Source code for ezmsg.sigproc.align

"""Time-align two AxisArray streams, outputting paired aligned chunks."""

from __future__ import annotations

import math
import typing

import ezmsg.core as ez
import numpy as np
from ezmsg.baseproc.protocols import processor_state
from ezmsg.baseproc.stateful import BaseStatefulTransformer
from ezmsg.util.messages.axisarray import AxisArray

from .util.axisarray_buffer import HybridAxisArrayBuffer


[docs] class AlignAlongAxisSettings(ez.Settings): axis: str = "time" """Axis used for alignment (typically the time axis).""" buffer_dur: float = 10.0 """Buffer duration in seconds for each input stream."""
[docs] @processor_state class AlignAlongAxisState: gain: float | None = None align_axis: str | None = None aligned: bool = False buf_a: HybridAxisArrayBuffer | None = None buf_b: HybridAxisArrayBuffer | None = None # Per-input non-alignment shape for reset detection. a_shape_sig: tuple[int, ...] | None = None b_shape_sig: tuple[int, ...] | None = None
_AlignPair = tuple[AxisArray, AxisArray]
[docs] class AlignAlongAxisProcessor( BaseStatefulTransformer[ AlignAlongAxisSettings, AxisArray, _AlignPair | None, AlignAlongAxisState, ] ): """Processor that time-aligns two AxisArray streams. Input A flows through ``__call__`` / ``_process`` with automatic hash-based reset. Input B flows through :meth:`push_b`. Returns ``(aligned_a, aligned_b)`` when alignment succeeds, else ``None``. """ # -- Helpers ------------------------------------------------------------- def _extract_gain(self, message: AxisArray) -> float | None: align_name = self.settings.axis or message.dims[0] ax = message.axes.get(align_name) if ax is not None and hasattr(ax, "gain"): return ax.gain if ax is not None and hasattr(ax, "data") and len(ax.data) > 1: return float(ax.data[-1] - ax.data[0]) / (len(ax.data) - 1) return None @staticmethod def _non_align_shape(message: AxisArray, align_axis: str) -> tuple[int, ...]: align_idx = message.dims.index(align_axis) return tuple(s for i, s in enumerate(message.data.shape) if i != align_idx) # -- Reset helpers ------------------------------------------------------- def _full_reset(self, align_axis: str) -> None: """ Reset state. Called either on input A (through default __call__ path) or on Input B. Args: align_axis: Returns: """ self._state.align_axis = align_axis self._state.buf_a = HybridAxisArrayBuffer(duration=self.settings.buffer_dur, axis=align_axis) self._state.buf_b = HybridAxisArrayBuffer(duration=self.settings.buffer_dur, axis=align_axis) self._state.gain = None self._state.aligned = False self._state.a_shape_sig = None self._state.b_shape_sig = None def _reset_a_state(self) -> None: self._state.buf_a = HybridAxisArrayBuffer(duration=self.settings.buffer_dur, axis=self._state.align_axis) def _reset_b_state(self) -> None: self._state.buf_b = HybridAxisArrayBuffer(duration=self.settings.buffer_dur, axis=self._state.align_axis) # -- BaseStatefulTransformer interface ------------------------------------ def _hash_message(self, message: AxisArray) -> int: return hash(self._extract_gain(message)) def _reset_state(self, message: AxisArray) -> None: align_axis = self.settings.axis or message.dims[0] self._full_reset(align_axis) def _process(self, message: AxisArray) -> _AlignPair | None: """Process input A: detect shape changes, buffer, try align.""" shape_sig = self._non_align_shape(message, self._state.align_axis) if self._state.a_shape_sig is not None and shape_sig != self._state.a_shape_sig: self._reset_a_state() self._state.aligned = False self._state.a_shape_sig = shape_sig self._state.buf_a.write(message) if self._state.gain is None: self._state.gain = self._state.buf_a.axis_gain return self._try_align() # -- Input B entry point ------------------------------------------------
[docs] def push_b(self, message: AxisArray) -> _AlignPair | None: """Process input B: check gain, detect shape changes, buffer, try align.""" align_axis = self.settings.axis or message.dims[0] # Gain compatibility check. b_gain = self._extract_gain(message) if self._state.gain is not None and not math.isclose(b_gain, self._state.gain): self._full_reset(align_axis) self._hash = self._hash_message(message) # Lazy-create buf_b if B arrives before A. if self._state.buf_b is None: if self._state.align_axis is None: self._state.align_axis = align_axis self._state.buf_b = HybridAxisArrayBuffer(duration=self.settings.buffer_dur, axis=align_axis) shape_sig = self._non_align_shape(message, align_axis) if self._state.b_shape_sig is not None and shape_sig != self._state.b_shape_sig: self._reset_b_state() self._state.aligned = False self._state.b_shape_sig = shape_sig self._state.buf_b.write(message) if self._state.gain is None: self._state.gain = self._state.buf_b.axis_gain return self._try_align()
# -- Core alignment logic ----------------------------------------------- def _try_align(self) -> _AlignPair | None: """Align and read from both buffers, returning the pair ``(a, b)``.""" if self._state.buf_a is None or self._state.buf_b is None: return None if self._state.buf_a.is_empty() or self._state.buf_b.is_empty(): return None gain = self._state.gain # --- Initial alignment (runs once) --- if not self._state.aligned: first_a = self._state.buf_a.axis_first_value final_a = self._state.buf_a.axis_final_value first_b = self._state.buf_b.axis_first_value final_b = self._state.buf_b.axis_final_value overlap_start = max(first_a, first_b) overlap_end = min(final_a, final_b) if overlap_end < overlap_start - gain / 2: if final_a < first_b: self._state.buf_a.seek(self._state.buf_a.available()) elif final_b < first_a: self._state.buf_b.seek(self._state.buf_b.available()) return None if first_a < overlap_start - gain / 2: self._state.buf_a.seek(int(round((overlap_start - first_a) / gain))) if first_b < overlap_start - gain / 2: self._state.buf_b.seek(int(round((overlap_start - first_b) / gain))) # --- Read aligned samples --- n_read = min(self._state.buf_a.available(), self._state.buf_b.available()) if n_read <= 0: return None aa_a = self._state.buf_a.read(n_read) aa_b = self._state.buf_b.read(n_read) if aa_a is None or aa_b is None: return None if not self._state.aligned: axis_a = aa_a.axes.get(self._state.align_axis) axis_b = aa_b.axes.get(self._state.align_axis) if axis_a is not None and axis_b is not None: off_a = axis_a.value(0) if hasattr(axis_a, "value") else None off_b = axis_b.value(0) if hasattr(axis_b, "value") else None if off_a is not None and off_b is not None: if not np.isclose(off_a, off_b, atol=abs(gain) * 1e-6): raise RuntimeError( f"Offset mismatch after alignment: " f"off_a={off_a}, off_b={off_b}, gain={gain}" ) self._state.aligned = True return aa_a, aa_b
[docs] class AlignAlongAxis(ez.Unit): """Time-align two AxisArray streams and output paired aligned chunks. Each subscriber can publish to *both* output streams; when alignment succeeds, a paired (A, B) result is yielded to the respective outputs. """ SETTINGS = AlignAlongAxisSettings INPUT_SIGNAL_A = ez.InputStream(AxisArray) INPUT_SIGNAL_B = ez.InputStream(AxisArray) OUTPUT_SIGNAL_A = ez.OutputStream(AxisArray) OUTPUT_SIGNAL_B = ez.OutputStream(AxisArray)
[docs] async def initialize(self) -> None: self.processor = AlignAlongAxisProcessor(settings=self.SETTINGS)
[docs] @ez.subscriber(INPUT_SIGNAL_A) @ez.publisher(OUTPUT_SIGNAL_A) @ez.publisher(OUTPUT_SIGNAL_B) async def on_a(self, msg: AxisArray) -> typing.AsyncGenerator: pair = await self.processor.__acall__(msg) if pair is not None: yield self.OUTPUT_SIGNAL_A, pair[0] yield self.OUTPUT_SIGNAL_B, pair[1]
[docs] @ez.subscriber(INPUT_SIGNAL_B) @ez.publisher(OUTPUT_SIGNAL_A) @ez.publisher(OUTPUT_SIGNAL_B) async def on_b(self, msg: AxisArray) -> typing.AsyncGenerator: pair = self.processor.push_b(msg) if pair is not None: yield self.OUTPUT_SIGNAL_A, pair[0] yield self.OUTPUT_SIGNAL_B, pair[1]