Source code for ezmsg.baseproc.util.pipeline_settings

"""Pipeline-settings event producer and supporting helpers.

This module provides a generic way to capture every settings change happening
inside an ezmsg pipeline and emit each one as a self-contained message that
can travel from unit to unit on the graph.

Three layers:

1. Helpers — ``flatten_ez_settings`` / ``flatten_component_settings`` /
   ``sanitize_settings_value`` / ``sanitize_settings_column_name``. These
   project arbitrary settings objects (``ez.Settings``, ``dict``, dataclass,
   numpy array, enum, path, …) into either dotted-key/value pairs or
   NWB-compatible scalar/array values. Useful to any tabular sink, not just
   NWB.
2. Message — ``PipelineSettingsEvent`` carries one settings change. It mirrors
   the meaningful subset of ``ezmsg.core.SettingsChangedEvent`` (``seq``,
   ``timestamp``, ``component_address``, ``repr_value``, ``structured_value``)
   plus a ``table_name`` for tabular sinks and an ``event_type`` distinguishing
   the startup snapshot from later updates. ``flatten_for_table`` projects to
   the ``{"data": json_str}`` shape an ``NWBPointRow``-compatible sink expects.
3. Producer / Unit — ``PipelineSettingsProducer`` + ``PipelineSettingsUnit``
   open a ``GraphContext``, queue one ``INITIAL`` event per in-scope component
   (using the current settings snapshot), and stream subsequent ``UPDATED``
   events from the graph server's settings subscription.
"""

from __future__ import annotations

import asyncio
import json
import os
import re
import time
import typing
from collections.abc import Mapping, Sequence
from dataclasses import asdict, dataclass, is_dataclass
from enum import Enum
from typing import Any, Optional

import ezmsg.core as ez
import numpy as np

from ..protocols import processor_state
from ..stateful import BaseStatefulProducer
from ..units import BaseProducerUnit

# ---------------------------------------------------------------------------
# Flatten / sanitize helpers (general-purpose; not NWB-specific)
# ---------------------------------------------------------------------------


def _flatten_value(value: Any, prefix: str) -> dict[str, Any]:
    """Flatten a settings value into dotted key/value pairs."""
    if isinstance(value, ez.Settings):
        return flatten_ez_settings(value, prefix)

    if is_dataclass(value) and not isinstance(value, type):
        value = asdict(value)

    if isinstance(value, Mapping):
        result: dict[str, Any] = {}
        for key, sub_value in value.items():
            key_prefix = f"{prefix}.{key}" if prefix else str(key)
            result.update(_flatten_value(sub_value, key_prefix))
        return result

    return {prefix if prefix else "value": value}


