Source code for ezmsg.sigproc.filterbankdesign

import typing

import ezmsg.core as ez
import numpy as np
import numpy.typing as npt

from ezmsg.util.messages.util import replace
from ezmsg.util.messages.axisarray import AxisArray

from .base import (
    BaseStatefulTransformer,
    processor_state,
)

from .filterbank import (
    FilterbankTransformer,
    FilterbankSettings,
    FilterbankMode,
    MinPhaseMode,
)

from .kaiser import KaiserFilterSettings, kaiser_design_fun


[docs] class FilterbankDesignSettings(ez.Settings): filters: typing.Iterable[KaiserFilterSettings] mode: FilterbankMode = FilterbankMode.CONV """ "conv", "fft", or "auto". If "auto", the mode is determined by the size of the input data. fft mode is more efficient for long kernels. However, fft mode uses non-overlapping windows and will incur a delay equal to the window length, which is larger than the largest kernel. conv mode is less efficient but will return data for every incoming chunk regardless of how small it is and thus can provide shorter latency updates. """ min_phase: MinPhaseMode = MinPhaseMode.NONE """ If not None, convert the kernels to minimum-phase equivalents. Valid options are 'hilbert', 'homomorphic', and 'homomorphic-full'. Complex filters not supported. See `scipy.signal.minimum_phase` for details. """ axis: str = "time" """The name of the axis to operate on. This should usually be "time".""" new_axis: str = "kernel" """The name of the new axis corresponding to the kernel index."""
[docs] @processor_state class FilterbankDesignState: filterbank: FilterbankTransformer | None = None needs_redesign: bool = False
[docs] class FilterbankDesignTransformer( BaseStatefulTransformer[ FilterbankDesignSettings, AxisArray, AxisArray, FilterbankDesignState ], ): """ Transformer that designs and applies a filterbank based on Kaiser windowed FIR filters. """
[docs] @classmethod def get_message_type(cls, dir: str) -> type[AxisArray]: if dir in ("in", "out"): return AxisArray else: raise ValueError(f"Invalid direction: {dir}. Must be 'in' or 'out'.")
[docs] def update_settings( self, new_settings: typing.Optional[FilterbankDesignSettings] = None, **kwargs ) -> None: """ Update settings and mark that filter coefficients need to be recalculated. Args: new_settings: Complete new settings object to replace current settings **kwargs: Individual settings to update """ # Update settings if new_settings is not None: self.settings = new_settings else: self.settings = replace(self.settings, **kwargs) # Set flag to trigger recalculation on next message if self.state.filterbank is not None: self.state.needs_redesign = True
def _calculate_kernels(self, fs: float) -> list[npt.NDArray]: kernels = [] for filter in self.settings.filters: output = kaiser_design_fun( fs, cutoff=filter.cutoff, ripple=filter.ripple, width=filter.width, pass_zero=filter.pass_zero, wn_hz=filter.wn_hz, ) kernels.append(np.array([1.0]) if output is None else output[0]) return kernels def __call__(self, message: AxisArray) -> AxisArray: if self.state.filterbank is not None and self.state.needs_redesign: self._reset_state(message) self.state.needs_redesign = False return super().__call__(message) def _hash_message(self, message: AxisArray) -> int: axis = message.dims[0] if self.settings.axis is None else self.settings.axis gain = message.axes[axis].gain if hasattr(message.axes[axis], "gain") else 1 axis_idx = message.get_axis_idx(axis) samp_shape = message.data.shape[:axis_idx] + message.data.shape[axis_idx + 1 :] return hash((message.key, samp_shape, gain)) def _reset_state(self, message: AxisArray) -> None: axis_obj = message.axes[self.settings.axis] assert isinstance(axis_obj, AxisArray.LinearAxis) fs = 1 / axis_obj.gain kernels = self._calculate_kernels(fs) new_settings = FilterbankSettings( kernels=kernels, mode=self.settings.mode, min_phase=self.settings.min_phase, axis=self.settings.axis, new_axis=self.settings.new_axis, ) self.state.filterbank = FilterbankTransformer(settings=new_settings) def _process(self, message: AxisArray) -> AxisArray: return self.state.filterbank(message)