Source code for ezmsg.sigproc.concat

"""Concatenate two AxisArray streams along an existing or new axis."""

from __future__ import annotations

import asyncio
import logging
import typing
from dataclasses import dataclass, field

import ezmsg.core as ez
import numpy as np
from array_api_compat import get_namespace
from ezmsg.util.messages.axisarray import AxisArray, AxisBase, CoordinateAxis
from ezmsg.util.messages.util import replace

logger = logging.getLogger(__name__)

# Sentinel for "attr key was missing on this side". Distinct from any user value.
_MISSING = object()

# ---------------------------------------------------------------------------
# Shared helpers (also used by merge.py)
# ---------------------------------------------------------------------------


def _build_merged_coordinate_axis(
    axis_a: CoordinateAxis,
    axis_b: CoordinateAxis,
    relabel: bool,
    label_a: str,
    label_b: str,
) -> CoordinateAxis:
    """Build a merged CoordinateAxis from two per-input axes.

    Handles both simple (string/numeric) and structured (numpy struct) dtypes.
    When *relabel* is True and the dtype is structured, only the ``"label"``
    field is modified (or created if absent).
    """
    data_a = axis_a.data
    data_b = axis_b.data

    if data_a.dtype.names is not None or data_b.dtype.names is not None:
        return _merge_struct_axes(data_a, data_b, relabel, label_a, label_b, axis_a)

    # Simple (non-struct) path — current behaviour.
    if relabel:
        labels_a = np.array([str(lbl) + label_a for lbl in data_a])
        labels_b = np.array([str(lbl) + label_b for lbl in data_b])
    else:
        labels_a = data_a
        labels_b = data_b
    return CoordinateAxis(
        data=np.concatenate([labels_a, labels_b]),
        dims=axis_a.dims,
        unit=axis_a.unit,
    )


def _merge_struct_axes(
    data_a: np.ndarray,
    data_b: np.ndarray,
    relabel: bool,
    label_a: str,
    label_b: str,
    ref_axis: CoordinateAxis,
) -> CoordinateAxis:
    """Merge two structured-dtype coordinate arrays, preserving all fields."""
    names_a = set(data_a.dtype.names or ())
    names_b = set(data_b.dtype.names or ())

    # Build the union dtype.  Shared fields must have compatible sub-dtypes.
    union_fields: list[tuple[str, np.dtype]] = []
    seen: set[str] = set()

    for src_names, src_dtype in [
        (data_a.dtype.names or (), data_a.dtype),
        (data_b.dtype.names or (), data_b.dtype),
    ]:
        for name in src_names:
            if name in seen:
                continue
            seen.add(name)
            dt_a = data_a.dtype[name] if name in names_a else None
            dt_b = data_b.dtype[name] if name in names_b else None
            if dt_a is not None and dt_b is not None:
                resolved = _resolve_field_dtype(name, dt_a, dt_b)
            else:
                resolved = dt_a if dt_a is not None else dt_b
            union_fields.append((name, resolved))

    # If relabel and "label" is not already a field, add it.
    has_label = "label" in seen
    if relabel and not has_label:
        max_len = max(
            max((len(str(i)) for i in range(len(data_a))), default=1),
            max((len(str(i)) for i in range(len(data_b))), default=1),
        )
        suffix_len = max(len(label_a), len(label_b))
        union_fields.append(("label", np.dtype(f"U{max_len + suffix_len}")))
        has_label = True

    union_dtype = np.dtype(union_fields)
    merged = np.zeros(len(data_a) + len(data_b), dtype=union_dtype)

    # Copy values from A.
    for name in data_a.dtype.names or ():
        merged[name][: len(data_a)] = data_a[name]
    # Copy values from B.
    for name in data_b.dtype.names or ():
        merged[name][len(data_a) :] = data_b[name]

    # Relabel only the "label" field.
    if relabel and has_label:
        for i in range(len(data_a)):
            src = str(data_a[i]["label"]) if "label" in names_a else str(i)
            merged[i]["label"] = src + label_a
        for j in range(len(data_b)):
            src = str(data_b[j]["label"]) if "label" in names_b else str(j)
            merged[len(data_a) + j]["label"] = src + label_b

    return CoordinateAxis(data=merged, dims=ref_axis.dims, unit=ref_axis.unit)


