Source code for ezmsg.blackrock.sampling_delay_alignment
"""Align channels sampled at different instants by a sequential A/D.
The Gemini front-end samples channels in banks of ``bank_size`` (32), one every
``channel_sample_interval`` (~969.7 ns), so channel ``c``'s sample ``n`` is the
signal at ``t_n + tau_c``, ``tau_c = (c % bank_size) * channel_sample_interval``.
For any cross-channel operation (CAR, whitening, beamforming) this misalignment
smears the common-mode at high frequency: the phase spread across a bank is
``2*pi*f*T_bank`` -- negligible at 60 Hz (0.65 deg) but ~81 deg at 7.5 kHz, so
e.g. CAR's common-mode rejection collapses toward Nyquist.
This transformer removes that by delaying each channel by ``tau_c`` with a
per-slot windowed-sinc fractional-delay filter, bringing every channel onto a
common time grid (the bank start). A windowed-sinc is used rather than linear
interpolation on purpose: linear interpolation is a delay-dependent low-pass
that would impose a *different* high-frequency rolloff per channel -- coloring
the band exactly where the misalignment mattered. There are only ``bank_size``
distinct delays, so only that many distinct filters.
The within-bank slot defaults to acquisition order (``c % bank_size``). If the
channel axis carries per-channel ``bank``/``elec`` metadata (e.g. attached by
:class:`~ezmsg.blackrock.ChannelMapUnit`), the slot is taken from ``elec``
(``elec - 1``) instead, so each channel's delay is correct even when channels
are reordered relative to hardware acquisition.
Cost / caveats:
* **Latency:** the causal FIR has a common bulk delay of ``(filter_len-1)//2``
samples (the per-channel fractional delays ride on top). The output time
axis offset is shifted to keep timestamps physically correct.
* **It resamples the raw data** -- downstream sees interpolated samples. Fine
for cross-channel cleaning; be deliberate if a step needs raw waveforms.
* **Railing:** clipped (rail) samples are corrupt and a fractional-delay
filter would spread that corruption over its support. With
``rail_threshold`` set, railed samples are held at the last valid value
before filtering (a basic mitigation). A production version should also
emit a reliability mask so downstream can discount the ~``filter_len``
samples around each rail. FIR (used here) localizes the damage; an IIR
all-pass (e.g. Thiran) would ring across it.
Array-API compatible: it detects the input's namespace and runs on the working
backend (numpy, MLX, torch, jax, cupy, ...). The sinc taps are designed in numpy
and moved to the backend; everything else -- the FIR tap-sum, concat/state
handling, and the rail forward-fill -- runs on the backend using only standard
Array-API ops (the forward-fill's cumulative max is built from ``maximum`` +
shifts, since the standard lacks one). Only the MLX ``concatenate``-vs-``concat``
spelling is special-cased.
"""
from typing import Any
import ezmsg.core as ez
import numpy as np
import numpy.typing as npt
from array_api_compat import array_namespace
from ezmsg.baseproc import (
BaseStatefulTransformer,
BaseTransformerUnit,
processor_state,
)
from ezmsg.util.messages.axisarray import AxisArray
from ezmsg.util.messages.util import replace
try: # pragma: no cover - exercised only when mlx is installed
import mlx.core as _mx
except Exception: # pragma: no cover
_mx = None
def _is_mlx(arr: object) -> bool:
return _mx is not None and isinstance(arr, _mx.array)
def _namespace(arr: object) -> tuple[Any, bool]:
"""Return ``(xp, is_mlx)``: the MLX module for MLX arrays, else the array's
Array-API namespace (numpy, torch, jax, cupy, ...)."""
if _is_mlx(arr):
return _mx, True
return array_namespace(arr), False
def _concat(xp: Any, is_mlx: bool, arrays: list, axis: int = 0) -> Any:
"""Concatenate (MLX spells it ``concatenate``; Array-API uses ``concat``)."""
return _mx.concatenate(arrays, axis=axis) if is_mlx else xp.concat(arrays, axis=axis)
_DEFAULT_BANK_SIZE = 32
_DEFAULT_CHANNEL_SAMPLE_INTERVAL = 64.0 / 66.0e6
[docs]
class SamplingDelayAlignmentSettings(ez.Settings):
"""Settings for :class:`SamplingDelayAlignmentTransformer`."""
bank_size: int = _DEFAULT_BANK_SIZE
"""Channels per simultaneously-started A/D bank. Used to derive each
channel's sweep slot (``c % bank_size``) only as a fallback, when the channel
axis carries no ``bank``/``elec`` metadata."""
channel_sample_interval: float = _DEFAULT_CHANNEL_SAMPLE_INTERVAL
"""Seconds between successive channels within a bank."""
filter_len: int = 33
"""Sinc FIR length (odd). Bulk delay is ``(filter_len-1)//2`` samples; longer
= flatter passband / better near Nyquist, at more latency and compute. Set to
``0`` to disable alignment entirely -- the transformer becomes a pass-through
that returns its input unchanged."""
rail_threshold: float | None = None
"""If set, samples with ``abs(value) >= rail_threshold`` are treated as
clipped and held at the last valid value before filtering. ``None`` skips
rail handling. (For Blackrock int16 at 0.25 uV/count, the rail is ~8191 uV.)"""
[docs]
@processor_state
class SamplingDelayAlignmentState:
"""State for :class:`SamplingDelayAlignmentTransformer`."""
fir: npt.NDArray | None = None
"""Per-channel sinc FIR taps, shape ``(filter_len, n_ch)``."""
hist: npt.NDArray | None = None
"""Carried input history, shape ``(filter_len-1, *sample_shape)``."""
bulk_delay: int = 0
"""Common bulk delay ``(filter_len-1)//2`` samples (for the offset shift)."""
[docs]
class SamplingDelayAlignmentTransformer(
BaseStatefulTransformer[
SamplingDelayAlignmentSettings,
AxisArray,
AxisArray,
SamplingDelayAlignmentState,
]
):
"""Per-channel fractional-delay alignment (see module docstring)."""
# The rail threshold only gates the forward-fill in _process; it doesn't
# alter the designed filters, so changing it needn't reset the state.
NONRESET_SETTINGS_FIELDS = frozenset({"rail_threshold"})
@property
def _passthrough(self) -> bool:
"""``filter_len <= 0`` disables alignment: the transformer returns its
input unchanged and skips building the FIR (undefined for ``n_taps`` 0)."""
return self.settings.filter_len < 1
def _channel_slots(self, message: AxisArray) -> npt.NDArray:
"""Within-bank A/D sweep position (0-based) for each channel on the
``ch`` axis.
Prefers channel metadata: when the ``ch`` axis carries a structured
``.data`` with ``bank`` and ``elec`` fields (as produced by
:class:`~ezmsg.blackrock.ChannelMapUnit`), the slot is ``elec - 1`` --
the channel's physical position in its bank's sequential sweep, so the
delay is correct even when channels are not in hardware-acquisition
order. Falls back to acquisition-order banks of ``bank_size``
(``arange(n_ch) % bank_size``) when that metadata is absent.
"""
n_ch = message.data.shape[message.get_axis_idx("ch")]
data = getattr(message.axes.get("ch"), "data", None)
names = getattr(getattr(data, "dtype", None), "names", None)
if names is not None and "bank" in names and "elec" in names and len(data) == n_ch:
return data["elec"].astype(np.int64) - 1
return np.arange(n_ch) % self.settings.bank_size
def _hash_message(self, message: AxisArray) -> int:
time_idx = message.get_axis_idx("time")
sample_shape = message.data.shape[:time_idx] + message.data.shape[time_idx + 1 :]
# Include the slot layout so a metadata change (e.g. a new channel map)
# re-designs the filters even when shape/key/gain are unchanged.
slot = self._channel_slots(message)
return hash((message.key, message.axes["time"].gain, sample_shape, slot.tobytes()))
def _reset_state(self, message: AxisArray) -> None:
if self._passthrough:
return # no filters to design; _process returns the input as-is
time_idx = message.get_axis_idx("time")
sample_shape = message.data.shape[:time_idx] + message.data.shape[time_idx + 1 :]
dtype = message.data.dtype
xp, is_mlx = _namespace(message.data)
fs = 1.0 / message.axes["time"].gain
slot = self._channel_slots(message)
# Fractional-sample delay that brings each channel back to its bank start.
d = slot * self.settings.channel_sample_interval * fs # in [0, ~0.9]
n_taps = int(self.settings.filter_len)
m = (n_taps - 1) // 2
self._state.bulk_delay = m
# Design the per-channel windowed sinc in numpy (total delay m + d_c, DC
# gain 1), then move the taps onto the working backend.
k = np.arange(n_taps)[:, None]
h = np.sinc(k - m - d[None, :]) * np.blackman(n_taps)[:, None]
h = h / h.sum(axis=0, keepdims=True)
if is_mlx:
self._state.fir = _mx.array(h.astype(np.float32))
self._state.hist = _mx.zeros((n_taps - 1,) + sample_shape, dtype=dtype)
else:
# h is numpy; convert to the backend then to its dtype (dtype may be
# a non-numpy dtype, e.g. torch.float32, that numpy.astype rejects).
self._state.fir = xp.astype(xp.asarray(h), dtype)
self._state.hist = xp.zeros((n_taps - 1,) + sample_shape, dtype=dtype)
@staticmethod
def _fill_rails(x: npt.NDArray, thresh: float, xp: Any, is_mlx: bool) -> npt.NDArray:
"""Forward-fill (hold last valid) over railed samples, per channel.
Backend-portable: per (time, channel), find the index of the most recent
valid sample at or before each position, then gather. Because the
Array-API standard lacks a cumulative max, it is built from standard ops
(``maximum`` + shifts) as a Hillis-Steele scan -- valid positions carry
their (increasing) index and railed ones carry ``-1``, so the running
max is exactly the last valid index. O(n log n) but fully vectorized,
and only runs when ``rail_threshold`` is set.
"""
n = x.shape[0]
sample_shape = x.shape[1:]
ar = xp.reshape(xp.arange(n), (n,) + (1,) * (x.ndim - 1))
idx = xp.where(xp.abs(x) >= thresh, -1, ar) # index, or -1 where railed
shift = 1
while shift < n:
sentinel = xp.full((shift,) + sample_shape, -1, dtype=idx.dtype)
shifted = _concat(xp, is_mlx, [sentinel, idx[: n - shift]], axis=0)
idx = xp.maximum(idx, shifted)
shift *= 2
idx = xp.where(idx < 0, 0, idx) # leading rails -> first sample
return xp.take_along_axis(x, idx, axis=0)
def _process(self, message: AxisArray) -> AxisArray:
if self._passthrough:
return message
ax_idx = message.get_axis_idx("time")
x = message.data
xp, is_mlx = _namespace(x)
moved = ax_idx != 0
if moved:
x = xp.moveaxis(x, ax_idx, 0)
if self.settings.rail_threshold is not None:
x = self._fill_rails(x, self.settings.rail_threshold, xp, is_mlx)
st = self._state
fir = st.fir
n_taps = fir.shape[0]
n = x.shape[0]
# FIR via tap-sum, carrying n_taps-1 samples of history across chunks:
# y[i] = sum_k fir[k] * xext[(n_taps-1) - k + i], xext = [hist, x]
xext = _concat(xp, is_mlx, [st.hist, x], axis=0)
y = xp.zeros_like(x)
for k in range(n_taps):
y = y + fir[k] * xext[n_taps - 1 - k : n_taps - 1 - k + n]
st.hist = xext[-(n_taps - 1) :]
if moved:
y = xp.moveaxis(y, 0, ax_idx)
# Output sample i carries the bank-start signal delayed by bulk_delay
# samples; shift the time-axis offset so timestamps stay physical.
time_axis = message.axes["time"]
new_axis = replace(
time_axis,
offset=time_axis.offset - st.bulk_delay * time_axis.gain,
)
return replace(message, data=y, axes={**message.axes, "time": new_axis})
[docs]
class SamplingDelayAlignment(
BaseTransformerUnit[
SamplingDelayAlignmentSettings,
AxisArray,
AxisArray,
SamplingDelayAlignmentTransformer,
]
):
SETTINGS = SamplingDelayAlignmentSettings