Source code for ezmsg.sigproc.filterbank
import functools
import math
import typing
import numpy as np
import scipy.signal as sps
import scipy.fft as sp_fft
from scipy.special import lambertw
import numpy.typing as npt
import ezmsg.core as ez
from ezmsg.util.messages.axisarray import AxisArray
from ezmsg.util.messages.util import replace
from .base import (
BaseStatefulTransformer,
BaseTransformerUnit,
processor_state,
)
from .spectrum import OptionsEnum
from .window import WindowTransformer
[docs]
class FilterbankMode(OptionsEnum):
"""The mode of operation for the filterbank."""
CONV = "Direct Convolution"
FFT = "FFT Convolution"
AUTO = "Automatic"
[docs]
class MinPhaseMode(OptionsEnum):
"""The mode of operation for the filterbank."""
NONE = "No kernel modification"
HILBERT = "Hilbert Method; designed to be used with equiripple filters (e.g., from remez) with unity or zero gain regions"
HOMOMORPHIC = "Works best with filters with an odd number of taps, and the resulting minimum phase filter will have a magnitude response that approximates the square root of the original filter’s magnitude response using half the number of taps"
# HOMOMORPHICFULL = "Like HOMOMORPHIC, but uses the full number of taps and same magnitude"
[docs]
class FilterbankSettings(ez.Settings):
kernels: list[npt.NDArray] | tuple[npt.NDArray, ...]
mode: FilterbankMode = FilterbankMode.CONV
"""
"conv", "fft", or "auto". If "auto", the mode is determined by the size of the input data.
fft mode is more efficient for long kernels. However, fft mode uses non-overlapping windows and will
incur a delay equal to the window length, which is larger than the largest kernel.
conv mode is less efficient but will return data for every incoming chunk regardless of how small it is
and thus can provide shorter latency updates.
"""
min_phase: MinPhaseMode = MinPhaseMode.NONE
"""
If not None, convert the kernels to minimum-phase equivalents. Valid options are
'hilbert', 'homomorphic', and 'homomorphic-full'. Complex filters not supported.
See `scipy.signal.minimum_phase` for details.
"""
axis: str = "time"
"""The name of the axis to operate on. This should usually be "time"."""
new_axis: str = "kernel"
"""The name of the new axis corresponding to the kernel index."""
[docs]
@processor_state
class FilterbankState:
tail: npt.NDArray | None = None
template: AxisArray | None = None
dest_arr: npt.NDArray | None = None
prep_kerns: npt.NDArray | list[npt.NDArray] | None = None
windower: WindowTransformer | None = None
fft: typing.Callable | None = None
ifft: typing.Callable | None = None
nfft: int | None = None
infft: int | None = None
overlap: int | None = None
mode: FilterbankMode | None = None
[docs]
class FilterbankTransformer(
BaseStatefulTransformer[FilterbankSettings, AxisArray, AxisArray, FilterbankState]
):
def _hash_message(self, message: AxisArray) -> int:
axis = self.settings.axis or message.dims[0]
gain = message.axes[axis].gain if axis in message.axes else 1.0
targ_ax_ix = message.get_axis_idx(axis)
in_shape = (
message.data.shape[:targ_ax_ix] + message.data.shape[targ_ax_ix + 1 :]
)
return hash(
(
message.key,
gain
if self.settings.mode in [FilterbankMode.FFT, FilterbankMode.AUTO]
else None,
message.data.dtype.kind,
in_shape,
)
)
def _reset_state(self, message: AxisArray) -> None:
axis = self.settings.axis or message.dims[0]
gain = message.axes[axis].gain if axis in message.axes else 1.0
targ_ax_ix = message.get_axis_idx(axis)
in_shape = (
message.data.shape[:targ_ax_ix] + message.data.shape[targ_ax_ix + 1 :]
)
kernels = self.settings.kernels
if self.settings.min_phase != MinPhaseMode.NONE:
method, half = {
MinPhaseMode.HILBERT: ("hilbert", False),
MinPhaseMode.HOMOMORPHIC: ("homomorphic", False),
# MinPhaseMode.HOMOMORPHICFULL: ("homomorphic", True),
}[self.settings.min_phase]
kernels = [sps.minimum_phase(k, method=method) for k in kernels]
# Determine if this will be operating with complex data.
b_complex = message.data.dtype.kind == "c" or any(
[_.dtype.kind == "c" for _ in kernels]
)
# Calculate window_dur, window_shift, nfft
max_kernel_len = max([_.size for _ in kernels])
# From sps._calc_oa_lens, where s2=max_kernel_len,:
# fallback_nfft = n_input + max_kernel_len - 1, but n_input is unbound.
self._state.overlap = max_kernel_len - 1
# Prepare previous iteration's overlap tail to add to input -- all zeros.
tail_shape = in_shape + (len(kernels), self._state.overlap)
self._state.tail = np.zeros(
tail_shape, dtype="complex" if b_complex else "float"
)
# Prepare output template -- kernels axis immediately before the target axis
dummy_shape = in_shape + (len(kernels), 0)
self._state.template = AxisArray(
data=np.zeros(dummy_shape, dtype="complex" if b_complex else "float"),
dims=message.dims[:targ_ax_ix]
+ message.dims[targ_ax_ix + 1 :]
+ [self.settings.new_axis, axis],
axes=message.axes.copy(),
key=message.key,
)
# Determine optimal mode. Assumes 100 msec chunks.
self._state.mode = self.settings.mode
if self._state.mode == FilterbankMode.AUTO:
# concatenate kernels into 1 mega kernel then check what's faster.
# Will typically return fft when combined kernel length is > 1500.
concat_kernel = np.concatenate(kernels)
n_dummy = max(2 * len(concat_kernel), int(0.1 / gain))
dummy_arr = np.zeros(n_dummy)
self._state.mode = (
FilterbankMode.CONV
if sps.choose_conv_method(dummy_arr, concat_kernel, mode="full")
== "direct"
else FilterbankMode.FFT
)
if self._state.mode == FilterbankMode.CONV:
# Preallocate memory for convolution result and overlap-add
dest_shape = in_shape + (
len(kernels),
self._state.overlap + message.data.shape[targ_ax_ix],
)
self._state.dest_arr = np.zeros(
dest_shape, dtype="complex" if b_complex else "float"
)
self._state.prep_kerns = kernels
else: # FFT mode
# Calculate optimal nfft and windowing size.
opt_size = (
-self._state.overlap
* lambertw(-1 / (2 * math.e * self._state.overlap), k=-1).real
)
self._state.nfft = sp_fft.next_fast_len(math.ceil(opt_size))
win_len = self._state.nfft - self._state.overlap
# infft same as nfft. Keeping as separate variable because I might need it again.
self._state.infft = win_len + self._state.overlap
# Create windowing node.
# Note: We could do windowing manually to avoid the overhead of the message structure,
# but windowing is difficult to do correctly, so we lean on the heavily-tested `windowing` generator.
win_dur = win_len * gain
self._state.windower = WindowTransformer(
axis=axis,
newaxis="win",
window_dur=win_dur,
window_shift=win_dur,
zero_pad_until="none",
)
# Windowing output has an extra "win" dimension, so we need our tail to match.
self._state.tail = np.expand_dims(self._state.tail, -2)
# Prepare fft functions
# Note: We could instead use `spectrum` but this adds overhead in creating the message structure
# for a rather simple calculation. We may revisit if `spectrum` gets additional features, such as
# more fft backends.
if b_complex:
self._state.fft = functools.partial(
sp_fft.fft, n=self._state.nfft, norm="backward"
)
self._state.ifft = functools.partial(
sp_fft.ifft, n=self._state.infft, norm="backward"
)
else:
self._state.fft = functools.partial(
sp_fft.rfft, n=self._state.nfft, norm="backward"
)
self._state.ifft = functools.partial(
sp_fft.irfft, n=self._state.infft, norm="backward"
)
# Calculate fft of kernels
self._state.prep_kerns = np.array([self._state.fft(_) for _ in kernels])
self._state.prep_kerns = np.expand_dims(self._state.prep_kerns, -2)
# TODO: If fft_kernels have significant stretches of zeros, convert to sparse array.
def _process(self, message: AxisArray) -> AxisArray:
axis = self.settings.axis or message.dims[0]
targ_ax_ix = message.get_axis_idx(axis)
# Make sure target axis is in -1th position.
if targ_ax_ix != (message.data.ndim - 1):
in_dat = np.moveaxis(message.data, targ_ax_ix, -1)
if self._state.mode == FilterbankMode.FFT:
# Fix message.dims because we will pass it to windower
move_dims = (
message.dims[:targ_ax_ix] + message.dims[targ_ax_ix + 1 :] + [axis]
)
message = replace(message, data=in_dat, dims=move_dims)
else:
in_dat = message.data
if self._state.mode == FilterbankMode.CONV:
n_dest = in_dat.shape[-1] + self._state.overlap
if self._state.dest_arr.shape[-1] < n_dest:
pad = np.zeros(
self._state.dest_arr.shape[:-1]
+ (n_dest - self._state.dest_arr.shape[-1],)
)
self._state.dest_arr = np.concatenate(
[self._state.dest_arr, pad], axis=-1
)
self._state.dest_arr.fill(0)
# Note: I tried several alternatives to this loop; all were slower than this.
# numba.jit; stride_tricks + np.einsum; threading. Latter might be better with Python 3.13.
for k_ix, k in enumerate(self._state.prep_kerns):
n_out = in_dat.shape[-1] + k.shape[-1] - 1
self._state.dest_arr[..., k_ix, :n_out] = np.apply_along_axis(
np.convolve, -1, in_dat, k, mode="full"
)
self._state.dest_arr[..., : self._state.overlap] += self._state.tail
new_tail = self._state.dest_arr[..., in_dat.shape[-1] : n_dest]
if new_tail.size > 0:
# COPY overlap for next iteration
self._state.tail = new_tail.copy()
res = self._state.dest_arr[..., : in_dat.shape[-1]].copy()
else: # FFT mode
# Slice into non-overlapping windows
win_msg = self._state.windower.send(message)
# Calculate spectrum of each window
spec_dat = self._state.fft(win_msg.data, axis=-1)
# Insert axis for filters
spec_dat = np.expand_dims(spec_dat, -3)
# Do the FFT convolution
# TODO: handle fft_kernels being sparse. Maybe need np.dot.
conv_spec = spec_dat * self._state.prep_kerns
overlapped = self._state.ifft(conv_spec, axis=-1)
# Do the overlap-add on the `axis` axis
# Previous iteration's tail:
overlapped[..., :1, : self._state.overlap] += self._state.tail
# window-to-window:
overlapped[..., 1:, : self._state.overlap] += overlapped[
..., :-1, -self._state.overlap :
]
# Save tail:
new_tail = overlapped[..., -1:, -self._state.overlap :]
if new_tail.size > 0:
# All of the above code works if input is size-zero, but we don't want to save a zero-size tail.
self._state.tail = new_tail
# Concat over win axis, without overlap.
res = overlapped[..., : -self._state.overlap].reshape(
overlapped.shape[:-2] + (-1,)
)
return replace(
self._state.template,
data=res,
axes={**self._state.template.axes, axis: message.axes[axis]},
)
[docs]
class Filterbank(
BaseTransformerUnit[FilterbankSettings, AxisArray, AxisArray, FilterbankTransformer]
):
SETTINGS = FilterbankSettings
[docs]
def filterbank(
kernels: list[npt.NDArray] | tuple[npt.NDArray, ...],
mode: FilterbankMode = FilterbankMode.CONV,
min_phase: MinPhaseMode = MinPhaseMode.NONE,
axis: str = "time",
new_axis: str = "kernel",
) -> FilterbankTransformer:
"""
Perform multiple (direct or fft) convolutions on a signal using a bank of kernels.
This is intended to be used during online processing, therefore both direct and fft convolutions
use the overlap-add method.
Returns: :obj:`FilterbankTransformer`.
"""
return FilterbankTransformer(
settings=FilterbankSettings(
kernels=kernels,
mode=mode,
min_phase=min_phase,
axis=axis,
new_axis=new_axis,
)
)