Source code for ezmsg.event.refractory
import numpy as np
import numpy.typing as npt
import ezmsg.core as ez
from ezmsg.util.messages.axisarray import AxisArray, slice_along_axis
from ezmsg.sigproc.base import BaseStatefulTransformer, processor_state
[docs]
class RefractorySettings(ez.Settings):
dur: float = 0.001
"""The minimum duration between events in seconds. If 0 (default), no refractory period is enforced."""
[docs]
@processor_state
class Refractory:
width: int = 0
elapsed: npt.NDArray | None = None
"""Track number of samples since last event for each feature."""
[docs]
class RefractoryTransformer(
BaseStatefulTransformer[RefractorySettings, AxisArray, AxisArray, Refractory]
):
def _hash_message(self, message: AxisArray) -> int:
return super()._hash_message(message)
def _reset_state(self, message: AxisArray) -> None:
fs = 1 / message.axes["time"].gain
self._state.width = int(self.settings.dur * fs)
ax_idx = message.get_axis_idx("time")
first_samp = slice_along_axis(message.data, slice(None, 1, None), ax_idx)
self._state.elapsed = np.zeros(first_samp.shape, dtype=int) + (
self._state.width + 1
)
def _process(self, message: AxisArray) -> AxisArray:
if self._state.width <= 2:
return message
# TODO: Get the sparse indices of the message.data
if len(samp_idx) <= 0:
return message
uq_feats, feat_splits = np.unique(cross_idx[0], return_index=True)
ieis = np.diff(np.hstack(([samp_idx[0] + 1], samp_idx)))
# Reset elapsed time at feature boundaries.
ieis[feat_splits] = samp_idx[feat_splits] + self._state.elapsed[uq_feats]
b_drop = ieis <= self._state.refrac_width
drop_idx = np.where(b_drop)[0]
final_drop = []
while len(drop_idx) > 0:
d_idx = drop_idx[0]
# Update next iei so its interval refers to the event before the to-be-dropped event.
# but only if the next iei belongs to the same feature.
if ((d_idx + 1) < len(ieis)) and (d_idx + 1) not in feat_splits:
ieis[d_idx + 1] += ieis[d_idx]
# We will later remove this event from samp_idx and cross_idx
final_drop.append(d_idx)
# Remove the dropped event from drop_idx.
drop_idx = drop_idx[1:]
# If the next event is now outside the refractory period then it will not be dropped.
if len(drop_idx) > 0 and ieis[drop_idx[0]] > self._state.refrac_width:
drop_idx = drop_idx[1:]
samp_idx = np.delete(samp_idx, final_drop)
cross_idx = tuple(np.delete(_, final_drop) for _ in cross_idx)