Source code for ezmsg.event.refractory

import ezmsg.core as ez
import numpy as np
import numpy.typing as npt
import sparse
from ezmsg.baseproc import BaseStatefulTransformer, processor_state
from ezmsg.util.messages.axisarray import AxisArray, replace


[docs] class RefractorySettings(ez.Settings): dur: float = 0.001 """The minimum duration between events in seconds. If 0 (default), no refractory period is enforced."""
[docs] @processor_state class Refractory: width: int = 0 elapsed: npt.NDArray | None = None """Track number of samples since last event for each feature."""
[docs] class RefractoryTransformer(BaseStatefulTransformer[RefractorySettings, AxisArray, AxisArray, Refractory]): def _hash_message(self, message: AxisArray) -> int: return super()._hash_message(message) def _reset_state(self, message: AxisArray) -> None: fs = 1 / message.axes["time"].gain self._state.width = int(self.settings.dur * fs) ax_idx = message.get_axis_idx("time") # Get the shape of features (all dims except time) feat_shape = message.data.shape[:ax_idx] + message.data.shape[ax_idx + 1 :] n_feats = int(np.prod(feat_shape)) self._state.elapsed = np.zeros((n_feats,), dtype=int) + (self._state.width + 1) def _process(self, message: AxisArray) -> AxisArray: if self._state.width <= 2: return message ax_idx = message.get_axis_idx("time") n_samps = message.data.shape[ax_idx] # Get the sparse indices of the message.data # coords is a tuple of arrays, one per dimension coords = message.data.coords if coords.shape[1] == 0: # No events, update elapsed and return self._state.elapsed += n_samps return message # Separate time indices from feature indices samp_idx = coords[ax_idx] feat_dims = list(range(message.data.ndim)) feat_dims.pop(ax_idx) feat_coords = tuple(coords[d] for d in feat_dims) # Ravel feature indices to 1D for tracking feat_shape = tuple(message.data.shape[d] for d in feat_dims) if len(feat_coords) > 0: ravel_feat_inds = np.ravel_multi_index(feat_coords, feat_shape) else: ravel_feat_inds = np.zeros(len(samp_idx), dtype=int) # Sort by feature then by time to process events in order sort_order = np.lexsort((samp_idx, ravel_feat_inds)) samp_idx = samp_idx[sort_order] ravel_feat_inds = ravel_feat_inds[sort_order] feat_coords = tuple(fc[sort_order] for fc in feat_coords) # Create cross_idx as list with feature coords first, then time cross_idx = list(feat_coords) + [samp_idx] uq_feats, feat_splits = np.unique(ravel_feat_inds, return_index=True) ieis = np.diff(np.hstack(([samp_idx[0] + 1], samp_idx))) # Reset elapsed time at feature boundaries. ieis[feat_splits] = samp_idx[feat_splits] + self._state.elapsed[uq_feats] b_drop = ieis <= self._state.width drop_idx = np.where(b_drop)[0] final_drop = [] while len(drop_idx) > 0: d_idx = drop_idx[0] # Update next iei so its interval refers to the event before the to-be-dropped event. # but only if the next iei belongs to the same feature. if ((d_idx + 1) < len(ieis)) and (d_idx + 1) not in feat_splits: ieis[d_idx + 1] += ieis[d_idx] # We will later remove this event from samp_idx and cross_idx final_drop.append(d_idx) # Remove the dropped event from drop_idx. drop_idx = drop_idx[1:] # If the next event is now outside the refractory period then it will not be dropped. if len(drop_idx) > 0 and ieis[drop_idx[0]] > self._state.width: drop_idx = drop_idx[1:] samp_idx = np.delete(samp_idx, final_drop) cross_idx = [np.delete(_, final_drop) for _ in cross_idx] ravel_feat_inds = np.delete(ravel_feat_inds, final_drop) # Update elapsed state for all features self._state.elapsed += n_samps # For features that had events, set elapsed to time since last event if len(samp_idx) > 0: # Get the last event time for each feature that had events uq_final_feats, last_idx = np.unique(ravel_feat_inds[::-1], return_index=True) last_idx = len(ravel_feat_inds) - 1 - last_idx last_samps = samp_idx[last_idx] self._state.elapsed[uq_final_feats] = n_samps - last_samps # Build output coordinates in original dimension order out_coords = [None] * message.data.ndim for i, d in enumerate(feat_dims): out_coords[d] = cross_idx[i] out_coords[ax_idx] = cross_idx[-1] # Get the values for kept events kept_mask = np.ones(coords.shape[1], dtype=bool) kept_mask[sort_order[final_drop]] = False result_data = message.data.data[kept_mask] result = sparse.COO( out_coords, data=result_data, shape=message.data.shape, ) return replace(message, data=result)