ezmsg.learn.process.ssr#

Self-supervised regression framework and LRR implementation.

This module provides a general framework for self-supervised channel regression via SelfSupervisedRegressionTransformer, and a concrete implementation — Linear Regression Rereferencing (LRR) — via LRRTransformer.

Framework. The base class accumulates the channel covariance C = X^T X and solves per-cluster ridge regressions to obtain a weight matrix W. Subclasses define what to do with W by implementing _on_weights_updated() and _process().

LRR. For each channel c, predict it from the other channels in its cluster via ridge regression, then subtract the prediction:

y = X - X @ W = X @ (I - W)

The effective weight matrix I - W is passed to AffineTransformTransformer, which automatically exploits block-diagonal structure when channel_clusters are provided.

Fitting. Given data matrix X of shape (samples, channels), the sufficient statistic is the channel covariance C = X^T X. When incremental=True (default), C is accumulated across partial_fit() calls.

Solving. Within each cluster the weight matrix W is obtained from the inverse of the (ridge-regularised) cluster covariance C_inv = (C_cluster + lambda * I)^{-1} using the block-inverse identity:

W[:, c] = -C_inv[:, c] / C_inv[c, c],    diag(W) = 0

This replaces the naive per-channel Cholesky loop with a single matrix inverse per cluster, keeping the linear algebra in the source array namespace so that GPU-backed arrays benefit from device-side computation.

Classes

class LRRSettings(weights=None, axis=None, channel_clusters=None, ridge_lambda=0.0, incremental=True, min_cluster_size=32)[source]#

Bases: SelfSupervisedRegressionSettings

Settings for LRRTransformer.

Parameters:
min_cluster_size: int = 32#

Passed to AffineTransformTransformer for the block-diagonal merge threshold.

__init__(weights=None, axis=None, channel_clusters=None, ridge_lambda=0.0, incremental=True, min_cluster_size=32)#
Parameters:
Return type:

None

class LRRState[source]#

Bases: SelfSupervisedRegressionState

affine: AffineTransformTransformer | None = None#
class LRRTransformer(*args, **kwargs)[source]#

Bases: SelfSupervisedRegressionTransformer[LRRSettings, LRRState]

Adaptive LRR transformer.

partial_fit accepts a plain AxisArray (self-supervised), and the transform step is delegated to an internal AffineTransformTransformer.

class LRRUnit(*args, settings=None, **kwargs)[source]#

Bases: BaseAdaptiveTransformerUnit[LRRSettings, AxisArray, AxisArray, LRRTransformer]

ezmsg Unit wrapping LRRTransformer.

Follows the BaseAdaptiveDecompUnit pattern — accepts AxisArray for self-supervised training via INPUT_SAMPLE.

Parameters:

settings (Settings | None)

SETTINGS#

alias of LRRSettings

INPUT_SAMPLE = InputStream:unlocated[AxisArray]()#
async on_sample(msg)[source]#
Return type:

None

Parameters:

msg (AxisArray)

class SelfSupervisedRegressionSettings(weights=None, axis=None, channel_clusters=None, ridge_lambda=0.0, incremental=True)[source]#

Bases: Settings

Settings common to all self-supervised regression modes.

Parameters:
weights: ndarray | str | Path | None = None#

Pre-calculated weight matrix W or path to a CSV file (np.loadtxt compatible). If provided, the transformer is ready immediately.

axis: str | None = None#

Channel axis name. None defaults to the last dimension.

channel_clusters: list[list[int]] | None = None#

Per-cluster regression. None treats all channels as one cluster.

ridge_lambda: float = 0.0#

Ridge (L2) regularisation parameter.

incremental: bool = True#

When True, accumulate X^T X across partial_fit() calls. When False, each call replaces the previous statistics.

__init__(weights=None, axis=None, channel_clusters=None, ridge_lambda=0.0, incremental=True)#
Parameters:
Return type:

None

class SelfSupervisedRegressionState[source]#

Bases: object

cxx: object | None = None#
n_samples: int = 0#
weights: object | None = None#
class SelfSupervisedRegressionTransformer(*args, **kwargs)[source]#

Bases: BaseAdaptiveTransformer[SettingsType, AxisArray, AxisArray, StateType], Generic[SettingsType, StateType]

Abstract base for self-supervised regression transformers.

Subclasses must implement:

  • _on_weights_updated() — called whenever the weight matrix W is (re)computed, so the subclass can build whatever internal transform it needs (e.g. I - W for LRR).

  • _process() — the per-message transform step.

partial_fit(message)[source]#
Return type:

None

Parameters:

message (AxisArray)

fit(X)[source]#

Batch fit from a raw numpy array (samples x channels).

Return type:

None

Parameters:

X (ndarray)