Source code for ezmsg.sigproc.fir_hilbert

"""FIR Hilbert transform filter for analytic signal and envelope extraction."""

import functools
import typing

import ezmsg.core as ez
import numpy as np
import scipy.signal as sps
from ezmsg.baseproc import BaseStatefulTransformer, processor_state
from ezmsg.util.messages.axisarray import AxisArray
from ezmsg.util.messages.util import replace

from ezmsg.sigproc.filter import (
    BACoeffs,
    BaseFilterByDesignTransformerUnit,
    BaseTransformerUnit,
    FilterBaseSettings,
    FilterByDesignTransformer,
)


[docs] class FIRHilbertFilterSettings(FilterBaseSettings): """Settings for :obj:`FIRHilbertFilter`.""" # axis inherited from FilterBaseSettings coef_type: str = "ba" """ Coefficient type. Must be 'ba' for FIR. """ order: int = 170 """ Filter order (taps = order + 1). Hilbert (type-III) filters require even order (odd taps). If odd order (even taps), order will be incremented by 1. """ f_lo: float = 1.0 """ Lower corner of Hilbert “pass” band (Hz). Transition starts at f_lo. """ f_hi: float | None = None """ Upper corner of Hilbert “pass” band (Hz). Transition starts at f_hi. If None, highpass from f_lo to Nyquist. """ trans_lo: float = 1.0 """ Transition width (Hz) below f_lo. Decrease to sharpen transition. """ trans_hi: float = 1.0 """ Transition width (Hz) at high end. Decrease to sharpen transition. """ weight_pass: float = 1.0 """ Weight for Hilbert pass region. """ weight_stop_lo: float = 1.0 """ Weight for low stop band. """ weight_stop_hi: float = 1.0 """ Weight for high stop band. """ norm_band: tuple[float, float] | None = None """ Optional normalization band (f_lo, f_hi) in Hz for gain normalization. If None, no normalization is applied. """ norm_freq: float | None = None """ Optional normalization frequency in Hz for gain normalization. If None, no normalization is applied. """
[docs] def fir_hilbert_design_fun( fs: float, order: int = 170, f_lo: float = 1.0, f_hi: float | None = None, trans_lo: float = 1.0, trans_hi: float = 1.0, weight_pass: float = 1.0, weight_stop_lo: float = 1.0, weight_stop_hi: float = 1.0, norm_band: tuple[float, float] | None = None, norm_freq: float | None = None, ) -> BACoeffs | None: """ Hilbert FIR filter design using the Remez exchange algorithm. Design an `order`th-order FIR Hilbert filter and return the filter coefficients. See :obj:`FIRHilbertFilterSettings` for argument description. Returns: The filter coefficients as a tuple of (b, a). """ if order <= 0: return None if order % 2 == 1: order += 1 nyq = fs / 2.0 taps = order + 1 f1 = max(f_lo, 0.0) + trans_lo f2 = (nyq - trans_hi) if (f_hi is None) else min(f_hi, nyq - trans_hi) if not (0.0 < f1 < f2 < nyq): raise ValueError( f"Hilbert passband collapsed or invalid: " f"f_lo={f_lo}, f_hi={f_hi}, trans_lo={trans_lo}, trans_hi={trans_hi}, fs={fs}" ) # Bands: [0, f1-trans_lo] stop ; [f1, f2] pass (Hilbert) ; [f2+trans_hi, nyq] stop bands = [0.0, max(f1 - trans_lo, 0.0), f1, f2, min(f2 + trans_hi, nyq), nyq] desired = [0.0, 1.0, 0.0] weight = [max(weight_stop_lo, 0.0), max(weight_pass, 0.0), max(weight_stop_hi, 0.0)] for i in range(1, len(bands) - 1): if bands[i] <= bands[i - 1]: bands[i] = np.nextafter(bands[i - 1], np.inf) if bands[-2] >= nyq: ez.logger.warning("Hilbert upper stopband collapsed; using 2-band (stop/pass) design.") bands = bands[:-3] + [nyq] desired = desired[:-1] weight = weight[:-1] b = sps.remez(taps, bands, desired, weight=weight, type="hilbert", fs=fs) a = np.array([1.0]) g = None if norm_freq is not None: if norm_freq < f1 or norm_freq > f2: ez.logger.warning("Invalid normalization frequency specifications. Skipping normalization.") else: f0 = float(norm_freq) w = 2.0 * np.pi * (np.asarray([f0], dtype=np.float64) / fs) _, H = sps.freqz(b, a, worN=w) g = float(np.abs(H[0])) elif norm_band is not None: lo, hi = norm_band if lo < f1 or hi > f2: lo = max(lo, f1) hi = min(hi, f2) ez.logger.warning("Normalization band outside passband. Clipping to passband for normalization.") if lo >= hi: ez.logger.warning("Invalid normalization band specifications. Skipping normalization.") else: freqs = np.linspace(lo, hi, 2048, dtype=np.float64) w = 2.0 * np.pi * (np.asarray(freqs, dtype=np.float64) / fs) _, H = sps.freqz(b, a, worN=w) g = float(np.median(np.abs(H))) if g is not None and g > 0: b = b / g return (b, a)
[docs] class FIRHilbertFilterTransformer(FilterByDesignTransformer[FIRHilbertFilterSettings, BACoeffs]):
[docs] def get_design_function(self) -> typing.Callable[[float], BACoeffs | None]: if self.settings.coef_type != "ba": ez.logger.error("FIRHilbert only supports coef_type='ba'.") raise ValueError("FIRHilbert only supports coef_type='ba'.") return functools.partial( fir_hilbert_design_fun, order=self.settings.order, f_lo=self.settings.f_lo, f_hi=self.settings.f_hi, trans_lo=self.settings.trans_lo, trans_hi=self.settings.trans_hi, weight_pass=self.settings.weight_pass, weight_stop_lo=self.settings.weight_stop_lo, weight_stop_hi=self.settings.weight_stop_hi, norm_band=self.settings.norm_band, norm_freq=self.settings.norm_freq, )
[docs] def get_taps(self) -> int | None: if self._state.filter is None: return None b, _ = self._state.filter.settings.coefs return b.size if b is not None else None
[docs] class FIRHilbertFilterUnit(BaseFilterByDesignTransformerUnit[FIRHilbertFilterSettings, FIRHilbertFilterTransformer]): SETTINGS = FIRHilbertFilterSettings
[docs] @processor_state class FIRHilbertEnvelopeState: filter: FIRHilbertFilterTransformer | None = None delay_buf: np.ndarray | None = None dly: int | None = None
[docs] class FIRHilbertEnvelopeTransformer( BaseStatefulTransformer[FIRHilbertFilterSettings, AxisArray, AxisArray, FIRHilbertEnvelopeState] ): """ Processor for computing the envelope of a signal using the Hilbert transform. This processor applies a Hilbert FIR filter to the input signal to obtain the analytic signal, from which the envelope is computed. The processor expects and outputs `AxisArray` messages with a `"time"` (time) axis. Settings: --------- order : int Filter order (taps = order + 1). Hilbert (type-III) filters require even order (odd taps). If odd order (even taps), order will be incremented by 1. f_lo : float Lower corner of Hilbert “pass” band (Hz). Transition starts at f_lo. f_hi : float, optional Upper corner of Hilbert “pass” band (Hz). Transition starts at f_hi. If None, highpass from f_lo to Nyquist. trans_lo : float Transition width (Hz) below f_lo. Decrease to sharpen transition. trans_hi : float Transition width (Hz) above f_hi. Decrease to sharpen transition. weight_pass : float Weight for Hilbert pass region. weight_stop_lo : float Weight for low stop band. weight_stop_hi : float Weight for high stop band. norm_band : tuple(float, float), optional Optional normalization band (f_lo, f_hi) in Hz for gain normalization. If None, no normalization is applied. norm_freq : float, optional Optional normalization frequency in Hz for gain normalization. If None, no normalization is applied. Example:: processor = FIRHilbertEnvelopeTransformer( settings=FIRHilbertFilterSettings( order=170, f_lo=1.0, f_hi=50.0, ) ) """ def _hash_message(self, message: AxisArray) -> int: axis = self.settings.axis or message.dims[0] gain = getattr(self._state.filter, "gain", 0.0) axis_idx = message.get_axis_idx(axis) samp_shape = message.data.shape[:axis_idx] + message.data.shape[axis_idx + 1 :] return hash((message.key, samp_shape, gain)) def _reset_state(self, message: AxisArray) -> None: self._state.filter = FIRHilbertFilterTransformer(settings=self.settings) self._state.delay_buf = None self._state.dly = None def _process(self, message: AxisArray) -> AxisArray: y_imag_msg = self._state.filter(message) y_imag = y_imag_msg.data axis_name = self.settings.axis or message.dims[0] axis_idx = message.get_axis_idx(axis_name) if self._state.dly is None: taps = self._state.filter.get_taps() self._state.dly = (taps - 1) // 2 x = message.data move_axis = False if axis_idx != x.ndim - 1: x = np.moveaxis(x, axis_idx, -1) y_imag = np.moveaxis(y_imag, axis_idx, -1) move_axis = True if self._state.delay_buf is None: lead_shape = x.shape[:-1] self._state.delay_buf = np.zeros(lead_shape + (self._state.dly,), dtype=x.dtype) x_cat = np.concatenate([self._state.delay_buf, x], axis=-1) x_delayed_full = x_cat[..., : -self._state.dly] y_real = x_delayed_full[..., -x.shape[-1] :] self._state.delay_buf = x_cat[..., -self._state.dly :].copy() analytic = y_real.astype(np.complex64) + 1j * y_imag.astype(np.complex64) out = np.abs(analytic) if move_axis: out = np.moveaxis(out, -1, axis_idx) return replace(message, data=out, axes=message.axes)
[docs] class FIRHilbertEnvelopeUnit( BaseTransformerUnit[ FIRHilbertFilterSettings, AxisArray, AxisArray, FIRHilbertEnvelopeTransformer, ] ): """ Unit wrapper for the `FIRHilbertEnvelopeTransformer`. This unit provides a plug-and-play interface for calculating the envelope using the FIR Hilbert transform on a signal in an ezmsg graph-based system. It takes in `AxisArray` inputs and outputs processed data in the same format. Example:: unit = FIRHilbertEnvelopeUnit( settings=FIRHilbertFilterSettings( order=170, f_lo=1.0, f_hi=50.0, ) ) """ SETTINGS = FIRHilbertFilterSettings