Source code for ezmsg.sigproc.rollingscaler

from collections import deque

import ezmsg.core as ez
import numpy as np
import numpy.typing as npt
from ezmsg.baseproc import (
    BaseAdaptiveTransformer,
    BaseAdaptiveTransformerUnit,
    processor_state,
)
from ezmsg.util.messages.axisarray import AxisArray
from ezmsg.util.messages.util import replace

from ezmsg.sigproc.sampler import SampleMessage


[docs] class RollingScalerSettings(ez.Settings): axis: str = "time" """ Axis along which samples are arranged. """ k_samples: int | None = 20 """ Rolling window size in number of samples. """ window_size: float | None = None """ Rolling window size in seconds. If set, overrides `k_samples`. `update_with_signal` likely should be True if using this option. """ update_with_signal: bool = False """ If True, update rolling statistics using the incoming process stream. """ min_samples: int = 1 """ Minimum number of samples required to compute statistics. Used when `window_size` is not set. """ min_seconds: float = 1.0 """ Minimum duration in seconds required to compute statistics. Used when `window_size` is set. """ artifact_z_thresh: float | None = None """ Threshold for z-score based artifact detection. If set, samples with any channel exceeding this z-score will be excluded from updating the rolling statistics. """ clip: float | None = 10.0 """ If set, clip the output values to the range [-clip, clip]. """
[docs] @processor_state class RollingScalerState: mean: npt.NDArray | None = None N: int = 0 M2: npt.NDArray | None = None samples: deque | None = None k_samples: int | None = None min_samples: int | None = None
[docs] class RollingScalerProcessor(BaseAdaptiveTransformer[RollingScalerSettings, AxisArray, AxisArray, RollingScalerState]): """ Processor for rolling z-score normalization of input `AxisArray` messages. The processor maintains rolling statistics (mean and variance) over the last `k_samples` samples received via the `partial_fit()` method. When processing an `AxisArray` message, it normalizes the data using the current rolling statistics. The input `AxisArray` messages are expected to have shape `(time, ch)`, where `ch` is the channel axis. The processor computes the z-score for each channel independently. Note: You should consider instead using the AdaptiveStandardScalerTransformer which is computationally more efficient and uses less memory. This RollingScalerProcessor is primarily provided to reproduce processing in the literature. Settings: --------- k_samples: int Number of previous samples to use for rolling statistics. Example: ----------------------------- ```python processor = RollingScalerProcessor( settings=RollingScalerSettings( k_samples=20 # Number of previous samples to use for rolling statistics ) ) ``` """ def _hash_message(self, message: AxisArray) -> int: axis = message.dims[0] if self.settings.axis is None else self.settings.axis gain = message.axes[axis].gain if hasattr(message.axes[axis], "gain") else 1 axis_idx = message.get_axis_idx(axis) samp_shape = message.data.shape[:axis_idx] + message.data.shape[axis_idx + 1 :] return hash((message.key, samp_shape, gain)) def _reset_state(self, message: AxisArray) -> None: ch = message.data.shape[-1] self._state.mean = np.zeros(ch) self._state.N = 0 self._state.M2 = np.zeros(ch) self._state.k_samples = ( int(np.ceil(self.settings.window_size / message.axes[self.settings.axis].gain)) if self.settings.window_size is not None else self.settings.k_samples ) if self._state.k_samples is not None and self._state.k_samples < 1: ez.logger.warning("window_size smaller than sample gain; setting k_samples to 1.") self._state.k_samples = 1 elif self._state.k_samples is None: ez.logger.warning("k_samples is None; z-score accumulation will be unbounded.") self._state.samples = deque(maxlen=self._state.k_samples) self._state.min_samples = ( int(np.ceil(self.settings.min_seconds / message.axes[self.settings.axis].gain)) if self.settings.window_size is not None else self.settings.min_samples ) if self._state.k_samples is not None and self._state.min_samples > self._state.k_samples: ez.logger.warning("min_samples is greater than k_samples; adjusting min_samples to k_samples.") self._state.min_samples = self._state.k_samples def _add_batch_stats(self, x: npt.NDArray) -> None: x = np.asarray(x, dtype=np.float64) n_b = x.shape[0] mean_b = np.mean(x, axis=0) M2_b = np.sum((x - mean_b) ** 2, axis=0) if self._state.k_samples is not None and len(self._state.samples) == self._state.k_samples: n_old, mean_old, M2_old = self._state.samples.popleft() N_T = self._state.N N_new = N_T - n_old if N_new <= 0: self._state.N = 0 self._state.mean = np.zeros_like(self._state.mean) self._state.M2 = np.zeros_like(self._state.M2) else: delta = mean_old - self._state.mean self._state.N = N_new self._state.mean = (N_T * self._state.mean - n_old * mean_old) / N_new self._state.M2 = self._state.M2 - M2_old - (delta * delta) * (N_T * n_old / N_new) N_A = self._state.N N = N_A + n_b delta = mean_b - self._state.mean self._state.mean = self._state.mean + delta * (n_b / N) self._state.M2 = self._state.M2 + M2_b + (delta * delta) * (N_A * n_b / N) self._state.N = N self._state.samples.append((n_b, mean_b, M2_b))
[docs] def partial_fit(self, message: SampleMessage) -> None: x = message.sample.data self._add_batch_stats(x)
def _process(self, message: AxisArray) -> AxisArray: if self._state.N == 0 or self._state.N < self._state.min_samples: if self.settings.update_with_signal: x = message.data if self.settings.artifact_z_thresh is not None and self._state.N > 0: varis = self._state.M2 / self._state.N std = np.maximum(np.sqrt(varis), 1e-8) z = np.abs((x - self._state.mean) / std) mask = np.any(z > self.settings.artifact_z_thresh, axis=1) x = x[~mask] if x.size > 0: self._add_batch_stats(x) return message varis = self._state.M2 / self._state.N std = np.maximum(np.sqrt(varis), 1e-8) with np.errstate(divide="ignore", invalid="ignore"): result = (message.data - self._state.mean) / std result = np.nan_to_num(result, nan=0.0, posinf=0.0, neginf=0.0) if self.settings.clip is not None: result = np.clip(result, -self.settings.clip, self.settings.clip) if self.settings.update_with_signal: x = message.data if self.settings.artifact_z_thresh is not None: z_scores = np.abs((x - self._state.mean) / std) mask = np.any(z_scores > self.settings.artifact_z_thresh, axis=1) x = x[~mask] if x.size > 0: self._add_batch_stats(x) return replace(message, data=result)
[docs] class RollingScalerUnit( BaseAdaptiveTransformerUnit[ RollingScalerSettings, AxisArray, AxisArray, RollingScalerProcessor, ] ): """ Unit wrapper for :obj:`RollingScalerProcessor`. This unit performs rolling z-score normalization on incoming `AxisArray` messages. The unit maintains rolling statistics (mean and variance) over the last `k_samples` samples received. When processing an `AxisArray` message, it normalizes the data using the current rolling statistics. Example: ----------------------------- ```python unit = RollingScalerUnit( settings=RollingScalerSettings( k_samples=20 # Number of previous samples to use for rolling statistics ) ) ``` """ SETTINGS = RollingScalerSettings