Source code for ezmsg.blackrock.channel_map

"""Attach Blackrock ``.cmp`` channel-map metadata to an ``AxisArray``'s ``ch`` axis.

The output ``ch`` axis is a structured ``CoordinateAxis`` with fields
``x``, ``y``, ``size``, ``label``, ``bank``, ``elec``, ``headstage`` for every
input channel. ``x``/``y``/``size`` are in micrometers; ``headstage`` is the
1-based headstage id (``0`` = none/auto).

:class:`ChannelMapUnit` takes the *complete* set of per-headstage overlays in
one settings object (:class:`ChannelMapUnitSettings`, a tuple of
:class:`ChannelMapSettings`) and rebuilds the ``ch`` axis from scratch on each
reset. One settings push = the whole map, applied deterministically — there is
no cross-push accumulation that could coalesce if pushes aren't separated by a
data message. An empty tuple clears the map (pure auto-grid).

Each reset proceeds in three phases:

1. **Base layer** — labels are pulled from the incoming ``ch`` axis. When the
   incoming axis already carries structured geometry (e.g. a CereLink source
   that read it from device chaninfo), its ``x``/``y``/``size``/``bank``/
   ``elec``/``headstage`` are copied through verbatim, so a map already present
   upstream needs no ``.cmp`` file at all. A channel counts as positioned (and
   so is skipped by the auto-grid) when it has a non-origin coordinate, or it
   is the *first* channel sitting at the ``(0, 0)`` origin — a lone origin
   electrode is a legitimate corner, but the device parks every *unmapped*
   channel at the origin, so origin pile-ups beyond the first fall through to
   the auto-grid. A companion ``src_mask`` records the positioned indices.
2. **CMP overlays** — for each :class:`ChannelMapSettings` in ``cmp_configs``,
   entries from :func:`pycbsdk.cmp.parse_cmp` are written at their channel
   index, overriding any source geometry there. ``parse_cmp``
   (CerebusOSS/CereLink#184) returns entries keyed by device ``(bank, term)``
   with flat ``x``/``y``/``size``/``headstage`` fields (``x``/``y`` in
   micrometers) and verbatim labels; the channel index is
   ``(bank - 1) * 32 + (term - 1)`` — ``start_chan`` is already folded into
   ``bank`` via its ``// 32`` offset. A companion ``cmp_mask`` records which
   indices were set so the auto-grid pass can avoid them.
3. **Auto-grid fill** — positions/bank/elec for indices covered by neither a
   CMP overlay nor a source position, laid out below and to the right of the
   placed geometry so they don't collide with it. The grid step matches the
   placed electrode pitch (inferred from its coordinates), so auto-laid
   channels share the same micrometer scale.

The same :class:`ChannelMapSettings` record is also used as a per-headstage
entry in :attr:`CereLinkSignalSettings.cmp_configs`.
"""

import logging
import math

import ezmsg.core as ez
import numpy as np
from ezmsg.baseproc import BaseStatefulTransformer, BaseTransformerUnit, processor_state
from ezmsg.util.messages.axisarray import AxisArray, CoordinateAxis
from ezmsg.util.messages.util import replace
from pycbsdk.cmp import parse_cmp

logger = logging.getLogger(__name__)

CHANNEL_DTYPE = np.dtype(
    [
        ("x", "i4"),  # electrode x, µm (int32, matching cbPKT_CHANINFO.position)
        ("y", "i4"),  # electrode y, µm
        ("size", "i4"),  # electrode size, µm (0 = unspecified)
        ("label", "U16"),
        ("bank", "U1"),
        ("elec", "i4"),
        ("headstage", "i4"),  # 1-based headstage id (0 = none/auto)
    ]
)


