Source code for ezmsg.event.peak
"""
Detects peaks in a signal.
.. note::
This module supports the `Array API standard <https://data-apis.org/array-api/>`_,
enabling use with NumPy, CuPy, PyTorch, and other compatible array libraries.
Signal data operations are array-API compliant. Event detection itself uses
NumPy regardless of input backend. The output container is configurable via
:class:`OutputFormat`: ``SPARSE`` produces a ``sparse.COO`` (default), while
``DENSE`` produces a dense array in the input's namespace so downstream
consumers can keep data on accelerators.
"""
import math
from enum import Enum
import ezmsg.core as ez
import numpy as np
import numpy.typing as npt
import sparse
from array_api_compat import get_namespace, is_numpy_array
from ezmsg.baseproc import (
BaseStatefulTransformer,
BaseTransformerUnit,
processor_state,
)
from ezmsg.sigproc.scaler import AdaptiveStandardScalerTransformer
from ezmsg.util.messages.axisarray import AxisArray, replace # slice_along_axis,
[docs]
class OutputFormat(str, Enum):
"""Output container for :class:`ThresholdCrossingTransformer`."""
SPARSE = "sparse"
"""Emit a ``sparse.COO`` array with one entry per accepted event (default)."""
DENSE = "dense"
"""Emit a dense array in the input's namespace, with non-zero entries at event positions.
Use this when downstream nodes are namespace-aware (e.g.,
:class:`ezmsg.event.kernel_activation.BinnedKernelActivation`) and you want
the data to stay on its current device (e.g., MLX, CuPy)."""
[docs]
class ThresholdSettings(ez.Settings):
threshold: float = -3.5
"""the value the signal must cross before the peak is found."""
max_peak_dur: float = 0.002
"""The maximum duration of a peak in seconds."""
min_peak_dur: float = 0.0
"""The minimum duration of a peak in seconds. If 0 (default), no minimum duration is enforced."""
refrac_dur: float = 0.001
"""The minimum duration between peaks in seconds. If 0 (default), no refractory period is enforced."""
align_on_peak: bool = False
"""If False (default), the returned sample index indicates the first sample across threshold.
If True, the sample index indicates the sample with the largest deviation after threshold crossing."""
return_peak_val: bool = False
"""If True then the peak value is included in the EventMessage or sparse matrix payload."""
auto_scale_tau: float = 0.0
"""If > 0, the data will be passed through a standard scaler prior to thresholding."""
output_format: OutputFormat = OutputFormat.SPARSE
"""Output container. ``SPARSE`` (default) emits ``sparse.COO``. ``DENSE`` emits a
dense array in the input's namespace so accelerator-resident data stays on device.
When ``DENSE`` is combined with an MLX input and a basic configuration (no
``align_on_peak``, ``return_peak_val``, ``min_peak_dur``, or ``auto_scale_tau``),
threshold detection + refractory enforcement run via an on-device Metal kernel
automatically — there is no separate opt-in toggle."""
[docs]
@processor_state
class ThresholdCrossingState:
"""State for ThresholdCrossingTransformer."""
max_width: int = 0
min_width: int = 1
refrac_width: int = 0
scaler: AdaptiveStandardScalerTransformer | None = None
"""Object performing adaptive z-scoring."""
data: npt.NDArray | None = None
"""Trailing buffer in case peak spans sample chunks. Only used if align_on_peak or return_peak_val."""
data_raw: npt.NDArray | None = None
"""Keep track of the raw data so we can return_peak_val. Only needed if using the scaler."""
elapsed: object | None = None
"""Int32 array, samples since last accepted crossing per feature.
Lives in the input's namespace: numpy (flat ``(prod(features),)``) for the cpu
event-loop path, MLX (shape ``(*features,)``) for the on-device Metal path."""
[docs]
class ThresholdCrossingTransformer(
BaseStatefulTransformer[ThresholdSettings, AxisArray, AxisArray, ThresholdCrossingState]
):
"""Transformer that detects threshold crossing events."""
def _hash_message(self, message: AxisArray) -> int:
ax_idx = message.get_axis_idx("time")
sample_shape = message.data.shape[:ax_idx] + message.data.shape[ax_idx + 1 :]
return hash((message.key, sample_shape, message.axes["time"].gain))
def _reset_state(self, message: AxisArray) -> None:
"""Reset the state variables."""
xp = get_namespace(message.data)
ax_idx = message.get_axis_idx("time")
# Precalculate some simple math we'd otherwise have to calculate on every iteration.
fs = 1 / message.axes["time"].gain
self._state.max_width = int(self.settings.max_peak_dur * fs)
self._state.min_width = int(self.settings.min_peak_dur * fs)
self._state.refrac_width = int(self.settings.refrac_dur * fs)
# We'll need the first sample (keep time dim!) for a few of our state initializations
perm = (ax_idx,) + tuple(i for i in range(message.data.ndim) if i != ax_idx)
data = xp.permute_dims(message.data, perm)
first_samp = data[:1]
# Prepare optional state variables
self._state.scaler = None
self._state.data_raw = None
if self.settings.auto_scale_tau > 0:
self._state.scaler = AdaptiveStandardScalerTransformer(
time_constant=self.settings.auto_scale_tau, axis="time"
)
if self.settings.return_peak_val:
self._state.data_raw = first_samp
# We always need at least the previous iteration's last sample for tracking whether we are newly over threshold,
# and potentially for aligning on peak or returning the peak value.
self._state.data = first_samp if self._state.scaler is None else xp.zeros_like(first_samp)
# Samples since last accepted crossing, per feature. Initialise at refrac_width+1 so
# the first sample is eligible. Live in the namespace the active path will consume:
# MLX feature-shaped int32 if Metal will run, numpy flat int32 otherwise.
feature_shape = tuple(data.shape[1:])
if self._can_use_mlx_metal(xp, message.data):
import mlx.core as mx
self._state.elapsed = mx.full(feature_shape, self._state.refrac_width + 1, dtype=mx.int32)
else:
n_features = math.prod(feature_shape) if feature_shape else 1
self._state.elapsed = np.full((n_features,), self._state.refrac_width + 1, dtype=np.int32)
def _can_use_mlx_metal(self, xp, data) -> bool:
"""The Metal kernel runs automatically for MLX + DENSE + a config it supports.
It cannot recover peak values, align on peaks, enforce a min peak width, or
auto-scale, so any of those settings disable the path.
"""
if self.settings.output_format != OutputFormat.DENSE:
return False
if self.settings.align_on_peak or self.settings.return_peak_val:
return False
if self.settings.min_peak_dur > 0.0 or self.settings.auto_scale_tau > 0.0:
return False
if is_numpy_array(data):
return False
return getattr(xp, "__name__", "") == "mlx.core"
def _process_mlx_metal(self, message: AxisArray):
"""Run the on-device Metal kernel for threshold detection + refractory.
Caller (``_process``) has already short-circuited empty messages, so this
always operates on at least one sample. Returns ``(events, t0_offset_samples)``
with ``t0_offset_samples == 0`` since the kernel processes the chunk in place.
"""
from ezmsg.event.util.peak_mlx_metal import threshold_crossings_mlx_metal
# _state.data carries the previous chunk's last sample as the over/under reference.
# _reset_state seeded it with this stream's first sample, so the kernel will treat
# the first sample as never-a-crossing, matching the cpu path's convention.
prev_sample = self._state.data[0]
events, last_sample, elapsed_out = threshold_crossings_mlx_metal(
message.data,
prev_sample,
self._state.elapsed,
threshold=self.settings.threshold,
refrac_width=self._state.refrac_width,
)
self._state.data = last_sample[None, ...]
self._state.elapsed = elapsed_out
return events, 0
def _wrap_output(self, message: AxisArray, data, t0_offset_samples: int = 0) -> AxisArray:
"""Build the output AxisArray, shifting the time offset if the result begins
earlier than the input chunk (the cpu path can hold back unfinished events
from the previous chunk into this output)."""
if t0_offset_samples == 0:
return replace(message, data=data)
time_axis = message.axes["time"]
return replace(
message,
data=data,
axes={
**message.axes,
"time": replace(time_axis, offset=time_axis.offset + t0_offset_samples * time_axis.gain),
},
)
def _empty_output(self, message: AxisArray, xp) -> AxisArray:
"""Build an empty output AxisArray that matches the active path's container/dtype."""
feature_shape = tuple(message.data.shape[1:])
if self._can_use_mlx_metal(xp, message.data):
import mlx.core as mx
return replace(message, data=mx.zeros((0,) + feature_shape, dtype=mx.int8))
out_dtype = message.data.dtype if self.settings.return_peak_val else bool
if self.settings.output_format == OutputFormat.SPARSE:
data = sparse.COO(
np.zeros((message.data.ndim, 0), dtype=np.int64),
data=np.array([], dtype=out_dtype),
shape=message.data.shape,
)
else:
data = xp.zeros(message.data.shape, dtype=out_dtype)
return replace(message, data=data)
def _process(self, message: AxisArray) -> AxisArray:
"""
Process incoming samples and detect threshold crossings.
Args:
msg: The input AxisArray containing signal data
Returns:
AxisArray with sparse data containing detected events
"""
xp = get_namespace(message.data)
ax_idx = message.get_axis_idx("time")
# If the time axis is not the last axis, we need to move it to the end.
if ax_idx != 0:
perm = (ax_idx,) + tuple(i for i in range(message.data.ndim) if i != ax_idx)
message = replace(
message,
data=xp.permute_dims(message.data, perm),
dims=["time"] + message.dims[:ax_idx] + message.dims[ax_idx + 1 :],
)
# An empty chunk produces an empty output regardless of backend or buffered state
# (the cpu path's prepended buffer alone yields no new crossings, and the metal
# kernel has nothing to scan). Short-circuit before backend dispatch.
if message.data.shape[0] == 0:
return self._empty_output(message, xp)
# MLX-on-device fast path: bypass the numpy event-detection logic and run
# the fused metal kernel for threshold detection + refractory enforcement.
if self._can_use_mlx_metal(xp, message.data):
data, t0_off = self._process_mlx_metal(message)
return self._wrap_output(message, data, t0_off)
# Take a copy of the raw data if needed and prepend to our state data_raw
# This will only exist if we are autoscaling AND we need to capture the true peak value.
if self._state.data_raw is not None:
self._state.data_raw = xp.concat((self._state.data_raw, message.data), axis=0)
# Run the message through the standard scaler if needed. Note: raw value is lost unless we copied it above.
if self._state.scaler is not None:
message = self._state.scaler(message)
# Prepend the previous iteration's last (maybe z-scored) sample to the current (maybe z-scored) data.
data = xp.concat((self._state.data, message.data), axis=0)
# Take note of how many samples were prepended. We will need this later when we modify `overs`.
n_prepended = self._state.data.shape[0]
if n_prepended == 0:
# No reference sample from previous iteration (e.g. first message after an empty-data reset).
# Duplicate the first sample as the reference, matching the convention that _reset_state
# stores data[:1] so it gets prepended on the next call.
data = xp.concat((data[:1], data), axis=0)
n_prepended = 1
if self._state.data_raw is not None:
self._state.data_raw = xp.concat((self._state.data_raw[:1], self._state.data_raw), axis=0)
# Identify which data points are over threshold
overs = data >= self.settings.threshold if self.settings.threshold >= 0 else data <= self.settings.threshold
# Find threshold _crossing_: where sample k is over and sample k-1 is not over.
b_cross_over = overs[1:] & ~overs[:-1]
# Convert boolean arrays to numpy for event detection (np.where, lexsort, etc.)
overs_np = np.asarray(overs) if not is_numpy_array(overs) else overs
b_cross_over_np = np.asarray(b_cross_over) if not is_numpy_array(b_cross_over) else b_cross_over
cross_idx = list(np.where(b_cross_over_np)) # List of indices into each dim
# We ignored the first sample when looking for crosses so we increment the sample index by 1.
cross_idx[0] += 1
# Sort events by feature first, then by time within each feature.
# np.where on a time-first array returns events sorted by time; we need them grouped by feature
# for the refractory period logic and elapsed tracking to work correctly.
if len(cross_idx[0]) > 0 and len(cross_idx) > 1:
sort_order = np.lexsort([cross_idx[0]] + cross_idx[1:][::-1])
cross_idx = [_[sort_order] for _ in cross_idx]
# Note: There is an assumption that the 0th sample only serves as a reference and is not part of the output;
# this will be trimmed at the very end. For now the offset is useful for bookkeeping (peak finding, etc.).
# Optionally drop crossings during refractory period
# TODO: This should go in its own transformer. https://github.com/ezmsg-org/ezmsg-event/issues/12
# However, a general purpose refractory-period-enforcer would keep track of its own event history,
# so we would probably do this step before prepending with historical samples.
if self._state.refrac_width > 2 and len(cross_idx[0]) > 0:
# Find the unique set of features that have at least one cross-over,
# and the indices of the first crossover for each.
ravel_feat_inds = np.ravel_multi_index(cross_idx[1:], overs_np.shape[1:])
uq_feats, feat_splits = np.unique(ravel_feat_inds, return_index=True)
# Calculate the inter-event intervals (IEIs) for each feature. First get all the IEIs.
ieis = np.diff(np.hstack(([cross_idx[0][0] + 1], cross_idx[0])))
# Then reset the interval at feature boundaries.
ieis[feat_splits] = cross_idx[0][feat_splits] + self._state.elapsed[uq_feats]
b_drop = ieis <= self._state.refrac_width
drop_idx = np.where(b_drop)[0]
final_drop = []
while len(drop_idx) > 0:
d_idx = drop_idx[0]
# Update next iei so its interval refers to the event before the to-be-dropped event.
# but only if the next iei belongs to the same feature.
if ((d_idx + 1) < len(ieis)) and (d_idx + 1) not in feat_splits:
ieis[d_idx + 1] += ieis[d_idx]
# We will later remove this event from samp_idx and cross_idx
final_drop.append(d_idx)
# Remove the dropped event from drop_idx.
drop_idx = drop_idx[1:]
# If the next event is now outside the refractory period then it will not be dropped.
if len(drop_idx) > 0 and ieis[drop_idx[0]] > self._state.refrac_width:
drop_idx = drop_idx[1:]
cross_idx = [np.delete(_, final_drop) for _ in cross_idx]
# Calculate the 'value' at each event.
hold_idx = overs_np.shape[0] - 1
if len(cross_idx[0]) == 0:
# No events; no values to calculate.
result_val = np.ones(
cross_idx[0].shape,
dtype=data.dtype if self.settings.return_peak_val else bool,
)
elif not (self._state.min_width > 1 or self.settings.align_on_peak or self.settings.return_peak_val):
# No postprocessing required. TODO: Why is min_width <= 1 a requirement here?
result_val = np.ones(cross_idx[0].shape, dtype=bool)
else:
cross_idx, result_val, hold_idx = self._postprocess_peaks(cross_idx, hold_idx, overs_np, data)
# Save data for next iteration
self._state.data = data[hold_idx:]
if self._state.data_raw is not None:
# Likely because we are using the scaler, we need a separate copy of the raw data.
self._state.data_raw = self._state.data_raw[hold_idx:]
# Clear out `elapsed` by adding the max number of samples since the last event.
self._state.elapsed += hold_idx
# Yet for features that actually had events, replace the elapsed time with the actual event time
self._state.elapsed[tuple(cross_idx[1:])] = hold_idx - cross_idx[0]
# Note: multiple-write to same index ^ is fine because it is sorted and the last value for each is correct.
# Build the result data; the AxisArray wrapping (with the time-offset shift for held-back samples)
# is handled by the shared _wrap_output below.
# Note: The first of the held back samples for next iteration is part of this iteration's return.
# Likewise, the first prepended sample on this iteration was part of the previous iteration's return.
cross_idx[0] -= 1 # Discard first prepended sample.
out_shape = (hold_idx,) + data.shape[1:]
if self.settings.output_format == OutputFormat.SPARSE:
result = sparse.COO(cross_idx, data=result_val, shape=out_shape)
else:
# Dense in the input's namespace so accelerator data stays on device downstream.
out_dtype = data.dtype if self.settings.return_peak_val else bool
dense_np = np.zeros(out_shape, dtype=out_dtype)
if cross_idx[0].size > 0:
dense_np[tuple(cross_idx)] = result_val
result = dense_np if xp is np else xp.asarray(dense_np)
return self._wrap_output(message, result, t0_offset_samples=-(n_prepended - 1))
def _postprocess_peaks(self, cross_idx, hold_idx, overs_np, data):
"""Apply min_peak_dur / align_on_peak / return_peak_val to detected crossings.
Walks each event's ``max_width``-long over/under window to find the true peak,
drops events that fail ``min_peak_dur``, and (for ``align_on_peak``) shifts
the event sample index to the peak location. May reduce ``hold_idx`` so
unfinished events get buffered for the next chunk.
Returns:
``(cross_idx, result_val, hold_idx)``.
"""
# Extract max_width-length vectors of `overs` values for each event. Pad with the last sample
# until the expected end of the event so events near the end of the data still resolve.
n_pad = max(0, max(cross_idx[0]) + self._state.max_width - overs_np.shape[0])
pad_width = ((0, n_pad),) + ((0, 0),) * (overs_np.ndim - 1)
overs_padded = np.pad(overs_np, pad_width, mode="edge")
s_idx = np.arange(self._state.max_width)[None, :] + cross_idx[0][:, None]
ep_overs = overs_padded[(s_idx,) + tuple(_[:, None] for _ in cross_idx[1:])]
# Event length = first non-over sample (invalid for events that don't cross back).
ev_len = ep_overs[..., 1:].argmin(axis=-1)
ev_len += 1 # Adjust because we skipped the first sample.
b_ev_crossback = np.any(~ep_overs[..., 1:], axis=-1)
if self._state.min_width > 1:
# Drop events that crossed back but fail min_width.
b_long = ~np.logical_and(b_ev_crossback, ev_len < self._state.min_width)
cross_idx = tuple(_[b_long] for _ in cross_idx)
ev_len = ev_len[b_long]
b_ev_crossback = b_ev_crossback[b_long]
# Find the earliest unfinished event so we can buffer it for the next chunk.
b_unf = ~b_ev_crossback
hold_idx = cross_idx[0][b_unf].min() if np.any(b_unf) else hold_idx
# Trim events that are past the hold_idx; they'll re-emerge next chunk.
b_pass_ev = cross_idx[0] < hold_idx
cross_idx = [_[b_pass_ev] for _ in cross_idx]
ev_len = ev_len[b_pass_ev]
if np.any(b_unf):
# Hold back at least 1 sample before start of unfinished events so we can re-detect.
hold_idx = max(hold_idx - 1, 0)
result_val = np.ones(cross_idx[0].shape, dtype=bool)
if self.settings.align_on_peak or self.settings.return_peak_val:
data_np = np.asarray(data) if not is_numpy_array(data) else data
raw_source_np = data_np
if self._state.data_raw is not None:
raw_source_np = (
np.asarray(self._state.data_raw)
if not is_numpy_array(self._state.data_raw)
else self._state.data_raw
)
# Process peaks in batches by length so short peaks don't give incorrect argmax results.
# TODO: Check performance of using a masked array instead. Might take longer to create the mask.
pk_offset = np.zeros_like(ev_len)
uq_lens, len_grps = np.unique(ev_len, return_inverse=True)
for len_idx, ep_len in enumerate(uq_lens):
b_grp = len_grps == len_idx
ep_resamp = np.arange(ep_len)[None, :] + cross_idx[0][b_grp, None]
ep_inds_tuple = (ep_resamp,) + tuple(_[b_grp, None] for _ in cross_idx[1:])
eps = data_np[ep_inds_tuple]
if self.settings.threshold >= 0:
pk_offset[b_grp] = np.argmax(eps, axis=1)
else:
pk_offset[b_grp] = np.argmin(eps, axis=1)
if self.settings.align_on_peak:
cross_idx[0] += pk_offset
if self.settings.return_peak_val:
peak_inds_tuple = (
tuple(cross_idx)
if self.settings.align_on_peak
else (cross_idx[0] + pk_offset,) + tuple(cross_idx[1:])
)
result_val = raw_source_np[peak_inds_tuple]
return cross_idx, result_val, hold_idx
[docs]
class ThresholdCrossing(BaseTransformerUnit[ThresholdSettings, AxisArray, AxisArray, ThresholdCrossingTransformer]):
SETTINGS = ThresholdSettings