Source code for ezmsg.sigproc.window

import enum
import traceback
import typing

import ezmsg.core as ez
import numpy.typing as npt
import sparse
from array_api_compat import is_pydata_sparse_namespace, get_namespace
from ezmsg.util.messages.axisarray import (
    AxisArray,
    slice_along_axis,
    sliding_win_oneaxis,
    replace,
)

from .base import (
    BaseStatefulTransformer,
    BaseTransformerUnit,
    processor_state,
)
from .util.sparse import sliding_win_oneaxis as sparse_sliding_win_oneaxis
from .util.profile import profile_subpub


[docs] class Anchor(enum.Enum): BEGINNING = "beginning" END = "end" MIDDLE = "middle"
[docs] class WindowSettings(ez.Settings): axis: str | None = None newaxis: str | None = None # new axis for output. No new axes if None window_dur: float | None = None # Sec. passthrough if None window_shift: float | None = None # Sec. Use "1:1 mode" if None zero_pad_until: str = "full" # "full", "shift", "input", "none" anchor: str | Anchor = Anchor.BEGINNING
[docs] @processor_state class WindowState: buffer: npt.NDArray | sparse.SparseArray | None = None window_samples: int | None = None window_shift_samples: int | None = None shift_deficit: int = 0 """ Number of incoming samples to ignore. Only relevant when shift > window.""" newaxis_warned: bool = False out_newaxis: AxisArray.LinearAxis | None = None out_dims: list[str] | None = None
[docs] class WindowTransformer( BaseStatefulTransformer[WindowSettings, AxisArray, AxisArray, WindowState] ): """ Apply a sliding window along the specified axis to input streaming data. The `windowing` method is perhaps the most useful and versatile method in ezmsg.sigproc, but its parameterization can be difficult. Please read the argument descriptions carefully. """
[docs] def __init__(self, *args, **kwargs) -> None: """ Args: axis: The axis along which to segment windows. If None, defaults to the first dimension of the first seen AxisArray. Note: The windowed axis must be an AxisArray.LinearAxis, not an AxisArray.CoordinateAxis. newaxis: New axis on which windows are delimited, immediately preceding the target windowed axis. The data length along newaxis may be 0 if this most recent push did not provide enough data for a new window. If window_shift is None then the newaxis length will always be 1. window_dur: The duration of the window in seconds. If None, the function acts as a passthrough and all other parameters are ignored. window_shift: The shift of the window in seconds. If None (default), windowing operates in "1:1 mode", where each input yields exactly one most-recent window. zero_pad_until: Determines how the function initializes the buffer. Can be one of "input" (default), "full", "shift", or "none". If `window_shift` is None then this field is ignored and "input" is always used. - "input" (default) initializes the buffer with the input then prepends with zeros to the window size. The first input will always yield at least one output. - "shift" fills the buffer until `window_shift`. No outputs will be yielded until at least `window_shift` data has been seen. - "none" does not pad the buffer. No outputs will be yielded until at least `window_dur` data has been seen. anchor: Determines the entry in `axis` that gets assigned `0`, which references the value in `newaxis`. Can be of class :obj:`Anchor` or a string representation of an :obj:`Anchor`. """ super().__init__(*args, **kwargs) # Sanity-check settings # if self.settings.newaxis is None: # ez.logger.warning("`newaxis=None` will be replaced with `newaxis='win'`.") # object.__setattr__(self.settings, "newaxis", "win") if ( self.settings.window_shift is None and self.settings.zero_pad_until != "input" ): ez.logger.warning( "`zero_pad_until` must be 'input' if `window_shift` is None. " f"Ignoring received argument value: {self.settings.zero_pad_until}" ) object.__setattr__(self.settings, "zero_pad_until", "input") elif ( self.settings.window_shift is not None and self.settings.zero_pad_until == "input" ): ez.logger.warning( "windowing is non-deterministic with `zero_pad_until='input'` as it depends on the size " "of the first input. We recommend using `zero_pad_until='shift'` when `window_shift` is float-valued." ) try: object.__setattr__(self.settings, "anchor", Anchor(self.settings.anchor)) except ValueError: raise ValueError( f"Invalid anchor: {self.settings.anchor}. Valid anchor are: {', '.join([e.value for e in Anchor])}" )
def _hash_message(self, message: AxisArray) -> int: axis = self.settings.axis or message.dims[0] axis_idx = message.get_axis_idx(axis) axis_info = message.get_axis(axis) fs = 1.0 / axis_info.gain samp_shape = message.data.shape[:axis_idx] + message.data.shape[axis_idx + 1 :] return hash(samp_shape + (fs, message.key)) def _reset_state(self, message: AxisArray) -> None: _newaxis = self.settings.newaxis or "win" if not self._state.newaxis_warned and _newaxis in message.dims: ez.logger.warning( f"newaxis {_newaxis} present in input dims. Using {_newaxis}_win instead" ) self._state.newaxis_warned = True self.settings.newaxis = f"{_newaxis}_win" axis = self.settings.axis or message.dims[0] axis_idx = message.get_axis_idx(axis) axis_info = message.get_axis(axis) fs = 1.0 / axis_info.gain xp = get_namespace(message.data) self._state.window_samples = int(self.settings.window_dur * fs) if self.settings.window_shift is not None: # If window_shift is None, we are in "1:1 mode" and window_shift_samples is not used. self._state.window_shift_samples = int(self.settings.window_shift * fs) if self.settings.zero_pad_until == "none": req_samples = self._state.window_samples elif ( self.settings.zero_pad_until == "shift" and self.settings.window_shift is not None ): req_samples = self._state.window_shift_samples else: # i.e. zero_pad_until == "input" req_samples = message.data.shape[axis_idx] n_zero = max(0, self._state.window_samples - req_samples) init_buffer_shape = ( message.data.shape[:axis_idx] + (n_zero,) + message.data.shape[axis_idx + 1 :] ) self._state.buffer = xp.zeros(init_buffer_shape, dtype=message.data.dtype) # Prepare reusable parts of output if self._state.out_newaxis is None: self._state.out_dims = ( list(message.dims[:axis_idx]) + [_newaxis] + list(message.dims[axis_idx:]) ) self._state.out_newaxis = replace( axis_info, gain=0.0 if self.settings.window_shift is None else axis_info.gain * self._state.window_shift_samples, offset=0.0, # offset modified per-msg below ) def __call__(self, message: AxisArray) -> AxisArray: if self.settings.window_dur is None: # Shortcut for no windowing return message return super().__call__(message) def _process(self, message: AxisArray) -> AxisArray: axis = self.settings.axis or message.dims[0] axis_idx = message.get_axis_idx(axis) axis_info = message.get_axis(axis) xp = get_namespace(message.data) # Add new data to buffer. # Currently, we concatenate the new time samples and clip the output. # np.roll is not preferred as it returns a copy, and there's no way to construct a # rolling view of the data. In current numpy implementations, np.concatenate # is generally faster than np.roll and slicing anyway, but this could still # be a performance bottleneck for large memory arrays. # A circular buffer might be faster. self._state.buffer = xp.concatenate( (self._state.buffer, message.data), axis=axis_idx ) # Create a vector of buffer timestamps to track axis `offset` in output(s) buffer_t0 = 0.0 buffer_tlen = self._state.buffer.shape[axis_idx] # Adjust so first _new_ sample at index 0. buffer_t0 -= self._state.buffer.shape[axis_idx] - message.data.shape[axis_idx] # Convert form indices to 'units' (probably seconds). buffer_t0 *= axis_info.gain buffer_t0 += axis_info.offset if self.settings.window_shift is not None and self._state.shift_deficit > 0: n_skip = min(self._state.buffer.shape[axis_idx], self._state.shift_deficit) if n_skip > 0: self._state.buffer = slice_along_axis( self._state.buffer, slice(n_skip, None), axis_idx ) buffer_t0 += n_skip * axis_info.gain buffer_tlen -= n_skip self._state.shift_deficit -= n_skip # Generate outputs. # Preliminary copy of axes without the axes that we are modifying. _newaxis = self.settings.newaxis or "win" out_axes = {k: v for k, v in message.axes.items() if k not in [_newaxis, axis]} # Update targeted (windowed) axis so that its offset is relative to the new axis if self.settings.anchor == Anchor.BEGINNING: out_axes[axis] = replace(axis_info, offset=0.0) elif self.settings.anchor == Anchor.END: out_axes[axis] = replace(axis_info, offset=-self.settings.window_dur) elif self.settings.anchor == Anchor.MIDDLE: out_axes[axis] = replace(axis_info, offset=-self.settings.window_dur / 2) # How we update .data and .axes[newaxis] depends on the windowing mode. if self.settings.window_shift is None: # one-to-one mode -- Each send yields exactly one window containing only the most recent samples. self._state.buffer = slice_along_axis( self._state.buffer, slice(-self._state.window_samples, None), axis_idx ) out_dat = self._state.buffer.reshape( self._state.buffer.shape[:axis_idx] + (1,) + self._state.buffer.shape[axis_idx:] ) win_offset = buffer_t0 + axis_info.gain * ( buffer_tlen - self._state.window_samples ) elif self._state.buffer.shape[axis_idx] >= self._state.window_samples: # Deterministic window shifts. sliding_win_fun = ( sparse_sliding_win_oneaxis if is_pydata_sparse_namespace(xp) else sliding_win_oneaxis ) out_dat = sliding_win_fun( self._state.buffer, self._state.window_samples, axis_idx, step=self._state.window_shift_samples, ) win_offset = buffer_t0 # Drop expired beginning of buffer and update shift_deficit multi_shift = self._state.window_shift_samples * out_dat.shape[axis_idx] self._state.shift_deficit = max( 0, multi_shift - self._state.buffer.shape[axis_idx] ) self._state.buffer = slice_along_axis( self._state.buffer, slice(multi_shift, None), axis_idx ) else: # Not enough data to make a new window. Return empty data. empty_data_shape = ( message.data.shape[:axis_idx] + (0, self._state.window_samples) + message.data.shape[axis_idx + 1 :] ) out_dat = xp.zeros(empty_data_shape, dtype=message.data.dtype) # out_newaxis will have first timestamp in input... but mostly meaningless because output is size-zero. win_offset = axis_info.offset if self.settings.anchor == Anchor.END: win_offset += self.settings.window_dur elif self.settings.anchor == Anchor.MIDDLE: win_offset += self.settings.window_dur / 2 self._state.out_newaxis = replace(self._state.out_newaxis, offset=win_offset) msg_out = replace( message, data=out_dat, dims=self._state.out_dims, axes={**out_axes, _newaxis: self._state.out_newaxis}, ) return msg_out
[docs] class Window( BaseTransformerUnit[WindowSettings, AxisArray, AxisArray, WindowTransformer] ): SETTINGS = WindowSettings INPUT_SIGNAL = ez.InputStream(AxisArray) OUTPUT_SIGNAL = ez.OutputStream(AxisArray)
[docs] @ez.subscriber(INPUT_SIGNAL, zero_copy=True) @ez.publisher(OUTPUT_SIGNAL) @profile_subpub(trace_oldest=False) async def on_signal(self, message: AxisArray) -> typing.AsyncGenerator: """ override superclass on_signal so we can opt to yield once or multiple times after dropping the win axis. """ # TODO: The transfomer overwrites settings.newaxis from None to "win", # then we no longer know if the user wants to trim out the newaxis from the unit. xp = get_namespace(message.data) try: ret = self.processor(message) if ret.data.size > 0: if ( self.SETTINGS.newaxis is not None or self.SETTINGS.window_dur is None ): # Multi-win mode or pass-through mode. yield self.OUTPUT_SIGNAL, ret else: # We need to split out_msg into multiple yields, dropping newaxis. axis_idx = ret.get_axis_idx("win") win_axis = ret.axes["win"] offsets = win_axis.value( xp.asarray(range(ret.data.shape[axis_idx])) ) for msg_ix in range(ret.data.shape[axis_idx]): # Need to drop 'win' and replace self.SETTINGS.axis from axes. _out_axes = { **{ k: v for k, v in ret.axes.items() if k not in ["win", self.SETTINGS.axis] }, self.SETTINGS.axis: replace( ret.axes[self.SETTINGS.axis], offset=offsets[msg_ix] ), } _ret = replace( ret, data=slice_along_axis(ret.data, msg_ix, axis_idx), dims=ret.dims[:axis_idx] + ret.dims[axis_idx + 1 :], axes=_out_axes, ) yield self.OUTPUT_SIGNAL, _ret except Exception: ez.logger.info(traceback.format_exc())
[docs] def windowing( axis: str | None = None, newaxis: str | None = None, window_dur: float | None = None, window_shift: float | None = None, zero_pad_until: str = "full", anchor: str | Anchor = Anchor.BEGINNING, ) -> WindowTransformer: return WindowTransformer( WindowSettings( axis=axis, newaxis=newaxis, window_dur=window_dur, window_shift=window_shift, zero_pad_until=zero_pad_until, anchor=anchor, ) )