Source code for ezmsg.learn.process.ssr

"""Self-supervised regression framework and LRR implementation.

This module provides a general framework for self-supervised channel
regression via :class:`SelfSupervisedRegressionTransformer`, and a
concrete implementation — Linear Regression Rereferencing (LRR) — via
:class:`LRRTransformer`.

**Framework.**  The base class accumulates the channel covariance
``C = X^T X`` and solves per-cluster ridge regressions to obtain a weight
matrix *W*.  Subclasses define what to *do* with *W* by implementing
:meth:`~SelfSupervisedRegressionTransformer._on_weights_updated` and
:meth:`~SelfSupervisedRegressionTransformer._process`.

**LRR.**  For each channel *c*, predict it from the other channels in its
cluster via ridge regression, then subtract the prediction::

    y = X - X @ W = X @ (I - W)

The effective weight matrix ``I - W`` is passed to
:class:`~ezmsg.sigproc.affinetransform.AffineTransformTransformer`, which
automatically exploits block-diagonal structure when ``channel_clusters``
are provided.

**Fitting.**  Given data matrix *X* of shape ``(samples, channels)``, the
sufficient statistic is the channel covariance ``C = X^T X``.  When
``incremental=True`` (default), *C* is accumulated across
:meth:`~SelfSupervisedRegressionTransformer.partial_fit` calls.

**Solving.**  Within each cluster the weight matrix *W* is obtained from
the inverse of the (ridge-regularised) cluster covariance
``C_inv = (C_cluster + lambda * I)^{-1}`` using the block-inverse identity::

    W[:, c] = -C_inv[:, c] / C_inv[c, c],    diag(W) = 0

This replaces the naive per-channel Cholesky loop with a single matrix
inverse per cluster, keeping the linear algebra in the source array
namespace so that GPU-backed arrays benefit from device-side computation.
"""

from __future__ import annotations

import os
import typing
from abc import abstractmethod
from pathlib import Path

import ezmsg.core as ez
import numpy as np
from array_api_compat import get_namespace
from ezmsg.baseproc import (
    BaseAdaptiveTransformer,
    BaseAdaptiveTransformerUnit,
    processor_state,
)
from ezmsg.baseproc.protocols import SettingsType, StateType
from ezmsg.sigproc.affinetransform import (
    AffineTransformSettings,
    AffineTransformTransformer,
)
from ezmsg.sigproc.util.array import array_device, xp_create
from ezmsg.util.messages.axisarray import AxisArray

# ---------------------------------------------------------------------------
# Base: Self-supervised regression
# ---------------------------------------------------------------------------


