Source code for ezmsg.sigproc.scaler

"""Adaptive standard scaling using exponentially weighted moving statistics."""

import typing

import ezmsg.core as ez
import numpy as np
from array_api_compat import get_namespace
from ezmsg.baseproc import (
    BaseStatefulTransformer,
    BaseTransformerUnit,
    processor_state,
)
from ezmsg.util.messages.axisarray import AxisArray
from ezmsg.util.messages.util import replace

# Imports for backwards compatibility with previous module location
from .ewma import EWMA_Deprecated as EWMA_Deprecated
from .ewma import EWMASettings, EWMATransformer, _alpha_from_tau
from .ewma import _tau_from_alpha as _tau_from_alpha
from .ewma import ewma_step as ewma_step


[docs] class RiverAdaptiveStandardScalerSettings(ez.Settings): time_constant: float = 1.0 """Decay constant ``tau`` in seconds.""" axis: str | None = None """The name of the axis to accumulate statistics over."""
[docs] @processor_state class RiverAdaptiveStandardScalerState: scaler: typing.Any = None axis: str | None = None axis_idx: int = 0
[docs] class RiverAdaptiveStandardScalerTransformer( BaseStatefulTransformer[ RiverAdaptiveStandardScalerSettings, AxisArray, AxisArray, RiverAdaptiveStandardScalerState, ] ): """ Apply the adaptive standard scaler from `river <https://riverml.xyz/latest/api/preprocessing/AdaptiveStandardScaler/>`_. This processes data sample-by-sample using River's online learning implementation. For a vectorized EWMA-based alternative, see :class:`AdaptiveStandardScalerTransformer`. """ def _reset_state(self, message: AxisArray) -> None: from river import preprocessing axis = self.settings.axis if axis is None: axis = message.dims[0] self._state.axis_idx = 0 else: self._state.axis_idx = message.get_axis_idx(axis) self._state.axis = axis alpha = _alpha_from_tau(self.settings.time_constant, message.axes[axis].gain) self._state.scaler = preprocessing.AdaptiveStandardScaler(fading_factor=alpha) def _process(self, message: AxisArray) -> AxisArray: data = message.data axis_idx = self._state.axis_idx if axis_idx != 0: data = np.moveaxis(data, axis_idx, 0) result = [] for sample in data: x = {k: v for k, v in enumerate(sample.flatten().tolist())} self._state.scaler.learn_one(x) y = self._state.scaler.transform_one(x) k = sorted(y.keys()) result.append(np.array([y[_] for _ in k]).reshape(sample.shape)) result = np.stack(result) result = np.moveaxis(result, 0, axis_idx) return replace(message, data=result)
[docs] class AdaptiveStandardScalerSettings(EWMASettings): ...
[docs] @processor_state class AdaptiveStandardScalerState: samps_ewma: EWMATransformer | None = None vars_sq_ewma: EWMATransformer | None = None alpha: float | None = None
[docs] class AdaptiveStandardScalerTransformer( BaseStatefulTransformer[ AdaptiveStandardScalerSettings, AxisArray, AxisArray, AdaptiveStandardScalerState, ] ): def _reset_state(self, message: AxisArray) -> None: self._state.samps_ewma = EWMATransformer( time_constant=self.settings.time_constant, axis=self.settings.axis, accumulate=self.settings.accumulate, ) self._state.vars_sq_ewma = EWMATransformer( time_constant=self.settings.time_constant, axis=self.settings.axis, accumulate=self.settings.accumulate, ) @property def accumulate(self) -> bool: """Whether to accumulate statistics from incoming samples.""" return self.settings.accumulate @accumulate.setter def accumulate(self, value: bool) -> None: """ Set the accumulate mode and propagate to child EWMA transformers. Args: value: If True, update statistics with each sample. If False, only apply current statistics without updating. """ if self._state.samps_ewma is not None: self._state.samps_ewma.settings = replace(self._state.samps_ewma.settings, accumulate=value) if self._state.vars_sq_ewma is not None: self._state.vars_sq_ewma.settings = replace(self._state.vars_sq_ewma.settings, accumulate=value) def _process(self, message: AxisArray) -> AxisArray: xp = get_namespace(message.data) # Update step (respects accumulate setting via child EWMAs) mean_message = self._state.samps_ewma(message) var_sq_message = self._state.vars_sq_ewma(replace(message, data=message.data**2)) # Get step: safe division avoids warnings from zero/negative variance varis = var_sq_message.data - mean_message.data**2 std = varis**0.5 mask = std > 0 safe_std = xp.where(mask, std, xp.asarray(1.0, dtype=std.dtype)) result = xp.where( mask, (message.data - mean_message.data) / safe_std, xp.asarray(0.0, dtype=message.data.dtype) ) return replace(message, data=result)
[docs] class AdaptiveStandardScaler( BaseTransformerUnit[ AdaptiveStandardScalerSettings, AxisArray, AxisArray, AdaptiveStandardScalerTransformer, ] ): SETTINGS = AdaptiveStandardScalerSettings INPUT_ACCUMULATE = ez.InputStream(bool)
[docs] @ez.subscriber(BaseTransformerUnit.INPUT_SETTINGS) async def on_settings(self, msg: AdaptiveStandardScalerSettings) -> None: """ Handle settings updates with smart reset behavior. Only resets state if `axis` changes (structural change). Changes to `time_constant` or `accumulate` are applied without resetting accumulated statistics. """ old_axis = self.SETTINGS.axis self.apply_settings(msg) if msg.axis != old_axis: # Axis changed - need full reset self.create_processor() else: # Update accumulate on processor (propagates to child EWMAs) self.processor.accumulate = msg.accumulate # Also update own settings reference self.processor.settings = msg
[docs] @ez.subscriber(INPUT_ACCUMULATE) async def on_accumulate(self, accumulate: bool) -> None: self.processor.accumulate = accumulate
# Convenience functions to support deprecated generator API
[docs] def scaler(time_constant: float = 1.0, axis: str | None = None) -> RiverAdaptiveStandardScalerTransformer: """Create a :class:`RiverAdaptiveStandardScalerTransformer` with the given parameters.""" return RiverAdaptiveStandardScalerTransformer( settings=RiverAdaptiveStandardScalerSettings(time_constant=time_constant, axis=axis) )
[docs] def scaler_np(time_constant: float = 1.0, axis: str | None = None) -> AdaptiveStandardScalerTransformer: return AdaptiveStandardScalerTransformer( settings=AdaptiveStandardScalerSettings(time_constant=time_constant, axis=axis) )