Source code for ezmsg.sigproc.butterworthzerophase

"""
Streaming zero-phase Butterworth filter implemented as a two-stage composite processor.

Stage 1: Forward causal Butterworth filter (from ezmsg.sigproc.butterworthfilter)
Stage 2: Backward acausal filter with buffering (ButterworthBackwardFilterTransformer)

The output is delayed by `pad_length` samples to ensure the backward pass has sufficient
future context. The pad_length is computed analytically using scipy's heuristic.
"""

import functools
import typing

import numpy as np
import scipy.signal
from array_api_compat import get_namespace, is_numpy_array
from ezmsg.baseproc import BaseTransformerUnit
from ezmsg.baseproc.composite import CompositeProcessor
from ezmsg.util.messages.axisarray import AxisArray
from ezmsg.util.messages.util import replace

from .butterworthfilter import (
    ButterworthFilterSettings,
    ButterworthFilterTransformer,
    butter_design_fun,
)
from .filter import (
    _HAS_MLX_METAL,
    BACoeffs,
    FilterByDesignTransformer,
    SOSCoeffs,
    _sosfilt_mlx_metal_xp,
)
from .util.array import xp_asarray, xp_empty, xp_flip
from .util.axisarray_buffer import HybridAxisArrayBuffer

if _HAS_MLX_METAL:
    import mlx.core as _mx
else:
    _mx = None  # type: ignore