[docs] def flatten_ez_settings(settings: ez.Settings, prefix: str = "") -> dict[str, Any]: """Flatten an ezmsg settings object into dotted key/value pairs.""" settings_prefix = settings.__class__.__name__ if not prefix else prefix result: dict[str, Any] = {} for name, value in settings.__dict__.items(): result.update(_flatten_value(value, f"{settings_prefix}.{name}")) return result
[docs] def sanitize_settings_column_name(name: str) -> str: """Convert a settings field path into a tabular-safe column name.""" sanitized = re.sub(r"[^0-9A-Za-z_]+", ".", name).strip("_") if not sanitized: sanitized = "setting" if sanitized[0].isdigit(): sanitized = f"setting_{sanitized}" return sanitized
def _sanitize_sequence_value(value: Sequence[Any]) -> Any: """Sanitize a sequence while preserving array-like values when possible.""" sanitized = [sanitize_settings_value(item) for item in value] try: array_value = np.asarray(sanitized) except Exception: return json.dumps(sanitized, default=str) if array_value.dtype != object: return array_value if all(isinstance(item, str) for item in sanitized): return np.asarray(sanitized, dtype=str) if all(isinstance(item, bytes) for item in sanitized): return np.asarray(sanitized, dtype="S") try: stacked = np.stack(sanitized) except Exception: return json.dumps(sanitized, default=str) if stacked.dtype != object: return stacked return json.dumps(sanitized, default=str)
[docs] def sanitize_settings_value(value: Any) -> Any: """Convert a settings value into a tabular-friendly scalar or array value.""" if value is None: return "None" if isinstance(value, Enum): return sanitize_settings_value(value.value) if isinstance(value, os.PathLike): return os.fspath(value) if isinstance(value, np.generic): return value.item() if isinstance(value, np.ndarray): return value if value.dtype != object else _sanitize_sequence_value(value.tolist()) if isinstance(value, (bool, int, float, str, bytes)): return value if is_dataclass(value) and not isinstance(value, type): value = asdict(value) if isinstance(value, Mapping): return json.dumps({str(k): sanitize_settings_value(v) for k, v in value.items()}, default=str) if isinstance(value, (set, frozenset)): return _sanitize_sequence_value(sorted(value, key=repr)) if isinstance(value, Sequence) and not isinstance(value, (str, bytes, bytearray)): return _sanitize_sequence_value(value) try: return json.dumps(value) except TypeError: return str(value)
[docs] def flatten_component_settings(component_address: str, value: Any) -> dict[str, Any]: """Flatten and sanitize a component settings payload for tabular storage.""" if isinstance(value, ez.Settings): flat = flatten_ez_settings(value) elif hasattr(value, "structured_value") and getattr(value, "structured_value") is not None: flat = _flatten_value(getattr(value, "structured_value"), "") elif hasattr(value, "repr_value") and isinstance(getattr(value, "repr_value"), dict): flat = _flatten_value(getattr(value, "repr_value"), "") else: flat = _flatten_value(value, "") return { sanitize_settings_column_name(f"{component_address}.{field_name}"): sanitize_settings_value(field_value) for field_name, field_value in flat.items() }
# --------------------------------------------------------------------------- # Event message # ---------------------------------------------------------------------------
[docs] class PipelineSettingsEventType(str, Enum): """Distinguishes startup snapshot rows from in-flight settings changes.""" INITIAL = "INITIAL" UPDATED = "UPDATED"
_DEFAULT_TABLE_NAME = "settings_annotations" INIT_FINAL_COMPONENT_ADDRESS: typing.Final[str] = "__init_final__" """Sentinel ``component_address`` value attached to the final :class:`PipelineSettingsEvent` of the startup snapshot. Producers emit one such event with ``event_type=INITIAL``, ``structured_value=None`` and ``repr_value=""`` immediately after the last per-component INITIAL has been queued. Consumers that need to aggregate the per-component initial snapshot (e.g., a typed-column sink that wants one merged row instead of N rows + N-1 schema-driven rotations) detect this sentinel by address and flush their pending buffer at that point. Sinks that don't need aggregation should treat the sentinel as a control message — by default, :meth:`PipelineSettingsEvent.flatten_for_table` returns ``None`` for the sentinel so JSON-row sinks naturally skip it."""
[docs] @dataclass class PipelineSettingsEvent: """One settings-change observation, ready to ship across the graph. Mirrors the JSON-friendly portion of ``ezmsg.core.SettingsChangedEvent`` so the message survives any ezmsg transport without custom encoders. Sinks that want a tabular row call :meth:`flatten_for_table`; sinks that need native column types can read ``structured_value`` directly and project with the helpers in this module. """ seq: int """Monotonic sequence number from the graph server.""" timestamp: float """``time.monotonic()`` seconds when the change was observed. The graph server stamps settings events with ``time.time()``; the producer rebases each event onto the local monotonic clock at emit time so downstream consumers can compare these against other monotonic timestamps in the pipeline.""" component_address: str """Address of the component whose settings changed.""" event_type: PipelineSettingsEventType """``INITIAL`` for startup-snapshot rows, ``UPDATED`` for live changes.""" repr_value: typing.Union[dict[str, Any], str] """Human-readable settings dump (``dict`` for dataclass-shaped settings, ``repr()`` string fallback otherwise).""" structured_value: Optional[dict[str, Any]] = None """Parsed structured settings dump (``None`` if structure couldn't be derived; falls back to ``repr_value``).""" source_session_id: Optional[str] = None table_name: str = _DEFAULT_TABLE_NAME """Name of the target table/series the consuming sink should write into."""
[docs] def flatten_for_table(self) -> Optional[dict[str, Any]]: """Project this event into ``NWBPointRow``-compatible columns. Returns ``{"data": json_string}`` so the event can be appended to a ``pynwb.misc.AnnotationSeries`` (single-string-per-timestamp). The JSON payload self-describes the change with ``component``, ``event_type``, ``seq``, and the structured ``settings`` snapshot. Returns ``None`` for the :data:`INIT_FINAL_COMPONENT_ADDRESS` sentinel — the event is a control marker for typed-column sinks and has no row to write in a JSON-row sink. """ if self.component_address == INIT_FINAL_COMPONENT_ADDRESS: return None if self.structured_value is not None: payload: typing.Union[dict[str, Any], str] = self.structured_value elif isinstance(self.repr_value, dict): payload = self.repr_value else: payload = self.repr_value return { "data": json.dumps( { "component": self.component_address, "event_type": self.event_type.value, "seq": self.seq, "settings": payload, }, default=str, ), }
# --------------------------------------------------------------------------- # Producer # ---------------------------------------------------------------------------
[docs] class PipelineSettingsProducerSettings(ez.Settings): """Settings for :class:`PipelineSettingsProducer`.""" target_table: str = _DEFAULT_TABLE_NAME """Value to stamp onto each emitted event's ``table_name`` field.""" scope_components: Optional[tuple[str, ...]] = None """If set, only emit events for these component addresses. ``None`` means watch every component the graph server reports. ``PipelineSettingsUnit`` fills this in automatically by discovering its own session's components on initialize.""" graph_address: Optional[tuple[str, int]] = None """Address of the ``GraphServer`` to connect to. ``None`` lets ``GraphContext`` resolve the default (``EZMSG_GRAPHSERVER_ADDR`` env var, falling back to ``127.0.0.1:25978``). Set explicitly when running multiple parallel ezmsg systems, or in tests where ``ez.run`` was called with a non-default ``graph_address``."""
[docs] @processor_state class PipelineSettingsProducerState: """State for :class:`PipelineSettingsProducer`.""" ctx: Optional[ez.GraphContext] = None queue: Optional[asyncio.Queue] = None watch_task: Optional[asyncio.Task] = None last_seq: int = 0 scope: Optional[frozenset[str]] = None initialized: bool = False
def _map_event_type(upstream: typing.Any) -> PipelineSettingsEventType: """Translate an upstream ``SettingsEventType`` into our public enum.""" name = getattr(upstream, "name", str(upstream)) if name == "INITIAL_SETTINGS": return PipelineSettingsEventType.INITIAL return PipelineSettingsEventType.UPDATED
[docs] class PipelineSettingsProducer( BaseStatefulProducer[ PipelineSettingsProducerSettings, PipelineSettingsEvent, PipelineSettingsProducerState, ] ): """Stream the graph server's settings events out as :class:`PipelineSettingsEvent`. On reset (which happens on first ``__acall__``), the producer: 1. Opens a ``GraphContext`` against the running graph server. 2. Pulls the current settings snapshot and queues one ``INITIAL`` event per in-scope component. 3. Subscribes to subsequent settings events; each one becomes an ``UPDATED`` event in the queue. ``_produce`` returns one event per call, awaiting the queue when it's empty. If the GraphContext can't be opened (no server running, network issue), the producer logs a warning and the queue stays empty — the unit will simply never publish anything, matching the "best-effort" semantics of the previous in-sink approach. """ def _reset_state(self) -> None: # Producer needs an event loop to open the GraphContext; the real # work lives in _areset_state, which the framework calls from the # async __acall__ path. raise NotImplementedError("PipelineSettingsProducer requires async setup; call via __acall__.") async def _areset_state(self) -> None: # Tear down anything left over from a prior incarnation (settings- # driven reset). await self._teardown() self._state.queue = asyncio.Queue() self._state.last_seq = 0 self._state.scope = ( frozenset(self.settings.scope_components) if self.settings.scope_components is not None else None ) self._state.initialized = False try: ctx = ez.GraphContext(self.settings.graph_address, auto_start=False) await ctx.__aenter__() except Exception as exc: ez.logger.warning(f"PipelineSettingsProducer could not open GraphContext: {exc}") return self._state.ctx = ctx try: await self._seed_initial_events() self._state.watch_task = asyncio.create_task( self._watch(after_seq=self._state.last_seq), name="pipeline-settings-watch", ) self._state.initialized = True except Exception as exc: ez.logger.warning(f"PipelineSettingsProducer initial snapshot failed: {exc}") await self._teardown() async def _seed_initial_events(self) -> None: """Fill the queue with one ``INITIAL`` event per in-scope component. After the per-component events, queue one final sentinel event (``component_address=INIT_FINAL_COMPONENT_ADDRESS``, ``structured_value=None``) so consumers that aggregate the startup snapshot have a clean batch boundary. JSON-row sinks skip the sentinel via :meth:`PipelineSettingsEvent.flatten_for_table` returning ``None``. """ ctx = self._state.ctx if ctx is None or self._state.queue is None: return settings_snapshot = await ctx.settings_snapshot() seed_events = await ctx.settings_events(after_seq=0) # For each component in the snapshot, find the most recent seed event # so the INITIAL row carries an accurate seq + timestamp. latest_per_addr: dict[str, typing.Any] = {} for ev in seed_events: cur = latest_per_addr.get(ev.component_address) if cur is None or ev.seq > cur.seq: latest_per_addr[ev.component_address] = ev scope = self._state.scope emitted_any = False max_initial_ts = 0.0 for addr in sorted(settings_snapshot.keys()): if scope is not None and addr not in scope: continue value = settings_snapshot[addr] latest = latest_per_addr.get(addr) seq = latest.seq if latest is not None else 0 ts = latest.timestamp if latest is not None else time.time() session_id = latest.source_session_id if latest is not None else None self._state.queue.put_nowait( PipelineSettingsEvent( seq=seq, timestamp=ts, component_address=addr, event_type=PipelineSettingsEventType.INITIAL, repr_value=value.repr_value, structured_value=value.structured_value, source_session_id=session_id, table_name=self.settings.target_table, ) ) self._state.last_seq = max(self._state.last_seq, seq) max_initial_ts = max(max_initial_ts, ts) emitted_any = True self._state.last_seq = max( self._state.last_seq, max((ev.seq for ev in seed_events), default=0), ) if emitted_any: self._state.queue.put_nowait( PipelineSettingsEvent( seq=self._state.last_seq, timestamp=max_initial_ts or time.time(), component_address=INIT_FINAL_COMPONENT_ADDRESS, event_type=PipelineSettingsEventType.INITIAL, repr_value="", structured_value=None, table_name=self.settings.target_table, ) ) async def _watch(self, after_seq: int) -> None: """Forward live settings events into the queue until cancelled.""" ctx = self._state.ctx queue = self._state.queue if ctx is None or queue is None: return scope = self._state.scope try: async for event in ctx.subscribe_settings_events(after_seq=after_seq): if scope is not None and event.component_address not in scope: continue await queue.put( PipelineSettingsEvent( seq=event.seq, timestamp=event.timestamp, component_address=event.component_address, event_type=_map_event_type(event.event_type), repr_value=event.value.repr_value, structured_value=event.value.structured_value, source_session_id=event.source_session_id, table_name=self.settings.target_table, ) ) self._state.last_seq = event.seq except asyncio.CancelledError: raise except Exception as exc: ez.logger.warning(f"PipelineSettingsProducer watch terminated: {exc}") async def _produce(self) -> PipelineSettingsEvent: # If reset failed to bring up a queue, block indefinitely so the unit # doesn't spin emitting Nones. The framework will cancel us on # shutdown. if self._state.queue is None: await asyncio.Event().wait() raise RuntimeError("unreachable") event = await self._state.queue.get() # Upstream stamps with time.time(); rebase onto the local monotonic # clock here so consumers see a monotonic timestamp. Sampled per # event — accurate enough given the low rate of settings changes. event.timestamp = event.timestamp - time.time() + time.monotonic() return event async def _teardown(self) -> None: """Cancel the watch task and close the GraphContext.""" task = self._state.watch_task self._state.watch_task = None if task is not None and not task.done(): task.cancel() try: await task except (asyncio.CancelledError, Exception): pass ctx = self._state.ctx self._state.ctx = None if ctx is not None: try: await ctx.__aexit__(None, None, None) except Exception: pass self._state.queue = None self._state.initialized = False
# --------------------------------------------------------------------------- # Unit # ---------------------------------------------------------------------------
[docs] class PipelineSettingsUnit( BaseProducerUnit[ PipelineSettingsProducerSettings, PipelineSettingsEvent, PipelineSettingsProducer, ] ): """Producer unit that emits one :class:`PipelineSettingsEvent` per change. Wire its ``OUTPUT_SIGNAL`` to any sink that consumes :class:`PipelineSettingsEvent` (or, more loosely, any sink that accepts objects implementing the ``NWBPointRow`` protocol — settings events satisfy it via :meth:`PipelineSettingsEvent.flatten_for_table`). On ``initialize``, if ``scope_components`` is unset, the unit looks up its own session in the graph snapshot and scopes the producer to that session's components. This matches the behaviour of the original in-sink watcher: only events from components running alongside the sink are emitted, so multiple parallel pipelines don't cross-pollute. """ SETTINGS = PipelineSettingsProducerSettings
[docs] async def initialize(self) -> None: # If the user didn't pre-supply a scope, discover this unit's session # and apply it to the producer's settings before reset runs. if self.SETTINGS.scope_components is None: try: scope = await self._discover_session_scope() except Exception as exc: ez.logger.warning(f"{self.address} could not discover session scope; will watch all components: {exc}") scope = None if scope: self.apply_settings( PipelineSettingsProducerSettings( target_table=self.SETTINGS.target_table, scope_components=tuple(sorted(scope)), graph_address=self.SETTINGS.graph_address, ) ) await super().initialize()
[docs] async def shutdown(self) -> None: prod = getattr(self, "producer", None) if prod is not None: try: await prod._teardown() except Exception: pass await super().shutdown()
async def _discover_session_scope(self) -> Optional[set[str]]: ctx = ez.GraphContext(self.SETTINGS.graph_address, auto_start=False) await ctx.__aenter__() try: snapshot = await ctx.snapshot() for session in snapshot.sessions.values(): metadata = session.metadata if metadata is None: continue if self.address in metadata.components: return set(metadata.components.keys()) finally: await ctx.__aexit__(None, None, None) return None