Source code for ezmsg.event.binned
import ezmsg.core as ez
import numpy as np
import numpy.typing as npt
from ezmsg.baseproc import (
BaseStatefulTransformer,
BaseTransformerUnit,
processor_state,
)
from ezmsg.util.messages.axisarray import AxisArray, replace
[docs]
class BinnedEventAggregatorSettings(ez.Settings):
bin_duration: float = 0.05
"""
Duration of each bin in seconds.
This is the time interval over which events will be counted.
"""
scale_output: bool = True
"""
If True, the output will be scaled by the bin duration.
This is useful for converting counts to rates.
"""
axis: str = "time"
[docs]
@processor_state
class BinnedEventAggregatorState:
n_overflow: int = 0
counts_in_overflow: npt.NDArray[np.int64] | None = None
[docs]
class BinnedEventAggregator(
BaseStatefulTransformer[BinnedEventAggregatorSettings, AxisArray, AxisArray, BinnedEventAggregatorState]
):
def _hash_message(self, message: AxisArray) -> int:
targ_ax_idx = message.get_axis_idx(self.settings.axis)
non_targ_dims = message.dims[:targ_ax_idx] + message.dims[targ_ax_idx + 1 :]
return hash(tuple(non_targ_dims))
def _reset_state(self, message: AxisArray) -> None:
self._state.n_overflow = 0
targ_axis_idx = message.get_axis_idx(self.settings.axis)
buff_shape = message.data.shape[:targ_axis_idx] + message.data.shape[targ_axis_idx + 1 :]
self._state.counts_in_overflow = np.zeros(buff_shape, dtype=np.int64)
def _process(self, message: AxisArray) -> AxisArray:
# Quick maths
targ_ax_idx = message.get_axis_idx(self.settings.axis)
targ_axis = message.axes[self.settings.axis]
samples_per_bin = int(self.settings.bin_duration * (1 / targ_axis.gain))
# We will be slicing the data several times, so create a variable to hold the slices
var_slice = [slice(None)] * message.data.ndim
# Store for later use
n_prev_overflow = self._state.n_overflow
if self._state.n_overflow > 0:
# Calculate how many samples from the input msg we can fit into the first bin,
# given the current overflow state
n_first = samples_per_bin - self._state.n_overflow
# Sum the number of samples in the first bin then add to self._state.counts_in_overflow
var_slice[targ_ax_idx] = slice(0, n_first)
first_bin_counts = message.data[tuple(var_slice)].sum(axis=targ_ax_idx).todense()
first_bin_counts += self._state.counts_in_overflow
else:
n_first = 0
first_bin_counts = self._state.counts_in_overflow
assert np.all(first_bin_counts == 0), "Overflow state should be zeroed out from previous iteration."
# Calculate how many samples remain after the first bin
n_remaining = message.data.shape[targ_ax_idx] - n_first
n_full_bins = int(n_remaining / samples_per_bin)
# Slice the n_first:-next_overflow samples into a segment that divides evenly into bins
split_idx = n_first + n_full_bins * samples_per_bin
var_slice[targ_ax_idx] = slice(n_first, split_idx)
full_bins_data = message.data[tuple(var_slice)]
# Reshape and sum for full bins
new_shape = (
full_bins_data.shape[:targ_ax_idx]
+ (n_full_bins, samples_per_bin)
+ full_bins_data.shape[targ_ax_idx + 1 :]
)
middle_bin_counts = full_bins_data.reshape(new_shape).sum(axis=targ_ax_idx + 1).todense()
# Prepare output
if self._state.n_overflow > 0:
first_bin_counts = first_bin_counts.reshape(
first_bin_counts.shape[:targ_ax_idx] + (1,) + first_bin_counts.shape[targ_ax_idx:]
)
output_data = np.concatenate([first_bin_counts, middle_bin_counts], axis=targ_ax_idx)
else:
output_data = middle_bin_counts
if self.settings.scale_output:
output_data = output_data / self.settings.bin_duration
# Create the new output axis
# For the target axis, backup the offset by the number of samples in the overflow
out_axis = replace(
targ_axis,
gain=targ_axis.gain * samples_per_bin,
offset=targ_axis.offset - n_prev_overflow * targ_axis.gain,
)
out_msg = replace(
message,
data=output_data,
axes={k: v if k != self.settings.axis else out_axis for k, v in message.axes.items()},
)
# Calculate and store the overflow state.
var_slice[targ_ax_idx] = slice(split_idx, None)
overflow_data = message.data[tuple(var_slice)]
self._state.n_overflow = overflow_data.shape[targ_ax_idx]
self._state.counts_in_overflow = overflow_data.sum(axis=targ_ax_idx).todense()
return out_msg
[docs]
class BinnedEventAggregatorUnit(
BaseTransformerUnit[BinnedEventAggregatorSettings, AxisArray, AxisArray, BinnedEventAggregator]
):
SETTINGS = BinnedEventAggregatorSettings