def _resolve_field_dtype(name: str, dt_a: np.dtype, dt_b: np.dtype) -> np.dtype:
    """Resolve a shared struct field's dtype.  String fields use the wider width."""
    if dt_a == dt_b:
        return dt_a
    if dt_a.kind == "U" and dt_b.kind == "U":
        return np.dtype(f"U{max(dt_a.itemsize // 4, dt_b.itemsize // 4)}")
    raise ValueError(f"Incompatible dtypes for shared struct field {name!r}: {dt_a} vs {dt_b}")


# ---------------------------------------------------------------------------
# Attrs merging + promotion
# ---------------------------------------------------------------------------

_ALLOWED_ATTR_SCALARS = (str, int, float, bool, np.integer, np.floating)


def _check_attr_type(key: str, value: typing.Any) -> None:
    if not isinstance(value, _ALLOWED_ATTR_SCALARS):
        raise TypeError(
            f"Cannot merge/promote attrs key {key!r}: unsupported value type "
            f"{type(value).__name__}; only scalar str/int/float/bool are allowed."
        )


def _attrs_values_equal(a: typing.Any, b: typing.Any) -> bool:
    try:
        return bool(a == b)
    except Exception:
        return a is b


def _classify_attrs(a_attrs: dict, b_attrs: dict) -> tuple[dict, dict, dict]:
    """Split two attrs dicts into equal-shared vs side-to-promote.

    Returns ``(equal, promote_a, promote_b)``:
      * ``equal[k] = v`` — present in both with equal value; kept on output ``.attrs``.
      * ``promote_a[k]``/``promote_b[k]`` — value to use on each side's concat-axis
        elements. Use the ``_MISSING`` sentinel when the key was absent on that side.
    """
    equal: dict = {}
    promote_a: dict = {}
    promote_b: dict = {}
    a_attrs = a_attrs or {}
    b_attrs = b_attrs or {}
    for k in set(a_attrs) | set(b_attrs):
        a_has, b_has = k in a_attrs, k in b_attrs
        if a_has and b_has and _attrs_values_equal(a_attrs[k], b_attrs[k]):
            _check_attr_type(k, a_attrs[k])
            equal[k] = a_attrs[k]
            continue
        if a_has:
            _check_attr_type(k, a_attrs[k])
            promote_a[k] = a_attrs[k]
        else:
            promote_a[k] = _MISSING
        if b_has:
            _check_attr_type(k, b_attrs[k])
            promote_b[k] = b_attrs[k]
        else:
            promote_b[k] = _MISSING
    return equal, promote_a, promote_b


def _promoted_field_dtype(values: list) -> np.dtype:
    """Pick a numpy dtype that can hold the supplied promoted values."""
    non_missing = [v for v in values if v is not _MISSING]
    if not non_missing:
        return np.dtype("U1")
    if any(isinstance(v, str) for v in non_missing):
        max_len = max(len(str(v)) for v in non_missing)
        return np.dtype(f"U{max(max_len, 1)}")
    if any(isinstance(v, (float, np.floating)) for v in non_missing):
        return np.dtype("f8")
    # Booleans are ints in Python; if everything is bool, prefer bool.
    if all(isinstance(v, (bool, np.bool_)) for v in non_missing):
        return np.dtype("?")
    if any(isinstance(v, (int, np.integer)) for v in non_missing):
        return np.dtype("i8")
    return np.dtype("U1")


def _sentinel_for_dtype(dt: np.dtype) -> typing.Any:
    if dt.kind == "U":
        return ""
    if dt.kind == "f":
        return float("nan")
    if dt.kind == "i":
        return 0
    if dt.kind == "b":
        return False
    return None


def _extend_struct_with_fields(
    existing: np.ndarray,
    new_fields: list[tuple[str, np.dtype, np.ndarray]],
) -> np.ndarray:
    """Append new columns to a structured array, preserving existing columns.

    ``new_fields`` is a list of ``(name, dtype, values)`` triples where
    ``values`` has length ``len(existing)``.
    """
    union: list[tuple[str, np.dtype]] = [(n, existing.dtype[n]) for n in (existing.dtype.names or ())]
    union.extend((n, dt) for (n, dt, _) in new_fields)
    union_dtype = np.dtype(union)
    out = np.zeros(len(existing), dtype=union_dtype)
    for n in existing.dtype.names or ():
        out[n] = existing[n]
    for n, _, vals in new_fields:
        out[n] = vals
    return out