[docs] class ChannelMapSettings(ez.Settings): filepath: str | None = None """Path to the ``.cmp`` file. ``None`` (or an empty path) means no CMP — the auto-grid fallback generates coordinates for every channel.""" start_chan: int = 1 """1-based channel ID assigned to the first sorted CMP row. Mirrors :meth:`pycbsdk.Session.load_channel_map`.""" hs_id: int = 0 """Headstage identifier, passed through to :func:`pycbsdk.cmp.parse_cmp`, where it sets each entry's ``headstage`` field. Labels are taken verbatim (no ``"hs{hs_id}-"`` prefix); ``bank``/``elec`` disambiguate channels that reuse a label across headstages. Pass ``0`` for single-headstage rigs."""
[docs] class ChannelMapUnitSettings(ez.Settings): cmp_configs: tuple[ChannelMapSettings, ...] = () """Per-headstage overlays, applied in order on each reset. Empty (the default) means no CMP — the auto-grid lays out every channel."""
[docs] @processor_state class ChannelMapState: channel_axis: CoordinateAxis | None = None cmp_mask: np.ndarray | None = None # bool, indices set by a CMP overlay src_mask: np.ndarray | None = None # bool, indices positioned by the incoming axis
[docs] class ChannelMapProcessor(BaseStatefulTransformer[ChannelMapUnitSettings, AxisArray, AxisArray, ChannelMapState]): """Stateful transformer that attaches CMP-derived channel metadata. Each reset rebuilds the axis in full from ``settings.cmp_configs``: a base layer from incoming labels, every CMP overlay applied in order, then the auto-grid fills indices no CMP claimed. Reset fires on a channel-count change or any ``cmp_configs`` change (the latter via :class:`BaseProcessor.update_settings` → ``_request_reset``), so there is no cross-push state to coalesce. An empty ``cmp_configs`` yields a pure auto-grid. """ def _reset_state(self, message: AxisArray) -> None: ch_dim_idx = message.dims.index("ch") n_total = message.data.shape[ch_dim_idx] # Base layer: labels from incoming; positions seeded from the incoming # structured axis when the source already carries them (else filled by # the overlays and auto-grid below). ch_data = np.zeros(n_total, dtype=CHANNEL_DTYPE) for i, label in enumerate(self._incoming_labels(message, n_total)): ch_data[i]["label"] = label # Source geometry: copy x/y/size/bank/elec/headstage straight from the # incoming axis and record which channels were positioned. The auto-grid # skips these; a CMP overlay (below) still overrides them. src_mask = self._apply_incoming_positions(message, ch_data, n_total) # CMP overlays: write each headstage's entries at chan_id-1 and mark # them in cmp_mask so the auto-grid skips them. cmp_mask = np.zeros(n_total, dtype=bool) for cfg in self.settings.cmp_configs: if not cfg.filepath: continue try: parsed = parse_cmp(cfg.filepath, start_chan=cfg.start_chan, hs_id=cfg.hs_id) except Exception as exc: # _reset_state runs on every message via __acall__ until the # hash matches; a re-raise would loop forever. Log and skip. logger.warning( "ChannelMapProcessor: could not load %r (start_chan=%d, hs_id=%d): %s; skipping.", cfg.filepath, cfg.start_chan, cfg.hs_id, exc, ) continue for (bank, term), entry in parsed.items(): # parse_cmp keys by device (bank, term); start_chan is already # folded into bank via its // 32 offset, so the channel index is # a direct (bank, term) → row mapping (32 terminals per bank). idx = (bank - 1) * 32 + (term - 1) if not (0 <= idx < n_total): continue ch_data[idx]["x"] = int(entry.x) ch_data[idx]["y"] = int(entry.y) ch_data[idx]["size"] = int(entry.size) ch_data[idx]["label"] = entry.label # verbatim (no hs{N}- prefix) ch_data[idx]["bank"] = chr(ord("A") + bank - 1) ch_data[idx]["elec"] = term ch_data[idx]["headstage"] = entry.headstage cmp_mask[idx] = True self.state.channel_axis = CoordinateAxis(data=ch_data, dims=["ch"], unit="struct") self.state.cmp_mask = cmp_mask # CMP wins over source geometry: a CMP-claimed index is "placed" by the # overlay, not the source. self.state.src_mask = src_mask & ~cmp_mask # Auto-grid: position/bank/elec for indices neither a CMP nor the source # claimed, offset below the placed geometry so they don't overlap. self._fill_auto_grid() @staticmethod def _apply_incoming_positions(message: AxisArray, ch_data: np.ndarray, n_total: int) -> np.ndarray: """Copy structured geometry from the incoming ``ch`` axis into ``ch_data``. Returns a bool ``src_mask`` of channels that carry a usable source position — any non-origin coordinate, plus the *first* channel at the ``(0, 0)`` origin. Origin pile-ups beyond the first are the device's "unmapped" sentinel; they are left ``False`` so the auto-grid claims them (their copied zero coordinates get overwritten there). When the incoming axis is unstructured or lacks ``x``/``y`` (e.g. a label-only or plain TimeSeries source), nothing is copied and the mask is all ``False`` — the original pure auto-grid behavior. """ src_mask = np.zeros(n_total, dtype=bool) ch_axis = message.axes.get("ch") incoming = getattr(ch_axis, "data", None) names = getattr(getattr(incoming, "dtype", None), "names", None) if incoming is None or not names or not ({"x", "y"} <= set(names)): return src_mask copy_fields = [f for f in ("x", "y", "size", "bank", "elec", "headstage") if f in names] seen_origin = False for i in range(min(n_total, incoming.shape[0])): for f in copy_fields: ch_data[i][f] = incoming[f][i] at_origin = int(incoming["x"][i]) == 0 and int(incoming["y"][i]) == 0 if at_origin: if seen_origin: continue # duplicate origin → leave to the auto-grid seen_origin = True src_mask[i] = True return src_mask def _fill_auto_grid(self) -> None: ch_data = self.state.channel_axis.data # "Placed" = positioned by a CMP overlay or the incoming source axis. placed_mask = self.state.cmp_mask | self.state.src_mask auto_idx = np.flatnonzero(~placed_mask) if auto_idx.size == 0: return # Step matches the placed electrode pitch (≈400 µm) so the auto-grid # sits on the same scale as the real geometry. Without this, the µm # coordinates would dwarf a unit-spaced auto-grid. Falls back to 1 when # nothing is placed (pure auto-grid from the origin). step = self._placed_pitch(ch_data, placed_mask) if placed_mask.any(): max_row = int(ch_data["y"][placed_mask].max()) max_bank_ord = max( (ord(str(b)) for b in ch_data["bank"][placed_mask] if str(b)), default=ord("A") - 1, ) else: # Nothing placed yet — start auto-grid at the origin with bank A. # max_row = -2*step makes start_row = 0 below. max_row = -2 * step max_bank_ord = ord("A") - 1 start_row = max_row + 2 * step next_bank_ord = max_bank_ord + 1 grid_size = max(1, math.ceil(math.sqrt(auto_idx.size))) for i, idx in enumerate(auto_idx): ch_data[idx]["x"] = (i % grid_size) * step ch_data[idx]["y"] = start_row + (i // grid_size) * step ch_data[idx]["size"] = step # synthetic electrodes sized to the grid pitch ch_data[idx]["bank"] = chr(next_bank_ord + i // 32) ch_data[idx]["elec"] = (i % 32) + 1 ch_data[idx]["headstage"] = 0 # auto-grid channels have no headstage @staticmethod def _placed_pitch(ch_data: np.ndarray, placed_mask: np.ndarray) -> int: """Smallest positive spacing among the placed channels' distinct x/y. This is the electrode pitch in micrometers (≈400 for a default Utah array). Defaults to ``1`` when nothing is placed or the geometry is degenerate, giving the pure auto-grid unit spacing from the origin.""" if not placed_mask.any(): return 1 deltas: list[int] = [] for field in ("x", "y"): vals = np.unique(ch_data[field][placed_mask]) if vals.size > 1: deltas.append(int(np.diff(vals).min())) positive = [d for d in deltas if d > 0] return min(positive) if positive else 1 @staticmethod def _incoming_labels(message: AxisArray, n_total: int) -> list[str]: ch_axis = message.axes.get("ch") data = getattr(ch_axis, "data", None) if data is None: return [f"ch{i + 1}" for i in range(n_total)] if data.dtype.names is not None and "label" in data.dtype.names: labels = [str(x) for x in data["label"][:n_total]] else: labels = [str(x) for x in data[:n_total]] if len(labels) < n_total: labels.extend(f"ch{i + 1}" for i in range(len(labels), n_total)) return labels def _hash_message(self, message: AxisArray) -> int: ch_dim_idx = message.dims.index("ch") return hash(message.data.shape[ch_dim_idx]) def _process(self, message: AxisArray) -> AxisArray: return replace(message, axes={**message.axes, "ch": self.state.channel_axis})
[docs] class ChannelMapUnit(BaseTransformerUnit[ChannelMapUnitSettings, AxisArray, AxisArray, ChannelMapProcessor]): SETTINGS = ChannelMapUnitSettings