"""
Compute binned kernel activation from events.
This module provides efficient computation of kernel-convolved features
at a lower output rate than the input. For exponential and alpha kernels,
uses a state-based approach that is O(n_events + n_bins) instead of
O(n_samples).
Input may be either ``sparse.COO`` (the typical output of
:class:`ezmsg.event.peak.ThresholdCrossingTransformer` in default mode) or a
dense array (from the same transformer with ``output_format=DENSE``). When the
input is dense and the configuration is COUNT + SUM (the rate-computation
case), the binning runs on the input's array namespace and stays on device
(e.g., MLX, CuPy). Other configurations with dense input fall back to
event-extraction and use the same code path as sparse input.
"""
from enum import Enum
import ezmsg.core as ez
import numpy as np
import numpy.typing as npt
import sparse
from array_api_compat import get_namespace, is_numpy_array
from ezmsg.baseproc import BaseStatefulTransformer, BaseTransformerUnit, processor_state
from ezmsg.util.messages.axisarray import AxisArray, replace
[docs]
class ActivationKernelType(str, Enum):
"""Supported kernel types for efficient binned activation."""
EXPONENTIAL = "exponential"
"""Exponential decay: k(t) = exp(-t/tau) for t >= 0."""
ALPHA = "alpha"
"""Alpha function: k(t) = (t/tau) * exp(-t/tau) for t >= 0."""
COUNT = "count"
"""Simple event counting (no kernel, just count events per bin)."""
[docs]
class BinAggregation(str, Enum):
"""How to aggregate activation within each bin."""
LAST = "last"
"""Use activation value at end of bin (default for activation features)."""
MEAN = "mean"
"""Average activation over the bin."""
SUM = "sum"
"""Sum of activation over the bin (for count, this gives total count)."""
MAX = "max"
"""Maximum activation in the bin."""
[docs]
class BinnedKernelActivationSettings(ez.Settings):
"""Settings for BinnedKernelActivation."""
kernel_type: ActivationKernelType = ActivationKernelType.EXPONENTIAL
"""Type of kernel to apply."""
tau: float = 0.050
"""Time constant in seconds. For exponential: decay rate. For alpha: peak time."""
bin_duration: float = 0.020
"""Output bin duration in seconds."""
aggregation: BinAggregation = BinAggregation.LAST
"""How to aggregate activation within each bin."""
scale_by_value: bool = False
"""If True, weight each event by its value. If False, all events contribute 1."""
normalize: bool = True
"""If True, normalize kernel so integral equals 1."""
rate_normalize: bool = False
"""If True, divide output by bin_duration to get events/second (for COUNT kernel)."""
[docs]
@processor_state
class BinnedKernelActivationState:
"""State for BinnedKernelActivation."""
# Current activation level per channel (for exponential/alpha)
activation: npt.NDArray[np.float64] | None = None
# For alpha kernel: auxiliary state variable
alpha_aux: npt.NDArray[np.float64] | None = None
# Time (in samples) since last state update per channel
samples_since_update: npt.NDArray[np.int64] | None = None
# Input sample rate (cached from first message)
fs: float | None = None
# Accumulated fractional bin samples for proper bin alignment
bin_accumulator: float = 0.0
[docs]
class BinnedKernelActivation(
BaseStatefulTransformer[
BinnedKernelActivationSettings,
AxisArray,
AxisArray,
BinnedKernelActivationState,
]
):
"""
Compute binned kernel activation from sparse events.
For exponential and alpha kernels, uses an efficient state-based algorithm:
- Exponential: activation[t] = sum_i exp(-(t - t_i) / tau)
- Alpha: activation[t] = sum_i (t - t_i) / tau * exp(-(t - t_i) / tau)
The algorithm only computes at event times and bin boundaries, giving
O(n_events + n_bins) complexity instead of O(n_samples).
Input: AxisArray with sparse.COO data (event times and values)
Output: AxisArray with dense binned activation features
Features:
- Efficient for sparse events (much faster than dense convolution)
- Handles chunk boundaries seamlessly
- Supports exponential, alpha, and count kernels
- Configurable bin aggregation (last, mean, sum, max)
"""
def _hash_message(self, message: AxisArray) -> int:
n_channels = message.data.shape[message.get_axis_idx("ch")] if "ch" in message.dims else 1
if "time" not in message.axes or not hasattr(message.axes["time"], "gain"):
raise ValueError("Could not determine sample rate from input message")
# str(dtype) works for numpy ('bool', 'float32', ...) and mlx (which doesn't expose dtype.kind).
return hash((message.data.ndim, str(message.data.dtype), n_channels, message.axes["time"].gain))
def _reset_state(self, message: AxisArray) -> None:
"""Initialize state for new input stream."""
n_channels = message.data.shape[message.get_axis_idx("ch")] if "ch" in message.dims else 1
self._state.activation = np.zeros(n_channels, dtype=np.float64)
self._state.samples_since_update = np.zeros(n_channels, dtype=np.int64)
self._state.bin_accumulator = 0.0
# For alpha kernel, we need auxiliary state
if self.settings.kernel_type == ActivationKernelType.ALPHA:
self._state.alpha_aux = np.zeros(n_channels, dtype=np.float64)
# Cache sample rate -- we know time is in axes because _hash_message would raise an error otherwise
time_axis = message.axes["time"]
if time_axis.gain > 0:
self._state.fs = 1.0 / time_axis.gain
def _decay_to_sample(self, channel: int, target_sample: int) -> None:
"""
Decay activation state to target sample.
Uses the appropriate decay formula based on kernel type.
"""
dt = target_sample - self._state.samples_since_update[channel]
if dt <= 0:
return
tau_samples = self.settings.tau * self._state.fs
decay = np.exp(-dt / tau_samples)
if self.settings.kernel_type == ActivationKernelType.EXPONENTIAL:
self._state.activation[channel] *= decay
elif self.settings.kernel_type == ActivationKernelType.ALPHA:
# Alpha kernel state update:
# activation = sum of (t - t_i) / tau * exp(-(t - t_i) / tau)
# We track: aux = sum of exp(-(t - t_i) / tau)
# activation = (derivative relationship)
# Update: aux *= decay, activation = activation * decay + aux * dt / tau
aux = self._state.alpha_aux[channel]
self._state.alpha_aux[channel] = aux * decay
# For alpha: d(activation)/dt = aux/tau - activation/tau
# Integrated: activation(t+dt) = activation(t)*decay + aux*(1-decay)
self._state.activation[channel] = self._state.activation[channel] * decay + aux * (1 - decay)
self._state.samples_since_update[channel] = target_sample
def _add_event(self, channel: int, sample: int, value: float) -> None:
"""Add an event contribution to the state."""
# First decay to event time
self._decay_to_sample(channel, sample)
weight = value if self.settings.scale_by_value else 1.0
if self.settings.normalize:
# Normalize so integral equals 1
weight /= self.settings.tau * self._state.fs
if self.settings.kernel_type == ActivationKernelType.EXPONENTIAL:
self._state.activation[channel] += weight
elif self.settings.kernel_type == ActivationKernelType.ALPHA:
# For alpha kernel, event adds to auxiliary state
self._state.alpha_aux[channel] += weight
elif self.settings.kernel_type == ActivationKernelType.COUNT:
self._state.activation[channel] += weight
def _get_activation_at_sample(self, channel: int, sample: int) -> float:
"""Get activation value at a specific sample."""
self._decay_to_sample(channel, sample)
return self._state.activation[channel]
def _process(self, message: AxisArray) -> AxisArray:
"""Compute binned activation from sparse or dense event input.
Dispatch:
- Dense input + COUNT + SUM: fast path that stays on the input's array
namespace (e.g., MLX, CuPy on device).
- Dense input + any other config: extract events from non-zero entries
and use the same code path as sparse input.
- Sparse input: existing event-based path.
"""
data = message.data
is_sparse_input = isinstance(data, sparse.SparseArray)
if not is_sparse_input:
if (
self.settings.kernel_type == ActivationKernelType.COUNT
and self.settings.aggregation == BinAggregation.SUM
):
return self._process_dense_count_sum(message)
# Fall back: convert dense to sparse so the existing event-based path can run.
data_np = data if is_numpy_array(data) else np.asarray(data)
message = replace(message, data=sparse.COO.from_numpy(data_np))
return self._process_events(message)
def _process_events(self, message: AxisArray) -> AxisArray:
"""Compute binned activation from sparse events."""
sparse_data = message.data
n_samples = sparse_data.shape[0]
n_channels = sparse_data.shape[1] if sparse_data.ndim > 1 else 1
# Calculate bin parameters
samples_per_bin = self.settings.bin_duration * self._state.fs
total_samples = n_samples + self._state.bin_accumulator
n_bins = int(total_samples / samples_per_bin)
if n_bins == 0:
# Not enough samples for a full bin yet
self._state.bin_accumulator = total_samples
# Still need to process events to update state
if hasattr(sparse_data, "coords") and hasattr(sparse_data, "data"):
coords = sparse_data.coords
values = sparse_data.data
for event_idx in range(len(values)):
sample_idx = int(coords[0, event_idx])
channel_idx = int(coords[1, event_idx]) if coords.shape[0] > 1 else 0
value = float(values[event_idx])
self._add_event(channel_idx, sample_idx, value)
# Return empty output
return replace(
message,
data=np.zeros((0, n_channels), dtype=np.float64),
axes={
**message.axes,
"time": replace(message.axes["time"], gain=self.settings.bin_duration),
},
)
# Calculate bin boundaries (in input samples, relative to chunk start)
# Account for accumulator from previous chunk
accumulator_before = self._state.bin_accumulator # Save for offset calculation
first_bin_end = samples_per_bin - self._state.bin_accumulator
bin_ends = first_bin_end + np.arange(n_bins) * samples_per_bin
# Update accumulator for next chunk
self._state.bin_accumulator = total_samples - n_bins * samples_per_bin
# Collect events sorted by time
events = []
if hasattr(sparse_data, "coords") and hasattr(sparse_data, "data"):
coords = sparse_data.coords
values = sparse_data.data
for event_idx in range(len(values)):
sample_idx = int(coords[0, event_idx])
channel_idx = int(coords[1, event_idx]) if coords.shape[0] > 1 else 0
value = float(values[event_idx])
events.append((sample_idx, channel_idx, value))
# Sort events by time
events.sort(key=lambda x: x[0])
# Process events and compute bin outputs
output = np.zeros((n_bins, n_channels), dtype=np.float64)
event_idx = 0
if self.settings.aggregation == BinAggregation.LAST:
# For LAST aggregation, process events up to each bin end
for bin_idx, bin_end in enumerate(bin_ends):
bin_end_sample = int(bin_end)
# Process all events up to this bin end
while event_idx < len(events) and events[event_idx][0] < bin_end_sample:
sample, channel, value = events[event_idx]
self._add_event(channel, sample, value)
event_idx += 1
# Record activation at bin end for each channel
for ch in range(n_channels):
output[bin_idx, ch] = self._get_activation_at_sample(ch, bin_end_sample)
elif self.settings.aggregation == BinAggregation.SUM:
# For SUM, accumulate within each bin
# For COUNT type, include accumulated counts from previous partial bin
for bin_idx, bin_end in enumerate(bin_ends):
bin_end_sample = int(bin_end)
# Start with any accumulated counts from previous chunk (for COUNT type)
if bin_idx == 0 and self.settings.kernel_type == ActivationKernelType.COUNT:
bin_sum = self._state.activation.copy()
# Reset state for next bin accumulation
self._state.activation = np.zeros(n_channels, dtype=np.float64)
else:
bin_sum = np.zeros(n_channels, dtype=np.float64)
# Sum events within this bin
while event_idx < len(events) and events[event_idx][0] < bin_end_sample:
sample, channel, value = events[event_idx]
weight = value if self.settings.scale_by_value else 1.0
bin_sum[channel] += weight
event_idx += 1
output[bin_idx] = bin_sum
elif self.settings.aggregation == BinAggregation.MEAN:
# For MEAN with kernel, we'd need to integrate activation over bin
# Approximate with samples at bin start and end
bin_start = 0
for bin_idx, bin_end in enumerate(bin_ends):
bin_end_sample = int(bin_end)
# Process events up to bin end
while event_idx < len(events) and events[event_idx][0] < bin_end_sample:
sample, channel, value = events[event_idx]
self._add_event(channel, sample, value)
event_idx += 1
# For exponential kernel, mean over [t0, t1] can be computed analytically
# For simplicity, use midpoint approximation
midpoint = (bin_start + bin_end_sample) // 2
for ch in range(n_channels):
output[bin_idx, ch] = self._get_activation_at_sample(ch, midpoint)
bin_start = bin_end_sample
# Process any remaining events (for state continuity)
while event_idx < len(events):
sample, channel, value = events[event_idx]
self._add_event(channel, sample, value)
event_idx += 1
# Update state sample counters relative to next chunk
self._state.samples_since_update -= n_samples
# Apply rate normalization if requested (divide by bin_duration to get events/second)
if self.settings.rate_normalize:
output = output / self.settings.bin_duration
# Calculate output time offset
# The first bin starts at (input_offset - accumulator_time)
input_offset = message.axes["time"].offset if "time" in message.axes else 0.0
accumulator_time = accumulator_before / self._state.fs
output_offset = input_offset - accumulator_time
return replace(
message,
data=output,
axes={
**message.axes,
"time": AxisArray.TimeAxis(
fs=1.0 / self.settings.bin_duration,
offset=output_offset,
),
},
)
def _process_dense_count_sum(self, message: AxisArray) -> AxisArray:
"""Fast path: dense input + COUNT kernel + SUM aggregation.
Bins are summed using cumulative-sum arithmetic in the input's array
namespace, so accelerator-resident inputs (MLX, CuPy) stay on device.
Carry-over for the partial bin spanning chunk boundaries is held in
``state.activation`` (numpy) and shuttled across boundaries.
"""
xp = get_namespace(message.data)
data = message.data
n_samples = data.shape[0]
feature_shape = tuple(data.shape[1:])
samples_per_bin = self.settings.bin_duration * self._state.fs
accumulator_before = self._state.bin_accumulator
total_samples = n_samples + accumulator_before
n_bins = int(total_samples / samples_per_bin)
# Per-sample contribution: 1 per non-zero, or the value itself if scaling.
# Use the .astype() method form so the same call works for both numpy and mlx
# (mlx.core has no top-level astype).
if n_samples == 0:
contrib = xp.zeros((0,) + feature_shape, dtype=xp.float32)
elif self.settings.scale_by_value:
contrib = data.astype(xp.float32)
else:
contrib = (data != 0).astype(xp.float32)
# Pull state into the input namespace for on-device math.
overflow_xp = xp.asarray(self._state.activation.reshape(feature_shape)).astype(xp.float32)
if n_bins == 0:
# No complete bins this chunk — accumulate everything into the carry-over.
new_overflow = overflow_xp + (xp.sum(contrib, axis=0) if n_samples > 0 else overflow_xp * 0)
self._state.activation = np.asarray(new_overflow).reshape(self._state.activation.shape)
self._state.bin_accumulator = total_samples
return replace(
message,
data=xp.zeros((0,) + feature_shape, dtype=xp.float32),
axes={
**message.axes,
"time": replace(message.axes["time"], gain=self.settings.bin_duration),
},
)
# Bin boundaries (in input-sample space, integer-truncated as in the event-based path).
first_bin_end = samples_per_bin - accumulator_before
bin_ends_float = first_bin_end + np.arange(n_bins) * samples_per_bin
bin_end_samples = bin_ends_float.astype(np.int64)
bin_start_samples = np.concatenate(([np.int64(0)], bin_end_samples[:-1]))
# Cumulative sum, prepended with zeros so cumsum_padded[k] = sum(contrib[:k]).
# Use cumsum (in both numpy and mlx); numpy via array_api_compat also exposes
# the standard `cumulative_sum`, but mlx does not.
cumsum = xp.cumsum(contrib, axis=0)
zero_row = xp.zeros((1,) + feature_shape, dtype=cumsum.dtype)
cumsum_padded = xp.concat((zero_row, cumsum), axis=0)
end_idx = xp.asarray(bin_end_samples)
start_idx = xp.asarray(bin_start_samples)
bin_sums = xp.take(cumsum_padded, end_idx, axis=0) - xp.take(cumsum_padded, start_idx, axis=0)
# Add carry-over from the previous chunk's partial bin into bin 0.
overflow_pad_first = overflow_xp[None, ...]
if n_bins > 1:
overflow_pad_rest = xp.zeros((n_bins - 1,) + feature_shape, dtype=bin_sums.dtype)
overflow_pad = xp.concat((overflow_pad_first, overflow_pad_rest), axis=0)
else:
overflow_pad = overflow_pad_first
output = bin_sums + overflow_pad
# New carry-over: events past the last complete bin remain in the partial bin.
last_bin_end = int(bin_end_samples[-1])
if last_bin_end < n_samples:
new_overflow = xp.sum(contrib[last_bin_end:], axis=0)
else:
new_overflow = xp.zeros(feature_shape, dtype=cumsum.dtype)
self._state.activation = np.asarray(new_overflow).reshape(self._state.activation.shape)
self._state.bin_accumulator = total_samples - n_bins * samples_per_bin
if self.settings.rate_normalize:
output = output / self.settings.bin_duration
accumulator_time = accumulator_before / self._state.fs
input_offset = message.axes["time"].offset if "time" in message.axes else 0.0
output_offset = input_offset - accumulator_time
return replace(
message,
data=output,
axes={
**message.axes,
"time": AxisArray.TimeAxis(
fs=1.0 / self.settings.bin_duration,
offset=output_offset,
),
},
)
[docs]
class BinnedKernelActivationUnit(
BaseTransformerUnit[
BinnedKernelActivationSettings,
AxisArray,
AxisArray,
BinnedKernelActivation,
]
):
"""Unit for BinnedKernelActivation."""
SETTINGS = BinnedKernelActivationSettings