Source code for ezmsg.sigproc.concat

"""Concatenate two AxisArray streams along an existing or new axis."""

from __future__ import annotations

import asyncio
import typing
from dataclasses import dataclass, field

import ezmsg.core as ez
import numpy as np
from ezmsg.util.messages.axisarray import AxisArray, AxisBase, CoordinateAxis
from ezmsg.util.messages.util import replace

# ---------------------------------------------------------------------------
# Shared helpers (also used by merge.py)
# ---------------------------------------------------------------------------


def _build_merged_coordinate_axis(
    axis_a: CoordinateAxis,
    axis_b: CoordinateAxis,
    relabel: bool,
    label_a: str,
    label_b: str,
) -> CoordinateAxis:
    """Build a merged CoordinateAxis from two per-input axes.

    Handles both simple (string/numeric) and structured (numpy struct) dtypes.
    When *relabel* is True and the dtype is structured, only the ``"label"``
    field is modified (or created if absent).
    """
    data_a = axis_a.data
    data_b = axis_b.data

    if data_a.dtype.names is not None or data_b.dtype.names is not None:
        return _merge_struct_axes(data_a, data_b, relabel, label_a, label_b, axis_a)

    # Simple (non-struct) path — current behaviour.
    if relabel:
        labels_a = np.array([str(lbl) + label_a for lbl in data_a])
        labels_b = np.array([str(lbl) + label_b for lbl in data_b])
    else:
        labels_a = data_a
        labels_b = data_b
    return CoordinateAxis(
        data=np.concatenate([labels_a, labels_b]),
        dims=axis_a.dims,
        unit=axis_a.unit,
    )


def _merge_struct_axes(
    data_a: np.ndarray,
    data_b: np.ndarray,
    relabel: bool,
    label_a: str,
    label_b: str,
    ref_axis: CoordinateAxis,
) -> CoordinateAxis:
    """Merge two structured-dtype coordinate arrays, preserving all fields."""
    names_a = set(data_a.dtype.names or ())
    names_b = set(data_b.dtype.names or ())

    # Build the union dtype.  Shared fields must have compatible sub-dtypes.
    union_fields: list[tuple[str, np.dtype]] = []
    seen: set[str] = set()

    for src_names, src_dtype in [
        (data_a.dtype.names or (), data_a.dtype),
        (data_b.dtype.names or (), data_b.dtype),
    ]:
        for name in src_names:
            if name in seen:
                continue
            seen.add(name)
            dt_a = data_a.dtype[name] if name in names_a else None
            dt_b = data_b.dtype[name] if name in names_b else None
            if dt_a is not None and dt_b is not None:
                resolved = _resolve_field_dtype(name, dt_a, dt_b)
            else:
                resolved = dt_a if dt_a is not None else dt_b
            union_fields.append((name, resolved))

    # If relabel and "label" is not already a field, add it.
    has_label = "label" in seen
    if relabel and not has_label:
        max_len = max(
            max((len(str(i)) for i in range(len(data_a))), default=1),
            max((len(str(i)) for i in range(len(data_b))), default=1),
        )
        suffix_len = max(len(label_a), len(label_b))
        union_fields.append(("label", np.dtype(f"U{max_len + suffix_len}")))
        has_label = True

    union_dtype = np.dtype(union_fields)
    merged = np.zeros(len(data_a) + len(data_b), dtype=union_dtype)

    # Copy values from A.
    for name in data_a.dtype.names or ():
        merged[name][: len(data_a)] = data_a[name]
    # Copy values from B.
    for name in data_b.dtype.names or ():
        merged[name][len(data_a) :] = data_b[name]

    # Relabel only the "label" field.
    if relabel and has_label:
        for i in range(len(data_a)):
            src = str(data_a[i]["label"]) if "label" in names_a else str(i)
            merged[i]["label"] = src + label_a
        for j in range(len(data_b)):
            src = str(data_b[j]["label"]) if "label" in names_b else str(j)
            merged[len(data_a) + j]["label"] = src + label_b

    return CoordinateAxis(data=merged, dims=ref_axis.dims, unit=ref_axis.unit)


