Source code for ezmsg.sigproc.filter

"""Core IIR/FIR filtering infrastructure with BA and SOS coefficient support."""

import typing
from abc import ABC, abstractmethod
from dataclasses import dataclass, field

import ezmsg.core as ez
import numpy as np
import numpy.typing as npt
import scipy.signal
from array_api_compat import get_namespace, is_numpy_array
from ezmsg.baseproc import (
    BaseConsumerUnit,
    BaseStatefulTransformer,
    BaseTransformerUnit,
    SettingsType,
    TransformerType,
    processor_state,
)
from ezmsg.util.messages.axisarray import AxisArray, slice_along_axis
from ezmsg.util.messages.util import replace

from .util.array import array_device, xp_asarray, xp_create


[docs] @dataclass class FilterCoefficients: b: np.ndarray = field(default_factory=lambda: np.array([1.0, 0.0])) a: np.ndarray = field(default_factory=lambda: np.array([1.0, 0.0]))
# Type aliases BACoeffs = tuple[npt.NDArray, npt.NDArray] SOSCoeffs = npt.NDArray FilterCoefsType = typing.TypeVar("FilterCoefsType", BACoeffs, SOSCoeffs) def _normalize_coefs( coefs: FilterCoefficients | tuple[npt.NDArray, npt.NDArray] | npt.NDArray | None, ) -> tuple[str, tuple[npt.NDArray, ...] | None]: coef_type = "ba" if coefs is not None: # scipy.signal functions called with first arg `*coefs`. # Make sure we have a tuple of coefficients. if isinstance(coefs, np.ndarray): coef_type = "sos" coefs = (coefs,) # sos funcs just want a single ndarray. elif isinstance(coefs, FilterCoefficients): coefs = (coefs.b, coefs.a) elif not isinstance(coefs, tuple): coefs = (coefs,) return coef_type, coefs def _sosfilt_xp(sos, x, axis_idx, zi, xp): """SOS filtering via parallel prefix scan (direct-form II transposed). Solves the IIR linear recurrence z[n+1] = A @ z[n] + B * x[n] using a Hillis-Steele inclusive prefix scan in O(log N) sequential steps instead of O(N), minimizing Python-level loop overhead for lazy-evaluation backends like MLX. Args: sos: (n_sections, 6) SOS coefficient array. Each row is [b0, b1, b2, a0, a1, a2]. a0 is assumed to be 1.0 (standard for scipy.signal.butter output). x: Input data array. axis_idx: The axis along which to filter. zi: Initial conditions, shape (n_sections, *x.shape[:axis_idx], 2, *x.shape[axis_idx+1:]). xp: Array API namespace. Returns: (y, zf) tuple — filtered output and final filter state. """ n_sections = sos.shape[0] N = x.shape[axis_idx] # Move time to axis 0 for uniform batch handling. x = xp.moveaxis(x, axis_idx, 0) # (N, *batch) zi = xp.moveaxis(zi, axis_idx + 1, 1) # (n_sections, 2, *batch) # Flatten batch dims into one. batch_shape = x.shape[1:] batch_size = 1 for s in batch_shape: batch_size *= s x = xp.reshape(x, (N, batch_size)) # (N, B) zi = xp.reshape(zi, (n_sections, 2, batch_size)) # (S, 2, B) # Pre-allocate output zi. zi_out = xp.zeros((n_sections, 2, batch_size), dtype=x.dtype) for s in range(n_sections): _b0 = float(sos[s, 0]) _b1 = float(sos[s, 1]) _b2 = float(sos[s, 2]) _a1 = float(sos[s, 4]) _a2 = float(sos[s, 5]) z_init = zi[s] # (2, B) # State recurrence: z[n+1] = A @ z[n] + B_vec * x[n] # A = [[-a1, 1], [-a2, 0]] # B_vec = [b1 - a1*b0, b2 - a2*b0] # Output: y[n] = b0 * x[n] + z[n][0] A_mat = xp_asarray(xp, np.array([[-_a1, 1.0], [-_a2, 0.0]])) # (2, 2) B_vec = xp_asarray(xp, np.array([_b1 - _a1 * _b0, _b2 - _a2 * _b0])) # (2,) # Initialize scan elements: # A_scan[n] = A for all n # c_scan[n] = B_vec * x[n] A_scan = xp.zeros((N, 2, 2), dtype=A_mat.dtype) A_scan[:] = A_mat # broadcast A_mat into every row c_scan = B_vec[None, :, None] * x[:, None, :] # (N, 2, B) # Hillis-Steele inclusive prefix scan. # Operator: (A_r, c_r) ∘ (A_l, c_l) = (A_r @ A_l, A_r @ c_l + c_r) # After the scan, A_scan[n] = A^(n+1) and # c_scan[n] = Σ_{k=0..n} A^(n-k) @ B_vec * x[k]. stride = 1 while stride < N: right_A = A_scan[stride:] # (N-stride, 2, 2) left_A = A_scan[:-stride] # (N-stride, 2, 2) right_c = c_scan[stride:] # (N-stride, 2, B) left_c = c_scan[:-stride] # (N-stride, 2, B) A_scan[stride:] = right_A @ left_A c_scan[stride:] = right_A @ left_c + right_c stride *= 2 # Recover all states: z[n+1] = A_scan[n] @ z_init + c_scan[n] z_from_scan = A_scan @ z_init[None, :, :] + c_scan # (N, 2, B) # z[0..N-1] for output: prepend z_init, drop z[N]. z_needed = xp.zeros((N, 2, batch_size), dtype=x.dtype) z_needed[0] = z_init z_needed[1:] = z_from_scan[:-1] # y[n] = b0 * x[n] + z[n][0]; output becomes input for the next section. x = _b0 * x + z_needed[:, 0, :] # (N, B) # Final state for this section: z[N] zi_out[s] = z_from_scan[-1] # Restore shapes. x = xp.reshape(x, (N,) + batch_shape) zi_out = xp.reshape(zi_out, (n_sections, 2) + batch_shape) x = xp.moveaxis(x, 0, axis_idx) zi_out = xp.moveaxis(zi_out, 1, axis_idx + 1) return x, zi_out def _fir_filt_fft(b, data, zi, axis_idx, xp): """FIR filtering via FFT convolution with streaming state. Args: b: FIR filter taps, shape (1, ..., M+1, ..., 1) with filter length at axis_idx. data: Input array. zi: State array holding the last M input samples along axis_idx. axis_idx: The axis along which to filter. xp: Array API namespace. Returns: (filtered_data, new_zi) tuple. """ M = zi.shape[axis_idx] # filter order (num taps - 1) if M == 0: # Zero-order FIR: just scale return data * b, zi N = data.shape[axis_idx] # Prepend state (last M input samples from previous chunk) extended = xp.concat([zi, data], axis=axis_idx) # FFT convolution fft_len = N + 2 * M B = xp.fft.rfft(b, n=fft_len, axis=axis_idx) X = xp.fft.rfft(extended, n=fft_len, axis=axis_idx) full = xp.fft.irfft(B * X, n=fft_len, axis=axis_idx) # Extract valid output: length N starting at offset M out = slice_along_axis(full, slice(M, M + N), axis_idx) # Update state: last M samples of extended input new_zi = slice_along_axis(extended, slice(N, N + M), axis_idx) return out, new_zi def _fir_filt_conv(b_1d, data, zi, axis_idx, xp): """FIR filtering via direct convolution using xp.conv_general. Args: b_1d: 1D FIR filter taps, shape (M+1,). data: Input array. zi: State array holding the last M input samples along axis_idx. axis_idx: The axis along which to filter. xp: Array API namespace (must have conv_general). Returns: (filtered_data, new_zi) tuple. """ M = zi.shape[axis_idx] # filter order (num taps - 1) if M == 0: return data * b_1d[0], zi N = data.shape[axis_idx] # Prepend state (last M input samples from previous chunk) extended = xp.concat([zi, data], axis=axis_idx) # Reshape N-D data into (batch, length, channels) for conv_general shape = extended.shape batch_size = 1 for i in range(axis_idx): batch_size *= shape[i] chan_size = 1 for i in range(axis_idx + 1, len(shape)): chan_size *= shape[i] L = shape[axis_idx] # M + N input_3d = xp.reshape(extended, (batch_size, L, chan_size)) # conv_general expects weight shape (out_channels, kernel_size, in_channels/groups) # With groups=chan_size, each channel is convolved independently. # We want each output channel to use the same kernel b_1d. # Weight shape: (chan_size, M+1, 1) kernel = xp.reshape(b_1d, (1, M + 1, 1)) weight = xp.broadcast_to(kernel, (chan_size, M + 1, 1)) # conv_general with flip=True gives correlation->convolution # padding=0 (default "VALID"), groups=chan_size for per-channel conv # Input: (batch_size, M+N, chan_size), Weight: (chan_size, M+1, 1) # Output: (batch_size, N, chan_size) out_3d = xp.conv_general(input_3d, weight, groups=chan_size, flip=True) # Reshape back to original data shape out_shape = list(data.shape) out_shape[axis_idx] = N dat_out = xp.reshape(out_3d, tuple(out_shape)) # Update state: last M samples of extended input new_zi = slice_along_axis(extended, slice(N, N + M), axis_idx) return dat_out, new_zi
[docs] class FilterBaseSettings(ez.Settings): axis: str | None = None """The name of the axis to operate on.""" coef_type: str = "ba" """The type of filter coefficients. One of "ba" or "sos"."""
[docs] class FilterSettings(FilterBaseSettings): coefs: FilterCoefficients | None = None """The pre-calculated filter coefficients."""
# Note: coef_type = "ba" is assumed for this class.
[docs] @processor_state class FilterState: zi: npt.NDArray | None = None fir_b: typing.Any | None = None # reshaped taps for FFT path (broadcast shape) fir_b_1d: typing.Any | None = None # 1D taps for conv path fir_method: str | None = None # 'conv', 'fft', or None (scipy)
[docs] class FilterTransformer(BaseStatefulTransformer[FilterSettings, AxisArray, AxisArray, FilterState]): """ Filter data using the provided coefficients. """ def __call__(self, message: AxisArray) -> AxisArray: if self.settings.coefs is None: return message if self._state.zi is None: self._reset_state(message) self._hash = self._hash_message(message) return super().__call__(message) def _hash_message(self, message: AxisArray) -> int: axis = message.dims[0] if self.settings.axis is None else self.settings.axis axis_idx = message.get_axis_idx(axis) samp_shape = message.data.shape[:axis_idx] + message.data.shape[axis_idx + 1 :] return hash((message.key, samp_shape)) def _reset_state(self, message: AxisArray) -> None: axis = message.dims[0] if self.settings.axis is None else self.settings.axis axis_idx = message.get_axis_idx(axis) n_tail = message.data.ndim - axis_idx - 1 _, coefs = _normalize_coefs(self.settings.coefs) if self.settings.coef_type == "ba": b, a = coefs is_fir = len(a) == 1 or np.allclose(a[1:], 0) if is_fir and not is_numpy_array(message.data): # FIR + non-numpy: use conv_general if available, else FFT xp = get_namespace(message.data) dev = array_device(message.data) M = len(b) - 1 # filter order zi_shape = list(message.data.shape) zi_shape[axis_idx] = M self.state.zi = xp_create(xp.zeros, tuple(zi_shape), dtype=message.data.dtype, device=dev) # 1D taps for conv path self.state.fir_b_1d = xp_asarray(xp, b, dtype=message.data.dtype, device=dev) # Reshape b to broadcast: (1, ..., M+1, ..., 1) for FFT path b_shape = [1] * message.data.ndim b_shape[axis_idx] = len(b) self.state.fir_b = xp.reshape(self.state.fir_b_1d, tuple(b_shape)) # Choose method self.state.fir_method = "conv" if hasattr(xp, "conv_general") else "fft" return if is_fir: # FIR + numpy: use lfiltic with zero initial conditions zi = scipy.signal.lfiltic(b, a, []) else: # IIR filters... zi = scipy.signal.lfilter_zi(b, a) else: # For second-order sections (SOS) filters, use sosfilt_zi zi = scipy.signal.sosfilt_zi(*coefs) zi_expand = (None,) * axis_idx + (slice(None),) + (None,) * n_tail n_tile = message.data.shape[:axis_idx] + (1,) + message.data.shape[axis_idx + 1 :] if self.settings.coef_type == "sos": zi_expand = (slice(None),) + zi_expand n_tile = (1,) + n_tile zi_tiled = np.tile(zi[zi_expand], n_tile) if not is_numpy_array(message.data): xp = get_namespace(message.data) zi_tiled = xp_asarray(xp, zi_tiled) self.state.zi = zi_tiled self.state.fir_method = None self.state.fir_b = None self.state.fir_b_1d = None
[docs] def update_coefficients( self, coefs: FilterCoefficients | tuple[npt.NDArray, npt.NDArray] | npt.NDArray, coef_type: str | None = None, ) -> None: """ Update filter coefficients. If the new coefficients have the same length as the current ones, only the coefficients are updated. If the lengths differ, the filter state is also reset to handle the new filter order. Args: coefs: New filter coefficients """ old_coefs = self.settings.coefs # Update settings with new coefficients self.settings = replace(self.settings, coefs=coefs) if coef_type is not None: self.settings = replace(self.settings, coef_type=coef_type) # Check if we need to reset the state if self.state.zi is not None: reset_needed = False if self.settings.coef_type == "ba": if isinstance(old_coefs, FilterCoefficients) and isinstance(coefs, FilterCoefficients): if len(old_coefs.b) != len(coefs.b) or len(old_coefs.a) != len(coefs.a): reset_needed = True elif isinstance(old_coefs, tuple) and isinstance(coefs, tuple): if len(old_coefs[0]) != len(coefs[0]) or len(old_coefs[1]) != len(coefs[1]): reset_needed = True else: reset_needed = True elif self.settings.coef_type == "sos": if isinstance(old_coefs, np.ndarray) and isinstance(coefs, np.ndarray): if old_coefs.shape != coefs.shape: reset_needed = True else: reset_needed = True if reset_needed: self.state.zi = None # This will trigger _reset_state on the next call
def _process(self, message: AxisArray) -> AxisArray: if message.data.size > 0: axis = message.dims[0] if self.settings.axis is None else self.settings.axis axis_idx = message.get_axis_idx(axis) if self.state.fir_method == "conv": xp = get_namespace(message.data) dat_out, self.state.zi = _fir_filt_conv(self.state.fir_b_1d, message.data, self.state.zi, axis_idx, xp) elif self.state.fir_method == "fft": xp = get_namespace(message.data) dat_out, self.state.zi = _fir_filt_fft(self.state.fir_b, message.data, self.state.zi, axis_idx, xp) else: _, coefs = _normalize_coefs(self.settings.coefs) filt_func = {"ba": scipy.signal.lfilter, "sos": scipy.signal.sosfilt}[self.settings.coef_type] input_xp = None if is_numpy_array(message.data) else get_namespace(message.data) if input_xp is not None: # Convert coefs and zi to the input namespace so scipy's # array_namespace sees a single backend and converts back. # NOTE: scipy 1.17 bundles an array_api_compat that does # not recognize MLX, so we also convert the output below. # When scipy's bundled copy gains MLX support, the manual # conversion will become a no-op. coefs = tuple(xp_asarray(input_xp, c) for c in coefs) dat_out, self.state.zi = filt_func(*coefs, message.data, axis=axis_idx, zi=self.state.zi) if input_xp is not None: dat_out = xp_asarray(input_xp, dat_out) self.state.zi = xp_asarray(input_xp, self.state.zi) else: dat_out = message.data return replace(message, data=dat_out)
[docs] class Filter(BaseTransformerUnit[FilterSettings, AxisArray, AxisArray, FilterTransformer]): SETTINGS = FilterSettings
[docs] def filtergen(axis: str, coefs: npt.NDArray | tuple[npt.NDArray] | None, coef_type: str) -> FilterTransformer: """ Filter data using the provided coefficients. Returns: :obj:`FilterTransformer`. """ return FilterTransformer(FilterSettings(axis=axis, coefs=coefs, coef_type=coef_type))
[docs] @processor_state class FilterByDesignState: filter: FilterTransformer | None = None needs_redesign: bool = False
[docs] class FilterByDesignTransformer( BaseStatefulTransformer[SettingsType, AxisArray, AxisArray, FilterByDesignState], ABC, typing.Generic[SettingsType, FilterCoefsType], ): """Abstract base class for filter design transformers."""
[docs] @classmethod def get_message_type(cls, dir: str) -> type[AxisArray]: if dir in ("in", "out"): return AxisArray else: raise ValueError(f"Invalid direction: {dir}. Must be 'in' or 'out'.")
[docs] @abstractmethod def get_design_function(self) -> typing.Callable[[float], FilterCoefsType | None]: """Return a function that takes sampling frequency and returns filter coefficients.""" ...
[docs] def update_settings(self, new_settings: typing.Optional[SettingsType] = None, **kwargs) -> None: """ Update settings and mark that filter coefficients need to be recalculated. Args: new_settings: Complete new settings object to replace current settings **kwargs: Individual settings to update """ # Update settings if new_settings is not None: self.settings = new_settings else: self.settings = replace(self.settings, **kwargs) # Set flag to trigger recalculation on next message if self.state.filter is not None: self.state.needs_redesign = True
def __call__(self, message: AxisArray) -> AxisArray: # Offer a shortcut when there is no design function or order is 0. if hasattr(self.settings, "order") and not self.settings.order: return message design_fun = self.get_design_function() if design_fun is None: return message # Check if filter exists but needs redesign due to settings change if self.state.filter is not None and self.state.needs_redesign: axis = self.state.filter.settings.axis fs = 1 / message.axes[axis].gain coefs = design_fun(fs) # Convert BA to SOS if requested if coefs is not None and self.settings.coef_type == "sos": if isinstance(coefs, tuple) and len(coefs) == 2: # It's BA format, convert to SOS b, a = coefs coefs = scipy.signal.tf2sos(b, a) self.state.filter.update_coefficients(coefs, coef_type=self.settings.coef_type) self.state.needs_redesign = False return super().__call__(message) def _hash_message(self, message: AxisArray) -> int: axis = message.dims[0] if self.settings.axis is None else self.settings.axis gain = message.axes[axis].gain if hasattr(message.axes[axis], "gain") else 1 axis_idx = message.get_axis_idx(axis) samp_shape = message.data.shape[:axis_idx] + message.data.shape[axis_idx + 1 :] return hash((message.key, samp_shape, gain)) def _reset_state(self, message: AxisArray) -> None: design_fun = self.get_design_function() axis = message.dims[0] if self.settings.axis is None else self.settings.axis fs = 1 / message.axes[axis].gain coefs = design_fun(fs) # Convert BA to SOS if requested if coefs is not None and self.settings.coef_type == "sos": if isinstance(coefs, tuple) and len(coefs) == 2: # It's BA format, convert to SOS b, a = coefs coefs = scipy.signal.tf2sos(b, a) new_settings = FilterSettings(axis=axis, coef_type=self.settings.coef_type, coefs=coefs) self.state.filter = FilterTransformer(settings=new_settings) def _process(self, message: AxisArray) -> AxisArray: return self.state.filter(message)
[docs] class BaseFilterByDesignTransformerUnit( BaseTransformerUnit[SettingsType, AxisArray, AxisArray, FilterByDesignTransformer], typing.Generic[SettingsType, TransformerType], ):
[docs] @ez.subscriber(BaseConsumerUnit.INPUT_SETTINGS) async def on_settings(self, msg: SettingsType) -> None: """ Receive a settings message, override self.SETTINGS, and re-create the processor. Child classes that wish to have fine-grained control over whether the core processor resets on settings changes should override this method. Args: msg: a settings message. """ self.apply_settings(msg) # Check if processor exists yet if hasattr(self, "processor") and self.processor is not None: # Update the existing processor with new settings self.processor.update_settings(self.SETTINGS) else: # Processor doesn't exist yet, create a new one self.create_processor()