"""Time-aligned merge of two AxisArray streams along a non-time axis."""
from __future__ import annotations
import math
import typing
import ezmsg.core as ez
import numpy as np
from array_api_compat import get_namespace
from ezmsg.baseproc.protocols import processor_state
from ezmsg.baseproc.stateful import BaseStatefulTransformer
from ezmsg.baseproc.units import BaseProcessorUnit
from ezmsg.util.messages.axisarray import AxisArray, CoordinateAxis
from ezmsg.util.messages.util import replace
from .util.axisarray_buffer import HybridAxisArrayBuffer
[docs]
class MergeSettings(ez.Settings):
axis: str = "ch"
"""Axis along which to concatenate the two signals."""
align_axis: str | None = "time"
"""Axis used for alignment. If None, defaults to the first dimension."""
buffer_dur: float = 10.0
"""Buffer duration in seconds for each input stream."""
relabel_axis: bool = True
"""Whether to relabel coordinate axis labels to ensure uniqueness."""
label_a: str = "_a"
"""Suffix appended to signal A labels when relabel_axis is True."""
label_b: str = "_b"
"""Suffix appended to signal B labels when relabel_axis is True."""
new_key: str | None = None
"""Output AxisArray key. If None, uses the key from signal A."""
[docs]
@processor_state
class MergeState:
# Common state
gain: float | None = None
align_axis: str | None = None
aligned: bool = False
merged_concat_axis: CoordinateAxis | None = None
# A state
buf_a: HybridAxisArrayBuffer | None = None
concat_axis_a: CoordinateAxis | None = None
a_concat_dim: int | None = None
a_other_dims: tuple[int, ...] | None = None
# B state
buf_b: HybridAxisArrayBuffer | None = None
concat_axis_b: CoordinateAxis | None = None
b_concat_dim: int | None = None
b_other_dims: tuple[int, ...] | None = None
[docs]
class MergeProcessor(BaseStatefulTransformer[MergeSettings, AxisArray, AxisArray | None, MergeState]):
"""Processor that time-aligns two AxisArray streams and concatenates them.
Input A flows through the standard ``__call__`` / ``_process`` path,
getting automatic ``_hash_message`` / ``_reset_state`` handling from
:class:`BaseStatefulTransformer`. Input B flows through :meth:`push_b`,
which independently tracks its own structure.
Invalidation rules:
- Gain mismatch (either input vs stored common gain) → full reset.
- Concat-axis dimensionality change → per-input buffer reset +
alignment and merged-axis cache invalidation.
- Non-align/non-concat axis shape change → per-input buffer reset +
alignment invalidation.
"""
# -- Structural extraction helpers ---------------------------------------
def _extract_gain(self, message: AxisArray) -> float | None:
"""Extract the align-axis gain from a message."""
align_name = self.settings.align_axis or message.dims[0]
ax = message.axes.get(align_name)
if ax is not None and hasattr(ax, "gain"):
return ax.gain
if ax is not None and hasattr(ax, "data") and len(ax.data) > 1:
return float(ax.data[-1] - ax.data[0]) / (len(ax.data) - 1)
return None
# -- Reset helpers -------------------------------------------------------
def _full_reset(self, align_axis: str) -> None:
"""Reset all state — both inputs and common merge state."""
self._state.align_axis = align_axis
self._state.buf_a = HybridAxisArrayBuffer(duration=self.settings.buffer_dur, axis=align_axis)
self._state.buf_b = HybridAxisArrayBuffer(duration=self.settings.buffer_dur, axis=align_axis)
self._state.gain = None
self._state.aligned = False
self._state.concat_axis_a = None
self._state.concat_axis_b = None
self._state.merged_concat_axis = None
self._state.a_concat_dim = None
self._state.a_other_dims = None
self._state.b_concat_dim = None
self._state.b_other_dims = None
def _reset_a_state(self) -> None:
"""Reset input-A buffer and concat-axis cache."""
self._state.buf_a = HybridAxisArrayBuffer(duration=self.settings.buffer_dur, axis=self._state.align_axis)
self._state.concat_axis_a = None
def _reset_b_state(self) -> None:
"""Reset input-B buffer and concat-axis cache."""
self._state.buf_b = HybridAxisArrayBuffer(duration=self.settings.buffer_dur, axis=self._state.align_axis)
self._state.concat_axis_b = None
# -- BaseStatefulTransformer interface ------------------------------------
def _hash_message(self, message: AxisArray) -> int:
"""Hash the align-axis gain only.
Gain changes trigger a full reset via ``_reset_state``. Concat-axis
and non-merge dimension changes are handled as partial resets inside
``_process`` and ``push_b``.
"""
return hash(self._extract_gain(message))
def _reset_state(self, message: AxisArray) -> None:
"""Full reset — called by the base class when gain changes."""
align_axis = self.settings.align_axis or message.dims[0]
self._full_reset(align_axis)
def _process(self, message: AxisArray) -> AxisArray | None:
"""Process input A: detect structural changes, buffer, try merge."""
# Detect per-input structural changes.
align_idx = message.dims.index(self._state.align_axis)
concat_idx = message.dims.index(self.settings.axis) if self.settings.axis in message.dims else None
concat_dim = message.data.shape[concat_idx] if concat_idx is not None else None
other_dims = tuple(s for i, s in enumerate(message.data.shape) if i != align_idx and i != concat_idx)
if self._state.a_concat_dim is not None and concat_dim != self._state.a_concat_dim:
self._reset_a_state()
self._state.aligned = False
self._state.merged_concat_axis = None
elif self._state.a_other_dims is not None and other_dims != self._state.a_other_dims:
self._reset_a_state()
self._state.aligned = False
self._state.a_concat_dim = concat_dim
self._state.a_other_dims = other_dims
self._state.buf_a.write(message)
if self._state.gain is None:
self._state.gain = self._state.buf_a.axis_gain
self._update_concat_axis(message, "a")
return self._try_merge()
# -- Input B entry point ------------------------------------------------
[docs]
def push_b(self, message: AxisArray) -> AxisArray | None:
"""Process input B: check gain, detect structural changes, buffer, try merge."""
align_axis = self.settings.align_axis or message.dims[0]
# Gain compatibility check.
b_gain = self._extract_gain(message)
if self._state.gain is not None and not math.isclose(b_gain, self._state.gain):
self._full_reset(align_axis)
# Set the base-class hash so the next compatible A goes straight
# to _process instead of triggering another full reset.
self._hash = self._hash_message(message)
# Lazy-create buf_b if B arrives before A.
if self._state.buf_b is None:
if self._state.align_axis is None:
self._state.align_axis = align_axis
self._state.buf_b = HybridAxisArrayBuffer(duration=self.settings.buffer_dur, axis=align_axis)
# Detect per-input structural changes.
align_idx = message.dims.index(align_axis)
concat_idx = message.dims.index(self.settings.axis) if self.settings.axis in message.dims else None
concat_dim = message.data.shape[concat_idx] if concat_idx is not None else None
other_dims = tuple(s for i, s in enumerate(message.data.shape) if i != align_idx and i != concat_idx)
if self._state.b_concat_dim is not None and concat_dim != self._state.b_concat_dim:
self._reset_b_state()
self._state.aligned = False
self._state.merged_concat_axis = None
elif self._state.b_other_dims is not None and other_dims != self._state.b_other_dims:
self._reset_b_state()
self._state.aligned = False
self._state.b_concat_dim = concat_dim
self._state.b_other_dims = other_dims
self._state.buf_b.write(message)
if self._state.gain is None:
self._state.gain = self._state.buf_b.axis_gain
self._update_concat_axis(message, "b")
return self._try_merge()
# -- Concat-axis caching ------------------------------------------------
def _update_concat_axis(self, message: AxisArray, which: str) -> None:
"""Track each input's concat-axis labels; invalidate cache on change."""
concat_dim = self.settings.axis
if concat_dim not in message.axes:
return
ax = message.axes[concat_dim]
if not hasattr(ax, "data"):
return
if which == "a":
if self._state.concat_axis_a is None or not np.array_equal(self._state.concat_axis_a.data, ax.data):
self._state.concat_axis_a = ax
self._state.merged_concat_axis = None
else:
if self._state.concat_axis_b is None or not np.array_equal(self._state.concat_axis_b.data, ax.data):
self._state.concat_axis_b = ax
self._state.merged_concat_axis = None
def _build_merged_concat_axis(self) -> CoordinateAxis | None:
"""Build the merged CoordinateAxis from the two cached per-input axes."""
if self._state.concat_axis_a is None or self._state.concat_axis_b is None:
return None
if self.settings.relabel_axis:
labels_a = np.array([str(lbl) + self.settings.label_a for lbl in self._state.concat_axis_a.data])
labels_b = np.array([str(lbl) + self.settings.label_b for lbl in self._state.concat_axis_b.data])
else:
labels_a = self._state.concat_axis_a.data
labels_b = self._state.concat_axis_b.data
return CoordinateAxis(
data=np.concatenate([labels_a, labels_b]),
dims=self._state.concat_axis_a.dims,
unit=self._state.concat_axis_a.unit,
)
# -- Core merge logic ---------------------------------------------------
def _try_merge(self) -> AxisArray | None:
"""Align and read from both buffers, returning the merged result.
Initial alignment is performed once. After the first successful
merge the two streams are assumed to share a common clock and
never drop samples, so we simply read
``min(available_a, available_b)`` on every subsequent call.
"""
if self._state.buf_a is None or self._state.buf_b is None:
return None
if self._state.buf_a.is_empty() or self._state.buf_b.is_empty():
return None
gain = self._state.gain
# --- Initial alignment (runs only until the first successful merge) ---
if not self._state.aligned:
first_a = self._state.buf_a.axis_first_value
final_a = self._state.buf_a.axis_final_value
first_b = self._state.buf_b.axis_first_value
final_b = self._state.buf_b.axis_final_value
overlap_start = max(first_a, first_b)
overlap_end = min(final_a, final_b)
if overlap_end < overlap_start - gain / 2:
if final_a < first_b:
self._state.buf_a.seek(self._state.buf_a.available())
elif final_b < first_a:
self._state.buf_b.seek(self._state.buf_b.available())
return None
if first_a < overlap_start - gain / 2:
self._state.buf_a.seek(int(round((overlap_start - first_a) / gain)))
if first_b < overlap_start - gain / 2:
self._state.buf_b.seek(int(round((overlap_start - first_b) / gain)))
# --- Read aligned samples ---
n_read = min(self._state.buf_a.available(), self._state.buf_b.available())
if n_read <= 0:
return None
aa_a = self._state.buf_a.read(n_read)
aa_b = self._state.buf_b.read(n_read)
if aa_a is None or aa_b is None:
return None
if not self._state.aligned:
axis_a = aa_a.axes.get(self._state.align_axis)
axis_b = aa_b.axes.get(self._state.align_axis)
if axis_a is not None and axis_b is not None:
off_a = axis_a.value(0) if hasattr(axis_a, "value") else None
off_b = axis_b.value(0) if hasattr(axis_b, "value") else None
if off_a is not None and off_b is not None:
if not np.isclose(off_a, off_b, atol=abs(gain) * 1e-6):
raise RuntimeError(
f"Offset mismatch after alignment: " f"off_a={off_a}, off_b={off_b}, gain={gain}"
)
self._state.aligned = True
return self._concat(aa_a, aa_b)
def _concat(self, a: AxisArray, b: AxisArray) -> AxisArray:
"""Concatenate *a* and *b* along the configured merge axis."""
merge_dim = self.settings.axis
# If the merge dim doesn't exist in an input, add it as a trailing axis.
if merge_dim not in a.dims:
xp = get_namespace(a.data)
a = replace(a, data=xp.expand_dims(a.data, axis=-1), dims=[*a.dims, merge_dim])
if merge_dim not in b.dims:
xp = get_namespace(b.data)
b = replace(b, data=xp.expand_dims(b.data, axis=-1), dims=[*b.dims, merge_dim])
# Use the cached merged axis (rebuilt lazily when labels change).
if self._state.merged_concat_axis is None:
self._state.merged_concat_axis = self._build_merged_concat_axis()
key = self.settings.new_key if self.settings.new_key is not None else a.key
result = AxisArray.concatenate(a, b, dim=merge_dim, axis=self._state.merged_concat_axis)
if key != result.key:
result = replace(result, key=key)
return result
[docs]
class Merge(BaseProcessorUnit[MergeSettings]):
"""Merge two AxisArray streams by time-aligning and concatenating along a non-time axis.
Input A routes through the processor's ``__acall__`` (triggering
hash-based reset when the stream structure changes). Input B
routes through ``push_b`` which independently tracks its own structure.
Inherits ``INPUT_SETTINGS`` and ``on_settings`` → ``create_processor``
from :class:`BaseProcessorUnit`.
"""
SETTINGS = MergeSettings
INPUT_SIGNAL_A = ez.InputStream(AxisArray)
INPUT_SIGNAL_B = ez.InputStream(AxisArray)
OUTPUT_SIGNAL = ez.OutputStream(AxisArray)
[docs]
def create_processor(self) -> None:
self.processor = MergeProcessor(settings=self.SETTINGS)
[docs]
@ez.subscriber(INPUT_SIGNAL_A, zero_copy=True)
@ez.publisher(OUTPUT_SIGNAL)
async def on_a(self, msg: AxisArray) -> typing.AsyncGenerator:
result = await self.processor.__acall__(msg)
if result is not None:
yield self.OUTPUT_SIGNAL, result
[docs]
@ez.subscriber(INPUT_SIGNAL_B, zero_copy=True)
@ez.publisher(OUTPUT_SIGNAL)
async def on_b(self, msg: AxisArray) -> typing.AsyncGenerator:
result = self.processor.push_b(msg)
if result is not None:
yield self.OUTPUT_SIGNAL, result