def _resolve_field_dtype(name: str, dt_a: np.dtype, dt_b: np.dtype) -> np.dtype:
    """Resolve a shared struct field's dtype.  String fields use the wider width."""
    if dt_a == dt_b:
        return dt_a
    if dt_a.kind == "U" and dt_b.kind == "U":
        return np.dtype(f"U{max(dt_a.itemsize // 4, dt_b.itemsize // 4)}")
    raise ValueError(f"Incompatible dtypes for shared struct field {name!r}: {dt_a} vs {dt_b}")


def _validate_shared_axes(
    a: AxisArray,
    b: AxisArray,
    concat_dim: str,
    align_dim: str | None,
    assert_flag: bool,
) -> None:
    """Raise ValueError if shared CoordinateAxis .data arrays differ."""
    if not assert_flag:
        return
    skip = {concat_dim, align_dim}
    for name in a.axes:
        if name in skip or name not in b.axes:
            continue
        ax_a, ax_b = a.axes[name], b.axes[name]
        if hasattr(ax_a, "data") and hasattr(ax_b, "data"):
            if not np.array_equal(ax_a.data, ax_b.data):
                raise ValueError(f"Shared axis {name!r} has different .data between inputs A and B")
        if hasattr(ax_a, "gain") and hasattr(ax_b, "gain"):
            if ax_a.gain != ax_b.gain:
                raise ValueError(f"Shared axis {name!r} has different gain: {ax_a.gain} vs {ax_b.gain}")


def _build_cached_axes(
    a: AxisArray,
    concat_dim: str,
    align_dim: str | None,
    merged_concat_axis: CoordinateAxis | None,
) -> dict[str, AxisBase]:
    """Build the output axes dict (everything except the alignment axis)."""
    axes: dict[str, AxisBase] = {}
    for name, ax in a.axes.items():
        if name == align_dim:
            continue
        if name == concat_dim and merged_concat_axis is not None:
            axes[name] = merged_concat_axis
        else:
            axes[name] = ax
    if concat_dim not in axes and merged_concat_axis is not None:
        axes[concat_dim] = merged_concat_axis
    return axes


# ---------------------------------------------------------------------------
# ConcatProcessor / Concat unit
# ---------------------------------------------------------------------------