def _apply_promoted_attrs(
    merged_axis: CoordinateAxis | None,
    n_a: int,
    n_b: int,
    promote_a: dict,
    promote_b: dict,
    ref_axis: CoordinateAxis | None,
    concat_dim: str,
) -> CoordinateAxis | None:
    """Inject promoted attrs as per-element fields on the concat axis.

    If ``merged_axis`` has simple (non-structured) ``.data``, it is first
    converted to a structured array with a single ``"label"`` field.
    Keys that collide with an existing struct field are dropped with a warning
    (the per-element field already in place wins).
    """
    if not promote_a and not promote_b:
        return merged_axis

    promoted_keys = sorted(set(promote_a) | set(promote_b))

    base_data: np.ndarray | None = None
    if merged_axis is not None and merged_axis.data is not None:
        if merged_axis.data.dtype.names is not None:
            base_data = merged_axis.data
        else:
            labels = merged_axis.data
            base_data = np.zeros(len(labels), dtype=np.dtype([("label", labels.dtype)]))
            base_data["label"] = labels

    existing_names = set(base_data.dtype.names or ()) if base_data is not None else set()

    new_fields: list[tuple[str, np.dtype, np.ndarray]] = []
    for k in promoted_keys:
        if k in existing_names:
            logger.warning(
                "concat: attrs key %r collides with existing struct field on %r "
                "axis; dropping promoted attr (per-element field is authoritative).",
                k,
                concat_dim,
            )
            continue
        a_val = promote_a.get(k, _MISSING)
        b_val = promote_b.get(k, _MISSING)
        all_values = [a_val] * n_a + [b_val] * n_b
        dt = _promoted_field_dtype(all_values)
        sentinel = _sentinel_for_dtype(dt)
        full = np.empty(n_a + n_b, dtype=dt)
        for i, v in enumerate(all_values):
            full[i] = sentinel if v is _MISSING else v
        new_fields.append((k, dt, full))

    if not new_fields:
        return merged_axis

    if base_data is not None:
        merged_data = _extend_struct_with_fields(base_data, new_fields)
        dims = ref_axis.dims if ref_axis is not None else [concat_dim]
        unit = ref_axis.unit if ref_axis is not None else None
        return CoordinateAxis(data=merged_data, dims=dims, unit=unit)

    # No pre-existing axis data — synthesize one solely from promoted fields.
    dtype = np.dtype([(n, dt) for (n, dt, _) in new_fields])
    out = np.zeros(n_a + n_b, dtype=dtype)
    for n, _, vals in new_fields:
        out[n] = vals
    return CoordinateAxis(data=out, dims=[concat_dim])


def _validate_shared_axes(
    a: AxisArray,
    b: AxisArray,
    concat_dim: str,
    align_dim: str | None,
    assert_flag: bool,
) -> None:
    """Raise ValueError if shared CoordinateAxis .data arrays differ."""
    if not assert_flag:
        return
    skip = {concat_dim, align_dim}
    for name in a.axes:
        if name in skip or name not in b.axes:
            continue
        ax_a, ax_b = a.axes[name], b.axes[name]
        if hasattr(ax_a, "data") and hasattr(ax_b, "data"):
            if not np.array_equal(ax_a.data, ax_b.data):
                raise ValueError(f"Shared axis {name!r} has different .data between inputs A and B")
        if hasattr(ax_a, "gain") and hasattr(ax_b, "gain"):
            if ax_a.gain != ax_b.gain:
                raise ValueError(f"Shared axis {name!r} has different gain: {ax_a.gain} vs {ax_b.gain}")


def _build_cached_axes(
    a: AxisArray,
    concat_dim: str,
    align_dim: str | None,
    merged_concat_axis: CoordinateAxis | None,
) -> dict[str, AxisBase]:
    """Build the output axes dict (everything except the alignment axis)."""
    axes: dict[str, AxisBase] = {}
    for name, ax in a.axes.items():
        if name == align_dim:
            continue
        if name == concat_dim and merged_concat_axis is not None:
            axes[name] = merged_concat_axis
        else:
            axes[name] = ax
    if concat_dim not in axes and merged_concat_axis is not None:
        axes[concat_dim] = merged_concat_axis
    return axes


# ---------------------------------------------------------------------------
# ConcatProcessor / Concat unit
# ---------------------------------------------------------------------------


