Source code for ezmsg.sigproc.adaptive_lattice_notch
import numpy as np
import numpy.typing as npt
import scipy.signal
import ezmsg.core as ez
from ezmsg.util.messages.axisarray import AxisArray, CoordinateAxis
from ezmsg.util.messages.util import replace
from .base import processor_state, BaseStatefulTransformer
[docs]
class AdaptiveLatticeNotchFilterSettings(ez.Settings):
"""Settings for the Adaptive Lattice Notch Filter."""
gamma: float = 0.995
"""Pole-zero contraction factor"""
mu: float = 0.99
"""Smoothing factor"""
eta: float = 0.99
"""Forgetting factor"""
axis: str = "time"
"""Axis to apply filter to"""
init_notch_freq: float | None = None
"""Initial notch frequency. Should be < nyquist."""
chunkwise: bool = False
"""Speed up processing by updating the target freq once per chunk only."""
[docs]
@processor_state
class AdaptiveLatticeNotchFilterState:
"""State for the Adaptive Lattice Notch Filter."""
s_history: npt.NDArray | None = None
"""Historical `s` values for the adaptive filter."""
p: npt.NDArray | None = None
"""Accumulated product for reflection coefficient update"""
q: npt.NDArray | None = None
"""Accumulated product for reflection coefficient update"""
k1: npt.NDArray | None = None
"""Reflection coefficient"""
freq_template: CoordinateAxis | None = None
"""Template for the frequency axis on the output"""
zi: npt.NDArray | None = None
"""Initial conditions for the filter, updated after every chunk"""
[docs]
class AdaptiveLatticeNotchFilterTransformer(
BaseStatefulTransformer[
AdaptiveLatticeNotchFilterSettings,
AxisArray,
AxisArray,
AdaptiveLatticeNotchFilterState,
]
):
"""
Adaptive Lattice Notch Filter implementation as a stateful transformer.
https://biomedical-engineering-online.biomedcentral.com/articles/10.1186/1475-925X-13-170
The filter automatically tracks and removes frequency components from the input signal.
It outputs the estimated frequency (in Hz) and the filtered sample.
"""
def _hash_message(self, message: AxisArray) -> int:
ax_idx = message.get_axis_idx(self.settings.axis)
sample_shape = message.data.shape[:ax_idx] + message.data.shape[ax_idx + 1 :]
return hash((message.key, message.axes[self.settings.axis].gain, sample_shape))
def _reset_state(self, message: AxisArray) -> None:
ax_idx = message.get_axis_idx(self.settings.axis)
sample_shape = message.data.shape[:ax_idx] + message.data.shape[ax_idx + 1 :]
fs = 1 / message.axes[self.settings.axis].gain
init_f = (
self.settings.init_notch_freq
if self.settings.init_notch_freq is not None
else 0.07178314656435313 * fs
)
init_omega = init_f * (2 * np.pi) / fs
init_k1 = -np.cos(init_omega)
"""Reset filter state to initial values."""
self._state = AdaptiveLatticeNotchFilterState()
self._state.s_history = np.zeros((2,) + sample_shape, dtype=float)
self._state.p = np.zeros(sample_shape, dtype=float)
self._state.q = np.zeros(sample_shape, dtype=float)
self._state.k1 = init_k1 + np.zeros(sample_shape, dtype=float)
self._state.freq_template = CoordinateAxis(
data=np.zeros((0,) + sample_shape, dtype=float),
dims=[self.settings.axis]
+ message.dims[:ax_idx]
+ message.dims[ax_idx + 1 :],
unit="Hz",
)
# Initialize the initial conditions for the filter
self._state.zi = np.zeros((2, np.prod(sample_shape)), dtype=float)
# Note: we could calculate it properly, but as long as we are initializing s_history with zeros,
# it will always be zero.
# a = [1, init_k1 * (1 + self.settings.gamma), self.settings.gamma]
# b = [1]
# s = np.reshape(self._state.s_history, (2, -1))
# for feat_ix in range(np.prod(sample_shape)):
# self._state.zi[:, feat_ix] = scipy.signal.lfiltic(b, a, s[::-1, feat_ix], x=None)
def _process(self, message: AxisArray) -> AxisArray:
x_data = message.data
ax_idx = message.get_axis_idx(self.settings.axis)
# TODO: Time should be moved to -1th axis, not the 0th axis
if message.dims[0] != self.settings.axis:
x_data = np.moveaxis(x_data, ax_idx, 0)
# Access settings once
gamma = self.settings.gamma
eta = self.settings.eta
mu = self.settings.mu
fs = 1 / message.axes[self.settings.axis].gain
# Pre-compute constants
one_minus_eta = 1 - eta
one_minus_mu = 1 - mu
gamma_plus_1 = 1 + gamma
omega_scale = fs / (2 * np.pi)
# For the lattice filter with constant k1:
# s_n = x_n - k1*(1+gamma)*s_n_1 - gamma*s_n_2
# This is equivalent to an IIR filter with b=1, a=[1, k1*(1+gamma), gamma]
# For the output filter:
# y_n = s_n + 2*k1*s_n_1 + s_n_2
# We can treat this as a direct-form FIR filter applied to s_out
if self.settings.chunkwise:
# Process each chunk using current filter parameters
# Reshape input and prepare output arrays
_s = self._state.s_history.reshape((2, -1))
_x = x_data.reshape((x_data.shape[0], -1))
s_n = np.zeros_like(_x)
y_out = np.zeros_like(_x)
# Apply static filter for each feature dimension
for ix, k in enumerate(self._state.k1.flatten()):
# Filter to get s_n (notch filter state)
a_s = [1, k * gamma_plus_1, gamma]
s_n[:, ix], self._state.zi[:, ix] = scipy.signal.lfilter(
[1], a_s, _x[:, ix], zi=self._state.zi[:, ix]
)
# Apply output filter to get y_out
b_y = [1, 2 * k, 1]
y_out[:, ix] = scipy.signal.lfilter(b_y, [1], s_n[:, ix])
# Update filter parameters using final values from the chunk
s_n_reshaped = s_n.reshape((s_n.shape[0],) + x_data.shape[1:])
s_final = s_n_reshaped[-1] # Current s_n
s_final_1 = s_n_reshaped[-2] # s_n_1
s_final_2 = (
s_n_reshaped[-3] if len(s_n_reshaped) > 2 else self._state.s_history[0]
) # s_n_2
# Update p and q using final values
self._state.p = eta * self._state.p + one_minus_eta * (
s_final_1 * (s_final + s_final_2)
)
self._state.q = eta * self._state.q + one_minus_eta * (
2 * (s_final_1 * s_final_1)
)
# Update reflection coefficient
new_k1 = -self._state.p / (self._state.q + 1e-8) # Avoid division by zero
new_k1 = np.clip(new_k1, -1, 1) # Clip to prevent instability
self._state.k1 = mu * self._state.k1 + one_minus_mu * new_k1 # Smoothed
# Calculate frequency from updated k1 value
omega_n = np.arccos(-self._state.k1)
freq = omega_n * omega_scale
freq_out = np.full_like(x_data.reshape(x_data.shape), freq)
# Update s_history for next chunk
self._state.s_history = s_n_reshaped[-2:].reshape((2,) + x_data.shape[1:])
# Reshape y_out back to original dimensions
y_out = y_out.reshape(x_data.shape)
else:
# Perform filtering, sample-by-sample
y_out = np.zeros_like(x_data)
freq_out = np.zeros_like(x_data)
for sample_ix, x_n in enumerate(x_data):
s_n_1 = self._state.s_history[-1]
s_n_2 = self._state.s_history[-2]
s_n = x_n - self._state.k1 * gamma_plus_1 * s_n_1 - gamma * s_n_2
y_out[sample_ix] = s_n + 2 * self._state.k1 * s_n_1 + s_n_2
# Update filter parameters
self._state.p = eta * self._state.p + one_minus_eta * (
s_n_1 * (s_n + s_n_2)
)
self._state.q = eta * self._state.q + one_minus_eta * (
2 * (s_n_1 * s_n_1)
)
# Update reflection coefficient
new_k1 = -self._state.p / (
self._state.q + 1e-8
) # Avoid division by zero
new_k1 = np.clip(new_k1, -1, 1) # Clip to prevent instability
self._state.k1 = mu * self._state.k1 + one_minus_mu * new_k1 # Smoothed
# Compute normalized angular frequency using equation 13 from the paper
omega_n = np.arccos(-self._state.k1)
freq_out[sample_ix] = omega_n * omega_scale # As Hz
# Update for next iteration
self._state.s_history[-2] = s_n_1
self._state.s_history[-1] = s_n
return replace(
message,
data=y_out,
axes={
**message.axes,
"freq": replace(self._state.freq_template, data=freq_out),
},
)