[docs] class ConcatSettings(ez.Settings): axis: str = "ch" """Axis along which to concatenate the two signals.""" align_axis: str | None = None """Axis along which to validate alignment between the two signals.""" relabel_axis: bool = True """Whether to relabel coordinate axis labels to ensure uniqueness.""" label_a: str = "_a" """Suffix appended to signal A labels when relabel_axis is True.""" label_b: str = "_b" """Suffix appended to signal B labels when relabel_axis is True.""" assert_identical_shared_axes: bool = False """If True, raise ValueError when shared CoordinateAxis .data arrays differ.""" new_key: str | None = None """Output AxisArray key. If None, uses the key from signal A."""
[docs] @dataclass class ConcatState: queue_a: "asyncio.Queue[AxisArray]" = field(default_factory=asyncio.Queue) queue_b: "asyncio.Queue[AxisArray]" = field(default_factory=asyncio.Queue) merged_concat_axis: CoordinateAxis | None = None cached_axes: dict[str, AxisBase] | None = None # Fingerprints for cache invalidation. a_fingerprint: tuple | None = None b_fingerprint: tuple | None = None
[docs] class ConcatProcessor: """Concatenate paired AxisArray messages from two input queues. Uses FIFO queue pairing (like :class:`~ezmsg.sigproc.math.add.AddProcessor`). No time-alignment or buffering — inputs are assumed pre-synchronized. """
[docs] def __init__(self, settings: ConcatSettings): self.settings = settings self._state = ConcatState()
@property def state(self) -> ConcatState: return self._state @state.setter def state(self, state: ConcatState | bytes | None) -> None: if state is not None: self._state = state
[docs] def push_a(self, msg: AxisArray) -> None: self._state.queue_a.put_nowait(msg)
[docs] def push_b(self, msg: AxisArray) -> None: self._state.queue_b.put_nowait(msg)
async def __acall__(self) -> AxisArray: a = await self._state.queue_a.get() b = await self._state.queue_b.get() return self._concat(a, b) def _concat(self, a: AxisArray, b: AxisArray) -> AxisArray: """Concatenate *a* and *b* along the configured axis.""" concat_dim = self.settings.axis fp_a = self._fingerprint(a) fp_b = self._fingerprint(b) if fp_a != self._state.a_fingerprint or fp_b != self._state.b_fingerprint: self._rebuild_cache(a, b) self._state.a_fingerprint = fp_a self._state.b_fingerprint = fp_b new_axis = concat_dim not in a.dims # expand_dims for new-axis concatenation. if new_axis: a = replace(a, data=np.expand_dims(a.data, axis=-1), dims=[*a.dims, concat_dim]) b = replace(b, data=np.expand_dims(b.data, axis=-1), dims=[*b.dims, concat_dim]) concat_idx = a.dims.index(concat_dim) data = np.concatenate([a.data, b.data], axis=concat_idx) # Build axes: use cached axes + live alignment axis from a. axes = dict(self._state.cached_axes) if self._state.cached_axes is not None else dict(a.axes) # Re-insert any axis that changes per-message (e.g. time offset). for name, ax in a.axes.items(): if name not in axes: axes[name] = ax key = self.settings.new_key if self.settings.new_key is not None else a.key return AxisArray(data, dims=list(a.dims), axes=axes, key=key) def _fingerprint(self, msg: AxisArray) -> tuple: concat_dim = self.settings.axis ax = msg.axes.get(concat_dim) ax_hash = hash(ax.data.tobytes()) if ax is not None and hasattr(ax, "data") else None return (tuple(msg.dims), msg.data.shape, ax_hash) def _rebuild_cache(self, a: AxisArray, b: AxisArray) -> None: concat_dim = self.settings.axis # Validate shared axes. _validate_shared_axes( a, b, concat_dim, align_dim=self.settings.align_axis, assert_flag=self.settings.assert_identical_shared_axes, ) # New-axis validation: all other dims must match. if concat_dim not in a.dims or concat_dim not in b.dims: for i, (d, sa, sb) in enumerate(zip(a.dims, a.data.shape, b.data.shape)): if sa != sb: raise ValueError( f"Cannot concatenate along new axis {concat_dim!r}: " f"dimension {d!r} has size {sa} in A but {sb} in B" ) # Build merged concat axis. ax_a = a.axes.get(concat_dim) ax_b = b.axes.get(concat_dim) if ax_a is not None and ax_b is not None and hasattr(ax_a, "data") and hasattr(ax_b, "data"): self._state.merged_concat_axis = _build_merged_coordinate_axis( ax_a, ax_b, relabel=self.settings.relabel_axis, label_a=self.settings.label_a, label_b=self.settings.label_b, ) else: self._state.merged_concat_axis = None self._state.cached_axes = _build_cached_axes( a, concat_dim, align_dim=self.settings.align_axis, merged_concat_axis=self._state.merged_concat_axis, )
[docs] class Concat(ez.Unit): """Concatenate two AxisArray streams along an axis. Pairs messages by arrival order (FIFO). No time-alignment. """ SETTINGS = ConcatSettings INPUT_SIGNAL_A = ez.InputStream(AxisArray) INPUT_SIGNAL_B = ez.InputStream(AxisArray) OUTPUT_SIGNAL = ez.OutputStream(AxisArray)
[docs] async def initialize(self) -> None: self.processor = ConcatProcessor(self.SETTINGS)
[docs] @ez.subscriber(INPUT_SIGNAL_A) async def on_a(self, msg: AxisArray) -> None: self.processor.push_a(msg)
[docs] @ez.subscriber(INPUT_SIGNAL_B) async def on_b(self, msg: AxisArray) -> None: self.processor.push_b(msg)
[docs] @ez.publisher(OUTPUT_SIGNAL) async def output(self) -> typing.AsyncGenerator: while True: yield self.OUTPUT_SIGNAL, await self.processor.__acall__()