[docs] class ConcatSettings(ez.Settings): axis: str = "ch" """Axis along which to concatenate the two signals.""" align_axis: str | None = None """Axis along which to validate alignment between the two signals.""" relabel_axis: bool = True """Whether to relabel coordinate axis labels to ensure uniqueness.""" label_a: str = "_a" """Per-side label for signal A. Used in two distinct ways depending on whether ``axis`` is an existing or new dimension on the inputs: * **Existing axis** (``axis`` is in both inputs' ``.dims``): ``label_a`` is a *suffix* appended to each entry of A's existing coordinate-axis labels when ``relabel_axis`` is True. Defaults to ``"_a"``. * **New axis** (``axis`` is not in either input's ``.dims``): ``label_a`` is used as the single ``data`` entry on the merged axis's CoordinateAxis at index 0. E.g. setting ``label_a="spk", label_b="sbp"`` on a Merge of two ``(time, ch)`` streams produces a ``(time, ch, feature)`` output whose ``feature`` axis has ``data=["spk", "sbp"]``. """ label_b: str = "_b" """Per-side label for signal B. See :attr:`label_a`. Defaults to ``"_b"``; used as the new-axis label at index 1 in the new-axis case. """ assert_identical_shared_axes: bool = False """If True, raise ValueError when shared CoordinateAxis .data arrays differ.""" new_key: str | None = None """Output AxisArray key. If None, uses the key from signal A.""" auto_coerce_backend: bool = False """If True, silently coerce signal B to signal A's array namespace when the two inputs are on mismatched backends (e.g. MLX vs numpy). Defaults to False (strict): a backend mismatch raises a clear error instead, since it is almost always an upstream bug and silent coercion hides device<->host copies."""
[docs] @dataclass class ConcatState: queue_a: "asyncio.Queue[AxisArray]" = field(default_factory=asyncio.Queue) queue_b: "asyncio.Queue[AxisArray]" = field(default_factory=asyncio.Queue) merged_concat_axis: CoordinateAxis | None = None cached_axes: dict[str, AxisBase] | None = None merged_attrs: dict | None = None # Fingerprints for cache invalidation. a_fingerprint: tuple | None = None b_fingerprint: tuple | None = None
[docs] class ConcatProcessor: """Concatenate paired AxisArray messages from two input queues. Uses FIFO queue pairing (like :class:`~ezmsg.sigproc.math.add.AddProcessor`). No time-alignment or buffering — inputs are assumed pre-synchronized. """
[docs] def __init__(self, settings: ConcatSettings): self.settings = settings self._state = ConcatState()
@property def state(self) -> ConcatState: return self._state @state.setter def state(self, state: ConcatState | bytes | None) -> None: if state is not None: self._state = state
[docs] def push_a(self, msg: AxisArray) -> None: self._state.queue_a.put_nowait(msg)
[docs] def push_b(self, msg: AxisArray) -> None: self._state.queue_b.put_nowait(msg)
async def __acall__(self) -> AxisArray: a = await self._state.queue_a.get() b = await self._state.queue_b.get() return self._concat(a, b) def _concat(self, a: AxisArray, b: AxisArray) -> AxisArray: """Concatenate *a* and *b* along the configured axis.""" concat_dim = self.settings.axis fp_a = self._fingerprint(a) fp_b = self._fingerprint(b) if fp_a != self._state.a_fingerprint or fp_b != self._state.b_fingerprint: self._rebuild_cache(a, b) self._state.a_fingerprint = fp_a self._state.b_fingerprint = fp_b new_axis = concat_dim not in a.dims xp = get_namespace(a.data) xp_b = get_namespace(b.data) if xp_b is not xp: if self.settings.auto_coerce_backend: b = replace(b, data=xp.asarray(b.data)) else: raise TypeError( f"Concat received inputs on mismatched backends: " f"a.data namespace={xp.__name__}, b.data namespace={xp_b.__name__}. " f"Coerce both inputs to one backend upstream " f"(e.g., via ezmsg.sigproc.asarray.AsArrayTransformer) before merging, " f"or set ConcatSettings(auto_coerce_backend=True) to coerce B to A's backend." ) # expand_dims for new-axis concatenation. if new_axis: a = replace(a, data=xp.expand_dims(a.data, axis=-1), dims=[*a.dims, concat_dim]) b = replace(b, data=xp.expand_dims(b.data, axis=-1), dims=[*b.dims, concat_dim]) concat_idx = a.dims.index(concat_dim) data = xp.concat([a.data, b.data], axis=concat_idx) # Build axes: use cached axes + live alignment axis from a. axes = dict(self._state.cached_axes) if self._state.cached_axes is not None else dict(a.axes) # Re-insert any axis that changes per-message (e.g. time offset). for name, ax in a.axes.items(): if name not in axes: axes[name] = ax key = self.settings.new_key if self.settings.new_key is not None else a.key attrs = dict(self._state.merged_attrs) if self._state.merged_attrs else {} return AxisArray(data, dims=list(a.dims), axes=axes, key=key, attrs=attrs) def _fingerprint(self, msg: AxisArray) -> tuple: concat_dim = self.settings.axis ax = msg.axes.get(concat_dim) ax_hash = hash(ax.data.tobytes()) if ax is not None and hasattr(ax, "data") else None attrs_fp = frozenset((k, type(v).__name__, repr(v)) for k, v in (msg.attrs or {}).items()) return (tuple(msg.dims), msg.data.shape, ax_hash, attrs_fp) def _rebuild_cache(self, a: AxisArray, b: AxisArray) -> None: concat_dim = self.settings.axis # Validate shared axes. _validate_shared_axes( a, b, concat_dim, align_dim=self.settings.align_axis, assert_flag=self.settings.assert_identical_shared_axes, ) # New-axis validation: all other dims must match. if concat_dim not in a.dims or concat_dim not in b.dims: for i, (d, sa, sb) in enumerate(zip(a.dims, a.data.shape, b.data.shape)): if sa != sb: raise ValueError( f"Cannot concatenate along new axis {concat_dim!r}: " f"dimension {d!r} has size {sa} in A but {sb} in B" ) # Build merged concat axis. ax_a = a.axes.get(concat_dim) ax_b = b.axes.get(concat_dim) if ax_a is not None and ax_b is not None and hasattr(ax_a, "data") and hasattr(ax_b, "data"): self._state.merged_concat_axis = _build_merged_coordinate_axis( ax_a, ax_b, relabel=self.settings.relabel_axis, label_a=self.settings.label_a, label_b=self.settings.label_b, ) elif concat_dim not in a.dims and concat_dim not in b.dims: self._state.merged_concat_axis = CoordinateAxis( data=np.asarray([self.settings.label_a, self.settings.label_b]), dims=[concat_dim], ) else: self._state.merged_concat_axis = None # Merge .attrs across A and B. Equal-shared keys stay in attrs; differing # or partially-present keys are promoted to per-element fields on the # concat axis. equal_attrs, promote_a, promote_b = _classify_attrs(a.attrs, b.attrs) if promote_a or promote_b: if concat_dim in a.dims: n_a = a.data.shape[a.dims.index(concat_dim)] else: n_a = 1 if concat_dim in b.dims: n_b = b.data.shape[b.dims.index(concat_dim)] else: n_b = 1 ref_axis = ax_a if ax_a is not None else ax_b self._state.merged_concat_axis = _apply_promoted_attrs( self._state.merged_concat_axis, n_a, n_b, promote_a, promote_b, ref_axis, concat_dim, ) self._state.merged_attrs = equal_attrs self._state.cached_axes = _build_cached_axes( a, concat_dim, align_dim=self.settings.align_axis, merged_concat_axis=self._state.merged_concat_axis, )
[docs] class Concat(ez.Unit): """Concatenate two AxisArray streams along an axis. Pairs messages by arrival order (FIFO). No time-alignment. """ SETTINGS = ConcatSettings INPUT_SIGNAL_A = ez.InputStream(AxisArray) INPUT_SIGNAL_B = ez.InputStream(AxisArray) OUTPUT_SIGNAL = ez.OutputStream(AxisArray)
[docs] async def initialize(self) -> None: self.processor = ConcatProcessor(self.SETTINGS)
[docs] @ez.subscriber(INPUT_SIGNAL_A) async def on_a(self, msg: AxisArray) -> None: self.processor.push_a(msg)
[docs] @ez.subscriber(INPUT_SIGNAL_B) async def on_b(self, msg: AxisArray) -> None: self.processor.push_b(msg)
[docs] @ez.publisher(OUTPUT_SIGNAL) async def output(self) -> typing.AsyncGenerator: while True: yield self.OUTPUT_SIGNAL, await self.processor.__acall__()