"""
Compute binned kernel activation from sparse events.
This module provides efficient computation of kernel-convolved features
at a lower output rate than the input. For exponential and alpha kernels,
uses a state-based approach that is O(n_events + n_bins) instead of
O(n_samples).
"""
from enum import Enum
import ezmsg.core as ez
import numpy as np
import numpy.typing as npt
from ezmsg.baseproc import BaseStatefulTransformer, BaseTransformerUnit, processor_state
from ezmsg.util.messages.axisarray import AxisArray, replace
[docs]
class ActivationKernelType(str, Enum):
"""Supported kernel types for efficient binned activation."""
EXPONENTIAL = "exponential"
"""Exponential decay: k(t) = exp(-t/tau) for t >= 0."""
ALPHA = "alpha"
"""Alpha function: k(t) = (t/tau) * exp(-t/tau) for t >= 0."""
COUNT = "count"
"""Simple event counting (no kernel, just count events per bin)."""
[docs]
class BinAggregation(str, Enum):
"""How to aggregate activation within each bin."""
LAST = "last"
"""Use activation value at end of bin (default for activation features)."""
MEAN = "mean"
"""Average activation over the bin."""
SUM = "sum"
"""Sum of activation over the bin (for count, this gives total count)."""
MAX = "max"
"""Maximum activation in the bin."""
[docs]
class BinnedKernelActivationSettings(ez.Settings):
"""Settings for BinnedKernelActivation."""
kernel_type: ActivationKernelType = ActivationKernelType.EXPONENTIAL
"""Type of kernel to apply."""
tau: float = 0.050
"""Time constant in seconds. For exponential: decay rate. For alpha: peak time."""
bin_duration: float = 0.020
"""Output bin duration in seconds."""
aggregation: BinAggregation = BinAggregation.LAST
"""How to aggregate activation within each bin."""
scale_by_value: bool = False
"""If True, weight each event by its value. If False, all events contribute 1."""
normalize: bool = True
"""If True, normalize kernel so integral equals 1."""
rate_normalize: bool = False
"""If True, divide output by bin_duration to get events/second (for COUNT kernel)."""
[docs]
@processor_state
class BinnedKernelActivationState:
"""State for BinnedKernelActivation."""
# Current activation level per channel (for exponential/alpha)
activation: npt.NDArray[np.float64] | None = None
# For alpha kernel: auxiliary state variable
alpha_aux: npt.NDArray[np.float64] | None = None
# Time (in samples) since last state update per channel
samples_since_update: npt.NDArray[np.int64] | None = None
# Input sample rate (cached from first message)
fs: float | None = None
# Accumulated fractional bin samples for proper bin alignment
bin_accumulator: float = 0.0
[docs]
class BinnedKernelActivation(
BaseStatefulTransformer[
BinnedKernelActivationSettings,
AxisArray,
AxisArray,
BinnedKernelActivationState,
]
):
"""
Compute binned kernel activation from sparse events.
For exponential and alpha kernels, uses an efficient state-based algorithm:
- Exponential: activation[t] = sum_i exp(-(t - t_i) / tau)
- Alpha: activation[t] = sum_i (t - t_i) / tau * exp(-(t - t_i) / tau)
The algorithm only computes at event times and bin boundaries, giving
O(n_events + n_bins) complexity instead of O(n_samples).
Input: AxisArray with sparse.COO data (event times and values)
Output: AxisArray with dense binned activation features
Features:
- Efficient for sparse events (much faster than dense convolution)
- Handles chunk boundaries seamlessly
- Supports exponential, alpha, and count kernels
- Configurable bin aggregation (last, mean, sum, max)
"""
def _hash_message(self, message: AxisArray) -> int:
n_channels = message.data.shape[message.get_axis_idx("ch")] if "ch" in message.dims else 1
if "time" not in message.axes or not hasattr(message.axes["time"], "gain"):
raise ValueError("Could not determine sample rate from input message")
return hash((message.data.ndim, message.data.dtype.kind, n_channels, message.axes["time"].gain))
def _reset_state(self, message: AxisArray) -> None:
"""Initialize state for new input stream."""
n_channels = message.data.shape[message.get_axis_idx("ch")] if "ch" in message.dims else 1
self._state.activation = np.zeros(n_channels, dtype=np.float64)
self._state.samples_since_update = np.zeros(n_channels, dtype=np.int64)
self._state.bin_accumulator = 0.0
# For alpha kernel, we need auxiliary state
if self.settings.kernel_type == ActivationKernelType.ALPHA:
self._state.alpha_aux = np.zeros(n_channels, dtype=np.float64)
# Cache sample rate -- we know time is in axes because _hash_message would raise an error otherwise
time_axis = message.axes["time"]
if time_axis.gain > 0:
self._state.fs = 1.0 / time_axis.gain
def _decay_to_sample(self, channel: int, target_sample: int) -> None:
"""
Decay activation state to target sample.
Uses the appropriate decay formula based on kernel type.
"""
dt = target_sample - self._state.samples_since_update[channel]
if dt <= 0:
return
tau_samples = self.settings.tau * self._state.fs
decay = np.exp(-dt / tau_samples)
if self.settings.kernel_type == ActivationKernelType.EXPONENTIAL:
self._state.activation[channel] *= decay
elif self.settings.kernel_type == ActivationKernelType.ALPHA:
# Alpha kernel state update:
# activation = sum of (t - t_i) / tau * exp(-(t - t_i) / tau)
# We track: aux = sum of exp(-(t - t_i) / tau)
# activation = (derivative relationship)
# Update: aux *= decay, activation = activation * decay + aux * dt / tau
aux = self._state.alpha_aux[channel]
self._state.alpha_aux[channel] = aux * decay
# For alpha: d(activation)/dt = aux/tau - activation/tau
# Integrated: activation(t+dt) = activation(t)*decay + aux*(1-decay)
self._state.activation[channel] = self._state.activation[channel] * decay + aux * (1 - decay)
self._state.samples_since_update[channel] = target_sample
def _add_event(self, channel: int, sample: int, value: float) -> None:
"""Add an event contribution to the state."""
# First decay to event time
self._decay_to_sample(channel, sample)
weight = value if self.settings.scale_by_value else 1.0
if self.settings.normalize:
# Normalize so integral equals 1
weight /= self.settings.tau * self._state.fs
if self.settings.kernel_type == ActivationKernelType.EXPONENTIAL:
self._state.activation[channel] += weight
elif self.settings.kernel_type == ActivationKernelType.ALPHA:
# For alpha kernel, event adds to auxiliary state
self._state.alpha_aux[channel] += weight
elif self.settings.kernel_type == ActivationKernelType.COUNT:
self._state.activation[channel] += weight
def _get_activation_at_sample(self, channel: int, sample: int) -> float:
"""Get activation value at a specific sample."""
self._decay_to_sample(channel, sample)
return self._state.activation[channel]
def _process(self, message: AxisArray) -> AxisArray:
"""Compute binned activation from sparse events."""
sparse_data = message.data
n_samples = sparse_data.shape[0]
n_channels = sparse_data.shape[1] if sparse_data.ndim > 1 else 1
# Calculate bin parameters
samples_per_bin = self.settings.bin_duration * self._state.fs
total_samples = n_samples + self._state.bin_accumulator
n_bins = int(total_samples / samples_per_bin)
if n_bins == 0:
# Not enough samples for a full bin yet
self._state.bin_accumulator = total_samples
# Still need to process events to update state
if hasattr(sparse_data, "coords") and hasattr(sparse_data, "data"):
coords = sparse_data.coords
values = sparse_data.data
for event_idx in range(len(values)):
sample_idx = int(coords[0, event_idx])
channel_idx = int(coords[1, event_idx]) if coords.shape[0] > 1 else 0
value = float(values[event_idx])
self._add_event(channel_idx, sample_idx, value)
# Return empty output
return replace(
message,
data=np.zeros((0, n_channels), dtype=np.float64),
axes={
**message.axes,
"time": replace(message.axes["time"], gain=self.settings.bin_duration),
},
)
# Calculate bin boundaries (in input samples, relative to chunk start)
# Account for accumulator from previous chunk
accumulator_before = self._state.bin_accumulator # Save for offset calculation
first_bin_end = samples_per_bin - self._state.bin_accumulator
bin_ends = first_bin_end + np.arange(n_bins) * samples_per_bin
# Update accumulator for next chunk
self._state.bin_accumulator = total_samples - n_bins * samples_per_bin
# Collect events sorted by time
events = []
if hasattr(sparse_data, "coords") and hasattr(sparse_data, "data"):
coords = sparse_data.coords
values = sparse_data.data
for event_idx in range(len(values)):
sample_idx = int(coords[0, event_idx])
channel_idx = int(coords[1, event_idx]) if coords.shape[0] > 1 else 0
value = float(values[event_idx])
events.append((sample_idx, channel_idx, value))
# Sort events by time
events.sort(key=lambda x: x[0])
# Process events and compute bin outputs
output = np.zeros((n_bins, n_channels), dtype=np.float64)
event_idx = 0
if self.settings.aggregation == BinAggregation.LAST:
# For LAST aggregation, process events up to each bin end
for bin_idx, bin_end in enumerate(bin_ends):
bin_end_sample = int(bin_end)
# Process all events up to this bin end
while event_idx < len(events) and events[event_idx][0] < bin_end_sample:
sample, channel, value = events[event_idx]
self._add_event(channel, sample, value)
event_idx += 1
# Record activation at bin end for each channel
for ch in range(n_channels):
output[bin_idx, ch] = self._get_activation_at_sample(ch, bin_end_sample)
elif self.settings.aggregation == BinAggregation.SUM:
# For SUM, accumulate within each bin
# For COUNT type, include accumulated counts from previous partial bin
for bin_idx, bin_end in enumerate(bin_ends):
bin_end_sample = int(bin_end)
# Start with any accumulated counts from previous chunk (for COUNT type)
if bin_idx == 0 and self.settings.kernel_type == ActivationKernelType.COUNT:
bin_sum = self._state.activation.copy()
# Reset state for next bin accumulation
self._state.activation = np.zeros(n_channels, dtype=np.float64)
else:
bin_sum = np.zeros(n_channels, dtype=np.float64)
# Sum events within this bin
while event_idx < len(events) and events[event_idx][0] < bin_end_sample:
sample, channel, value = events[event_idx]
weight = value if self.settings.scale_by_value else 1.0
bin_sum[channel] += weight
event_idx += 1
output[bin_idx] = bin_sum
elif self.settings.aggregation == BinAggregation.MEAN:
# For MEAN with kernel, we'd need to integrate activation over bin
# Approximate with samples at bin start and end
bin_start = 0
for bin_idx, bin_end in enumerate(bin_ends):
bin_end_sample = int(bin_end)
# Process events up to bin end
while event_idx < len(events) and events[event_idx][0] < bin_end_sample:
sample, channel, value = events[event_idx]
self._add_event(channel, sample, value)
event_idx += 1
# For exponential kernel, mean over [t0, t1] can be computed analytically
# For simplicity, use midpoint approximation
midpoint = (bin_start + bin_end_sample) // 2
for ch in range(n_channels):
output[bin_idx, ch] = self._get_activation_at_sample(ch, midpoint)
bin_start = bin_end_sample
# Process any remaining events (for state continuity)
while event_idx < len(events):
sample, channel, value = events[event_idx]
self._add_event(channel, sample, value)
event_idx += 1
# Update state sample counters relative to next chunk
self._state.samples_since_update -= n_samples
# Apply rate normalization if requested (divide by bin_duration to get events/second)
if self.settings.rate_normalize:
output = output / self.settings.bin_duration
# Calculate output time offset
# The first bin starts at (input_offset - accumulator_time)
input_offset = message.axes["time"].offset if "time" in message.axes else 0.0
accumulator_time = accumulator_before / self._state.fs
output_offset = input_offset - accumulator_time
return replace(
message,
data=output,
axes={
**message.axes,
"time": AxisArray.TimeAxis(
fs=1.0 / self.settings.bin_duration,
offset=output_offset,
),
},
)
[docs]
class BinnedKernelActivationUnit(
BaseTransformerUnit[
BinnedKernelActivationSettings,
AxisArray,
AxisArray,
BinnedKernelActivation,
]
):
"""Unit for BinnedKernelActivation."""
SETTINGS = BinnedKernelActivationSettings