Source code for ezmsg.sigproc.fir_hilbert

import functools
import typing

import ezmsg.core as ez
import numpy as np
import scipy.signal as sps
from ezmsg.baseproc import BaseStatefulTransformer, processor_state
from ezmsg.util.messages.axisarray import AxisArray
from ezmsg.util.messages.util import replace

from ezmsg.sigproc.filter import (
    BACoeffs,
    BaseFilterByDesignTransformerUnit,
    BaseTransformerUnit,
    FilterBaseSettings,
    FilterByDesignTransformer,
)


[docs] class FIRHilbertFilterSettings(FilterBaseSettings): """Settings for :obj:`FIRHilbertFilter`.""" # axis inherited from FilterBaseSettings coef_type: str = "ba" """ Coefficient type. Must be 'ba' for FIR. """ order: int = 170 """ Filter order (taps = order + 1). Hilbert (type-III) filters require even order (odd taps). If odd order (even taps), order will be incremented by 1. """ f_lo: float = 1.0 """ Lower corner of Hilbert “pass” band (Hz). Transition starts at f_lo. """ f_hi: float | None = None """ Upper corner of Hilbert “pass” band (Hz). Transition starts at f_hi. If None, highpass from f_lo to Nyquist. """ trans_lo: float = 1.0 """ Transition width (Hz) below f_lo. Decrease to sharpen transition. """ trans_hi: float = 1.0 """ Transition width (Hz) at high end. Decrease to sharpen transition. """ weight_pass: float = 1.0 """ Weight for Hilbert pass region. """ weight_stop_lo: float = 1.0 """ Weight for low stop band. """ weight_stop_hi: float = 1.0 """ Weight for high stop band. """ norm_band: tuple[float, float] | None = None """ Optional normalization band (f_lo, f_hi) in Hz for gain normalization. If None, no normalization is applied. """ norm_freq: float | None = None """ Optional normalization frequency in Hz for gain normalization. If None, no normalization is applied. """
[docs] def fir_hilbert_design_fun( fs: float, order: int = 170, f_lo: float = 1.0, f_hi: float | None = None, trans_lo: float = 1.0, trans_hi: float = 1.0, weight_pass: float = 1.0, weight_stop_lo: float = 1.0, weight_stop_hi: float = 1.0, norm_band: tuple[float, float] | None = None, norm_freq: float | None = None, ) -> BACoeffs | None: """ Hilbert FIR filter design using the Remez exchange algorithm. Design an `order`th-order FIR Hilbert filter and return the filter coefficients. See :obj:`FIRHilbertFilterSettings` for argument description. Returns: The filter coefficients as a tuple of (b, a). """ if order <= 0: return None if order % 2 == 1: order += 1 nyq = fs / 2.0 taps = order + 1 f1 = max(f_lo, 0.0) + trans_lo f2 = (nyq - trans_hi) if (f_hi is None) else min(f_hi, nyq - trans_hi) if not (0.0 < f1 < f2 < nyq): raise ValueError( f"Hilbert passband collapsed or invalid: " f"f_lo={f_lo}, f_hi={f_hi}, trans_lo={trans_lo}, trans_hi={trans_hi}, fs={fs}" ) # Bands: [0, f1-trans_lo] stop ; [f1, f2] pass (Hilbert) ; [f2+trans_hi, nyq] stop bands = [0.0, max(f1 - trans_lo, 0.0), f1, f2, min(f2 + trans_hi, nyq), nyq] desired = [0.0, 1.0, 0.0] weight = [max(weight_stop_lo, 0.0), max(weight_pass, 0.0), max(weight_stop_hi, 0.0)] for i in range(1, len(bands) - 1): if bands[i] <= bands[i - 1]: bands[i] = np.nextafter(bands[i - 1], np.inf) if bands[-2] >= nyq: ez.logger.warning("Hilbert upper stopband collapsed; using 2-band (stop/pass) design.") bands = bands[:-3] + [nyq] desired = desired[:-1] weight = weight[:-1] b = sps.remez(taps, bands, desired, weight=weight, type="hilbert", fs=fs) a = np.array([1.0]) g = None if norm_freq is not None: if norm_freq < f1 or norm_freq > f2: ez.logger.warning("Invalid normalization frequency specifications. Skipping normalization.") else: f0 = float(norm_freq) w = 2.0 * np.pi * (np.asarray([f0], dtype=np.float64) / fs) _, H = sps.freqz(b, a, worN=w) g = float(np.abs(H[0])) elif norm_band is not None: lo, hi = norm_band if lo < f1 or hi > f2: lo = max(lo, f1) hi = min(hi, f2) ez.logger.warning("Normalization band outside passband. Clipping to passband for normalization.") if lo >= hi: ez.logger.warning("Invalid normalization band specifications. Skipping normalization.") else: freqs = np.linspace(lo, hi, 2048, dtype=np.float64) w = 2.0 * np.pi * (np.asarray(freqs, dtype=np.float64) / fs) _, H = sps.freqz(b, a, worN=w) g = float(np.median(np.abs(H))) if g is not None and g > 0: b = b / g return (b, a)
[docs] class FIRHilbertFilterTransformer(FilterByDesignTransformer[FIRHilbertFilterSettings, BACoeffs]):
[docs] def get_design_function(self) -> typing.Callable[[float], BACoeffs | None]: if self.settings.coef_type != "ba": ez.logger.error("FIRHilbert only supports coef_type='ba'.") raise ValueError("FIRHilbert only supports coef_type='ba'.") return functools.partial( fir_hilbert_design_fun, order=self.settings.order, f_lo=self.settings.f_lo, f_hi=self.settings.f_hi, trans_lo=self.settings.trans_lo, trans_hi=self.settings.trans_hi, weight_pass=self.settings.weight_pass, weight_stop_lo=self.settings.weight_stop_lo, weight_stop_hi=self.settings.weight_stop_hi, norm_band=self.settings.norm_band, norm_freq=self.settings.norm_freq, )
[docs] def get_taps(self) -> int | None: if self._state.filter is None: return None b, _ = self._state.filter.settings.coefs return b.size if b is not None else None
[docs] class FIRHilbertFilterUnit(BaseFilterByDesignTransformerUnit[FIRHilbertFilterSettings, FIRHilbertFilterTransformer]): SETTINGS = FIRHilbertFilterSettings
[docs] @processor_state class FIRHilbertEnvelopeState: filter: FIRHilbertFilterTransformer | None = None delay_buf: np.ndarray | None = None dly: int | None = None
[docs] class FIRHilbertEnvelopeTransformer( BaseStatefulTransformer[FIRHilbertFilterSettings, AxisArray, AxisArray, FIRHilbertEnvelopeState] ): """ Processor for computing the envelope of a signal using the Hilbert transform. This processor applies a Hilbert FIR filter to the input signal to obtain the analytic signal, from which the envelope is computed. The processor expects and outputs `AxisArray` messages with a `"time"` (time) axis. Settings: --------- order : int Filter order (taps = order + 1). Hilbert (type-III) filters require even order (odd taps). If odd order (even taps), order will be incremented by 1. f_lo : float Lower corner of Hilbert “pass” band (Hz). Transition starts at f_lo. f_hi : float, optional Upper corner of Hilbert “pass” band (Hz). Transition starts at f_hi. If None, highpass from f_lo to Nyquist. trans_lo : float Transition width (Hz) below f_lo. Decrease to sharpen transition. trans_hi : float Transition width (Hz) above f_hi. Decrease to sharpen transition. weight_pass : float Weight for Hilbert pass region. weight_stop_lo : float Weight for low stop band. weight_stop_hi : float Weight for high stop band. norm_band : tuple(float, float), optional Optional normalization band (f_lo, f_hi) in Hz for gain normalization. If None, no normalization is applied. norm_freq : float, optional Optional normalization frequency in Hz for gain normalization. If None, no normalization is applied. Example: ----------------------------- ```python processor = FIRHilbertEnvelopeTransformer( settings=FIRHilbertFilterSettings( order=170, f_lo=1.0, f_hi=50.0, ) ) ``` """ def _hash_message(self, message: AxisArray) -> int: axis = self.settings.axis or message.dims[0] gain = getattr(self._state.filter, "gain", 0.0) axis_idx = message.get_axis_idx(axis) samp_shape = message.data.shape[:axis_idx] + message.data.shape[axis_idx + 1 :] return hash((message.key, samp_shape, gain)) def _reset_state(self, message: AxisArray) -> None: self._state.filter = FIRHilbertFilterTransformer(settings=self.settings) self._state.delay_buf = None self._state.dly = None def _process(self, message: AxisArray) -> AxisArray: y_imag_msg = self._state.filter(message) y_imag = y_imag_msg.data axis_name = self.settings.axis or message.dims[0] axis_idx = message.get_axis_idx(axis_name) if self._state.dly is None: taps = self._state.filter.get_taps() self._state.dly = (taps - 1) // 2 x = message.data move_axis = False if axis_idx != x.ndim - 1: x = np.moveaxis(x, axis_idx, -1) y_imag = np.moveaxis(y_imag, axis_idx, -1) move_axis = True if self._state.delay_buf is None: lead_shape = x.shape[:-1] self._state.delay_buf = np.zeros(lead_shape + (self._state.dly,), dtype=x.dtype) x_cat = np.concatenate([self._state.delay_buf, x], axis=-1) x_delayed_full = x_cat[..., : -self._state.dly] y_real = x_delayed_full[..., -x.shape[-1] :] self._state.delay_buf = x_cat[..., -self._state.dly :].copy() analytic = y_real.astype(np.complex64) + 1j * y_imag.astype(np.complex64) out = np.abs(analytic) if move_axis: out = np.moveaxis(out, -1, axis_idx) return replace(message, data=out, axes=message.axes)
[docs] class FIRHilbertEnvelopeUnit( BaseTransformerUnit[ FIRHilbertFilterSettings, AxisArray, AxisArray, FIRHilbertEnvelopeTransformer, ] ): """ Unit wrapper for the `FIRHilbertEnvelopeTransformer`. This unit provides a plug-and-play interface for calculating the envelope using the FIR Hilbert transform on a signal in an ezmsg graph-based system. It takes in `AxisArray` inputs and outputs processed data in the same format. Example: -------- ```python unit = FIRHilbertEnvelopeUnit( settings=FIRHilbertFilterSettings( order=170, f_lo=1.0, f_hi=50.0, ) ) ``` """ SETTINGS = FIRHilbertFilterSettings