[docs] class ButterworthZeroPhaseSettings(ButterworthFilterSettings): """ Settings for :obj:`ButterworthZeroPhase`. This implements a streaming zero-phase Butterworth filter using forward-backward filtering. The output is delayed by `pad_length` samples to ensure the backward pass has sufficient future context. The pad_length is computed by finding where the filter's impulse response decays to `settle_cutoff` fraction of its peak value. This accounts for the filter's actual time constant rather than just its order. """ # Inherits from ButterworthFilterSettings: # axis, coef_type, order, cuton, cutoff, wn_hz settle_cutoff: float = 0.01 """ Fraction of peak impulse response used to determine settling time. The pad_length is set to the number of samples until the impulse response decays to this fraction of its peak. Default is 0.01 (1% of peak). """ max_pad_duration: float | None = None """ Maximum pad duration in seconds. If set, the pad_length will be capped at this value times the sampling rate. Use this to limit latency for filters with very long impulse responses. Default is None (no limit). """
[docs] class ButterworthBackwardFilterTransformer(FilterByDesignTransformer[ButterworthFilterSettings, BACoeffs | SOSCoeffs]): """ Backward (acausal) Butterworth filter with buffering. This transformer buffers its input and applies the filter in reverse, outputting only the "settled" portion where transients have decayed. This introduces a lag of ``pad_length`` samples. Intended to be used as stage 2 in a zero-phase filter pipeline, receiving forward-filtered data from a ButterworthFilterTransformer. """ # Instance attributes (initialized in _reset_state) _buffer: HybridAxisArrayBuffer | None _coefs_cache: BACoeffs | SOSCoeffs | None _zi_tiled: typing.Any | None # xp array in the namespace of the input data _sos_mx: typing.Any | None # cached mlx.core.array of SOS coefs (SOS + MLX path) _pad_length: int
[docs] def get_design_function( self, ) -> typing.Callable[[float], BACoeffs | SOSCoeffs | None]: return functools.partial( butter_design_fun, order=self.settings.order, cuton=self.settings.cuton, cutoff=self.settings.cutoff, coef_type=self.settings.coef_type, wn_hz=self.settings.wn_hz, )
def _compute_pad_length(self, fs: float) -> int: """ Compute pad length based on the filter's impulse response settling time. The pad_length is determined by finding where the impulse response decays to `settle_cutoff` fraction of its peak value. This is then optionally capped by `max_pad_duration`. Args: fs: Sampling frequency in Hz. Returns: Number of samples for the pad length. """ # Design the filter to compute impulse response coefs = self.get_design_function()(fs) if coefs is None: # Filter design failed or is disabled return 0 # Generate impulse response - use a generous length initially # Start with scipy's heuristic as minimum, then extend if needed if self.settings.coef_type == "ba": min_length = 3 * (self.settings.order + 1) else: n_sections = (self.settings.order + 1) // 2 min_length = 3 * n_sections * 2 # Use 10x the minimum as initial impulse length, or at least 10000 samples # (10000 samples allows for ~333ms at 30kHz, covering most practical cases) impulse_length = max(min_length * 10, 10000) # Cap impulse length computation if max_pad_duration is set if self.settings.max_pad_duration is not None: max_samples = int(self.settings.max_pad_duration * fs) impulse_length = min(impulse_length, max_samples + 1) impulse = np.zeros(impulse_length) impulse[0] = 1.0 if self.settings.coef_type == "ba": b, a = coefs h = scipy.signal.lfilter(b, a, impulse) else: h = scipy.signal.sosfilt(coefs, impulse) # Find where impulse response settles to settle_cutoff of peak abs_h = np.abs(h) peak = abs_h.max() if peak == 0: return min_length threshold = self.settings.settle_cutoff * peak above_threshold = np.where(abs_h > threshold)[0] if len(above_threshold) == 0: pad_length = min_length else: pad_length = above_threshold[-1] + 1 # Ensure at least the scipy heuristic minimum pad_length = max(pad_length, min_length) # Apply max_pad_duration cap if set if self.settings.max_pad_duration is not None: max_samples = int(self.settings.max_pad_duration * fs) pad_length = min(pad_length, max_samples) return pad_length def _reset_state(self, message: AxisArray) -> None: """Reset filter state when stream changes.""" self._coefs_cache = None self._zi_tiled = None self._sos_mx = None self._buffer = None # Compute pad_length based on the message's sampling rate axis = message.dims[0] if self.settings.axis is None else self.settings.axis fs = 1 / message.axes[axis].gain self._pad_length = self._compute_pad_length(fs) self.state.needs_redesign = True def _compute_zi_tiled(self, data, ax_idx: int, xp) -> None: """Compute and cache the tiled zi for the given data shape. Called once per stream (or after filter redesign). The result is broadcast-ready for multiplication by the edge sample on each chunk. Stored in the namespace of ``data`` (numpy or MLX). """ if self.settings.coef_type == "ba": b, a = self._coefs_cache zi_base = scipy.signal.lfilter_zi(b, a) else: # sos zi_base = scipy.signal.sosfilt_zi(self._coefs_cache) n_tail = data.ndim - ax_idx - 1 if self.settings.coef_type == "ba": zi_expand = (None,) * ax_idx + (slice(None),) + (None,) * n_tail n_tile = data.shape[:ax_idx] + (1,) + data.shape[ax_idx + 1 :] else: # sos zi_expand = (slice(None),) + (None,) * ax_idx + (slice(None),) + (None,) * n_tail n_tile = (1,) + data.shape[:ax_idx] + (1,) + data.shape[ax_idx + 1 :] zi_tiled = np.tile(zi_base[zi_expand], n_tile) if xp is not np: zi_tiled = xp_asarray(xp, zi_tiled.astype(np.float32), dtype=data.dtype) self._zi_tiled = zi_tiled def _initialize_zi(self, data, ax_idx: int, xp): """Initialize filter state (zi) scaled by edge value.""" if self._zi_tiled is None: self._compute_zi_tiled(data, ax_idx, xp) # slice [0:1] along ax_idx — portable across numpy/MLX first_idx = tuple(slice(0, 1) if i == ax_idx else slice(None) for i in range(data.ndim)) first_sample = data[first_idx] return self._zi_tiled * first_sample def _process(self, message: AxisArray) -> AxisArray: axis = message.dims[0] if self.settings.axis is None else self.settings.axis ax_idx = message.get_axis_idx(axis) fs = 1 / message.axes[axis].gain # Check if we need to redesign filter if self._coefs_cache is None or self.state.needs_redesign: self._coefs_cache = self.get_design_function()(fs) self._pad_length = self._compute_pad_length(fs) self._zi_tiled = None # Invalidate; recomputed on next use. self._sos_mx = None self.state.needs_redesign = False # Initialize buffer with duration based on pad_length # Add some margin to handle variable chunk sizes buffer_duration = (self._pad_length + 1) / fs self._buffer = HybridAxisArrayBuffer(duration=buffer_duration, axis=axis) # Early exit if filter is effectively disabled if self._coefs_cache is None or self.settings.order <= 0 or message.data.size <= 0: return message # Write new data to buffer self._buffer.write(message) n_available = self._buffer.available() n_output = n_available - self._pad_length xp = np if is_numpy_array(message.data) else get_namespace(message.data) # If we don't have enough data yet, return empty if n_output <= 0: new_shape = list(message.data.shape) new_shape[ax_idx] = 0 empty_data = xp_empty(xp, tuple(new_shape), dtype=message.data.dtype) return replace(message, data=empty_data) # Peek all available data from buffer # Note: HybridAxisArrayBuffer moves the target axis to position 0 buffered = self._buffer.peek(n_available) combined = buffered.data buffer_ax_idx = 0 # Buffer always puts time axis at position 0 # Backward filter on reversed data — stay in the input's namespace. combined_rev = xp_flip(combined, axis=buffer_ax_idx) backward_zi = self._initialize_zi(combined_rev, buffer_ax_idx, xp) is_mlx = xp is not np and xp.__name__ == "mlx.core" use_mlx_metal = ( self.settings.coef_type == "sos" and is_mlx and getattr(self.settings, "use_mlx_metal", True) and _HAS_MLX_METAL ) if use_mlx_metal: if self._sos_mx is None: self._sos_mx = _mx.array(np.asarray(self._coefs_cache).astype(np.float32)) y_bwd_rev, _ = _sosfilt_mlx_metal_xp(self._sos_mx, combined_rev, buffer_ax_idx, backward_zi) elif self.settings.coef_type == "ba": b, a = self._coefs_cache y_bwd_rev, _ = scipy.signal.lfilter(b, a, combined_rev, axis=buffer_ax_idx, zi=backward_zi) else: # sos via scipy (non-MLX, or use_mlx_metal disabled) y_bwd_rev, _ = scipy.signal.sosfilt(self._coefs_cache, combined_rev, axis=buffer_ax_idx, zi=backward_zi) # Reverse back to get output in correct time order y_bwd = xp_flip(y_bwd_rev, axis=buffer_ax_idx) # Output the settled portion (first n_output samples) y = y_bwd[:n_output] # Advance buffer read head to discard output samples, keep pad_length self._buffer.seek(n_output) # Build output with adjusted time axis # LinearAxis offset is already correct from the buffer out_axis = buffered.axes[axis] # Move axis back to original position if needed if ax_idx != 0: y = xp.moveaxis(y, 0, ax_idx) return replace( message, data=y, axes={**message.axes, axis: out_axis}, )
[docs] class ButterworthZeroPhaseTransformer(CompositeProcessor[ButterworthZeroPhaseSettings, AxisArray, AxisArray]): """ Streaming zero-phase Butterworth filter as a composite of two stages. Stage 1 (forward): Standard causal Butterworth filter with state Stage 2 (backward): Acausal Butterworth filter with buffering The output is delayed by ``pad_length`` samples. """ @staticmethod def _initialize_processors( settings: ButterworthZeroPhaseSettings, ) -> dict[str, typing.Any]: # Both stages use the same filter design settings return { "forward": ButterworthFilterTransformer(settings), "backward": ButterworthBackwardFilterTransformer(settings), }
[docs] @classmethod def get_message_type(cls, dir: str) -> type[AxisArray]: if dir in ("in", "out"): return AxisArray raise ValueError(f"Invalid direction: {dir}. Must be 'in' or 'out'.")
[docs] class ButterworthZeroPhase( BaseTransformerUnit[ButterworthZeroPhaseSettings, AxisArray, AxisArray, ButterworthZeroPhaseTransformer] ): SETTINGS = ButterworthZeroPhaseSettings