[docs] class SelfSupervisedRegressionSettings(ez.Settings): """Settings common to all self-supervised regression modes.""" weights: np.ndarray | str | Path | None = None """Pre-calculated weight matrix *W* or path to a CSV file (``np.loadtxt`` compatible). If provided, the transformer is ready immediately.""" axis: str | None = None """Channel axis name. ``None`` defaults to the last dimension.""" channel_clusters: list[list[int]] | None = None """Per-cluster regression. ``None`` treats all channels as one cluster.""" ridge_lambda: float = 0.0 """Ridge (L2) regularisation parameter.""" incremental: bool = True """When ``True``, accumulate ``X^T X`` across :meth:`partial_fit` calls. When ``False``, each call replaces the previous statistics."""
[docs] @processor_state class SelfSupervisedRegressionState: cxx: object | None = None # Array API; namespace matches source data. n_samples: int = 0 weights: object | None = None # Array API; namespace matches cxx.
[docs] class SelfSupervisedRegressionTransformer( BaseAdaptiveTransformer[SettingsType, AxisArray, AxisArray, StateType], typing.Generic[SettingsType, StateType], ): """Abstract base for self-supervised regression transformers. Subclasses must implement: * :meth:`_on_weights_updated` — called whenever the weight matrix *W* is (re)computed, so the subclass can build whatever internal transform it needs (e.g. ``I - W`` for LRR). * :meth:`_process` — the per-message transform step. """ # -- message hash / state management ------------------------------------ def _hash_message(self, message: AxisArray) -> int: axis = self.settings.axis or message.dims[-1] axis_idx = message.get_axis_idx(axis) return hash((message.key, message.data.shape[axis_idx])) def _reset_state(self, message: AxisArray) -> None: axis = self.settings.axis or message.dims[-1] axis_idx = message.get_axis_idx(axis) n_channels = message.data.shape[axis_idx] self._validate_clusters(n_channels) self._state.cxx = None self._state.n_samples = 0 self._state.weights = None # If pre-calculated weights are provided, load and go. weights = self.settings.weights if weights is not None: if isinstance(weights, str): weights = Path(os.path.abspath(os.path.expanduser(weights))) if isinstance(weights, Path): weights = np.loadtxt(weights, delimiter=",") weights = np.asarray(weights, dtype=np.float64) self._state.weights = weights self._on_weights_updated() # -- cluster validation -------------------------------------------------- def _validate_clusters(self, n_channels: int) -> None: """Raise if any cluster index is out of range.""" clusters = self.settings.channel_clusters if clusters is None: return all_indices = np.concatenate([np.asarray(g) for g in clusters]) if np.any((all_indices < 0) | (all_indices >= n_channels)): raise ValueError(f"channel_clusters contains out-of-range indices (valid range: 0..{n_channels - 1})") # -- weight solving ------------------------------------------------------ def _solve_weights(self, cxx): """Solve all per-channel ridge regressions via matrix inverse. Uses the block-inverse identity: for target channel *c* with references *r*, ``w_c = -C_inv[r, c] / C_inv[c, c]`` where ``C_inv = (C_cluster + λI)⁻¹``. This replaces the per-channel Cholesky loop with one matrix inverse per cluster. All computation stays in the source array namespace so that GPU-backed arrays benefit from device-side execution. Cluster results are scattered into the full matrix via a selection-matrix multiply (``S @ W_cluster @ S^T``) to avoid numpy fancy indexing. Returns weight matrix *W* in the same namespace as *cxx*, with ``diag(W) == 0``. """ xp = get_namespace(cxx) dev = array_device(cxx) n = cxx.shape[0] clusters = self.settings.channel_clusters if clusters is None: clusters = [list(range(n))] W = xp_create(xp.zeros, (n, n), dtype=cxx.dtype, device=dev) eye_n = xp_create(xp.eye, n, dtype=cxx.dtype, device=dev) for cluster in clusters: k = len(cluster) if k <= 1: continue idx_xp = xp.asarray(cluster) if dev is None else xp.asarray(cluster, device=dev) eye_k = xp_create(xp.eye, k, dtype=cxx.dtype, device=dev) # Extract cluster sub-covariance (stays on device) sub = xp.take(xp.take(cxx, idx_xp, axis=0), idx_xp, axis=1) if self.settings.ridge_lambda > 0: sub = sub + self.settings.ridge_lambda * eye_k # One inverse per cluster try: sub_inv = xp.linalg.inv(sub) except Exception: sub_inv = xp.linalg.pinv(sub) # Diagonal via element-wise product with identity diag_vals = xp.sum(sub_inv * eye_k, axis=0) # w_c = -C_inv[:, c] / C_inv[c, c], vectorised over all c W_cluster = -(sub_inv / xp.reshape(diag_vals, (1, k))) # Zero the diagonal W_cluster = W_cluster * (1.0 - eye_k) # Scatter into full W if k == n: W = W + W_cluster else: # Selection matrix: columns of eye(n) at cluster indices S = xp.take(eye_n, idx_xp, axis=1) # (n, k) W = W + xp.matmul(S, xp.matmul(W_cluster, xp.permute_dims(S, (1, 0)))) return W # -- partial_fit (self-supervised, accepts AxisArray) --------------------
[docs] def partial_fit(self, message: AxisArray) -> None: # type: ignore[override] xp = get_namespace(message.data) if xp.any(xp.isnan(message.data)): return # Hash check / state reset msg_hash = self._hash_message(message) if self._hash != msg_hash: self._reset_state(message) self._hash = msg_hash axis = self.settings.axis or message.dims[-1] axis_idx = message.get_axis_idx(axis) data = message.data # Move channel axis to last, flatten to 2-D if axis_idx != data.ndim - 1: perm = list(range(data.ndim)) perm.append(perm.pop(axis_idx)) data = xp.permute_dims(data, perm) n_channels = data.shape[-1] X = xp.reshape(data, (-1, n_channels)) # Covariance stays in the source namespace for accumulation. cxx_new = xp.matmul(xp.permute_dims(X, (1, 0)), X) if self.settings.incremental and self._state.cxx is not None: self._state.cxx = self._state.cxx + cxx_new else: self._state.cxx = cxx_new self._state.n_samples += int(X.shape[0]) self._state.weights = self._solve_weights(self._state.cxx) self._on_weights_updated()
# -- convenience APIs ----------------------------------------------------
[docs] def fit(self, X: np.ndarray) -> None: """Batch fit from a raw numpy array (samples x channels).""" n_channels = X.shape[-1] self._validate_clusters(n_channels) X = np.asarray(X, dtype=np.float64).reshape(-1, n_channels) self._state.cxx = X.T @ X self._state.n_samples = X.shape[0] self._state.weights = self._solve_weights(self._state.cxx) self._on_weights_updated()
# -- abstract hooks for subclasses --------------------------------------- @abstractmethod def _on_weights_updated(self) -> None: """Called after ``self._state.weights`` has been set/updated. Subclasses should build or refresh whatever internal transform object they need for :meth:`_process`. """ ... @abstractmethod def _process(self, message: AxisArray) -> AxisArray: ...
# --------------------------------------------------------------------------- # Concrete: Linear Regression Rereferencing (LRR) # ---------------------------------------------------------------------------
[docs] class LRRSettings(SelfSupervisedRegressionSettings): """Settings for :class:`LRRTransformer`.""" min_cluster_size: int = 32 """Passed to :class:`AffineTransformTransformer` for the block-diagonal merge threshold."""
[docs] @processor_state class LRRState(SelfSupervisedRegressionState): affine: AffineTransformTransformer | None = None
[docs] class LRRTransformer( SelfSupervisedRegressionTransformer[LRRSettings, LRRState], ): """Adaptive LRR transformer. ``partial_fit`` accepts a plain :class:`AxisArray` (self-supervised), and the transform step is delegated to an internal :class:`AffineTransformTransformer`. """ # -- state management (clear own state, then delegate to base) ---------- def _reset_state(self, message: AxisArray) -> None: self._state.affine = None super()._reset_state(message) # -- weights → affine transform ----------------------------------------- def _on_weights_updated(self) -> None: xp = get_namespace(self._state.weights) dev = array_device(self._state.weights) n = self._state.weights.shape[0] effective = xp_create(xp.eye, n, dtype=self._state.weights.dtype, device=dev) - self._state.weights # Prefer in-place weight update when the affine transformer supports # it (avoids a full _reset_state round-trip on every partial_fit). if self._state.affine is not None: self._state.affine.set_weights(effective) else: self._state.affine = AffineTransformTransformer( AffineTransformSettings( weights=effective, axis=self.settings.axis, channel_clusters=self.settings.channel_clusters, min_cluster_size=self.settings.min_cluster_size, ) ) # -- transform ----------------------------------------------------------- def _process(self, message: AxisArray) -> AxisArray: if self._state.affine is None: raise RuntimeError( "LRRTransformer has not been fitted. Call partial_fit() or provide pre-calculated weights." ) return self._state.affine(message)
[docs] class LRRUnit( BaseAdaptiveTransformerUnit[ LRRSettings, AxisArray, AxisArray, LRRTransformer, ], ): """ezmsg Unit wrapping :class:`LRRTransformer`. Follows the :class:`BaseAdaptiveDecompUnit` pattern — accepts :class:`AxisArray` for self-supervised training via ``INPUT_SAMPLE``. """ SETTINGS = LRRSettings INPUT_SAMPLE = ez.InputStream(AxisArray)
[docs] @ez.subscriber(INPUT_SAMPLE) async def on_sample(self, msg: AxisArray) -> None: await self.processor.apartial_fit(msg)