Source code for ezmsg.sigproc.ewma

"""Exponentially weighted moving average (EWMA) utilities and parameter conversion."""

import functools
from dataclasses import field

import ezmsg.core as ez
import numpy as np
import numpy.typing as npt
import scipy.signal as sps
from array_api_compat import get_namespace, is_numpy_array
from ezmsg.baseproc import BaseStatefulTransformer, BaseTransformerUnit, processor_state
from ezmsg.util.messages.axisarray import AxisArray, slice_along_axis
from ezmsg.util.messages.util import replace


def _ewma_mlx_metal_xp(data, axis_idx: int, zi, alpha: float):
    """Run EWMA through the MLX Metal helper while preserving scipy zi layout."""
    import mlx.core as mx

    from .util.ewma_mlx_metal import MAX_CHUNK_SIZE, ewma_mlx_metal

    zi = mx.asarray(zi, dtype=data.dtype)
    last_data_axis = data.ndim - 1
    last_zi_axis = zi.ndim - 1
    x_mx = mx.moveaxis(data, axis_idx, last_data_axis) if axis_idx != last_data_axis else data
    zi_mx = mx.moveaxis(zi, axis_idx, last_zi_axis) if axis_idx != last_zi_axis else zi
    cs = min(x_mx.shape[-1], MAX_CHUNK_SIZE)
    y_mx, zf_mx = ewma_mlx_metal(x_mx, alpha, zi_mx, chunk_size=cs)
    y = mx.moveaxis(y_mx, last_data_axis, axis_idx) if axis_idx != last_data_axis else y_mx
    zf = mx.moveaxis(zf_mx, last_zi_axis, axis_idx) if axis_idx != last_zi_axis else zf_mx
    return y, zf


def _tau_from_alpha(alpha: float, dt: float) -> float:
    """
    Inverse of _alpha_from_tau. See that function for explanation.
    """
    return -dt / np.log(1 - alpha)


def _alpha_from_tau(tau: float, dt: float) -> float:
    """
    # https://en.wikipedia.org/wiki/Exponential_smoothing#Time_constant
    :param tau: The amount of time for the smoothed response of a unit step function to reach
        1 - 1/e approx-eq 63.2%.
    :param dt: sampling period, or 1 / sampling_rate.
    :return: alpha, the "fading factor" in exponential smoothing.
    """
    return 1 - np.exp(-dt / tau)


[docs] def ewma_step(sample: npt.NDArray, zi: npt.NDArray, alpha: float, beta: float | None = None): """ Do an exponentially weighted moving average step. Args: sample: The new sample. zi: The output of the previous step. alpha: Fading factor. beta: Persisting factor. If None, it is calculated as 1-alpha. Returns: alpha * sample + beta * zi """ # Potential micro-optimization: # Current: scalar-arr multiplication, scalar-arr multiplication, arr-arr addition # Alternative: arr-arr subtraction, arr-arr multiplication, arr-arr addition # return zi + alpha * (new_sample - zi) beta = beta or (1 - alpha) return alpha * sample + beta * zi
[docs] class EWMA_Deprecated: """ Grabbed these methods from https://stackoverflow.com/a/70998068 and other answers in that topic, but they ended up being slower than the scipy.signal.lfilter method. Additionally, `compute` and `compute2` suffer from potential errors as the vector length increases and beta**n approaches zero. """
[docs] def __init__(self, alpha: float, max_len: int): self.alpha = alpha self.beta = 1 - alpha self.prev: npt.NDArray | None = None self.weights = np.empty((max_len + 1,), float) self._precalc_weights(max_len) self._step_func = functools.partial(ewma_step, alpha=self.alpha, beta=self.beta)
def _precalc_weights(self, n: int): # (1-α)^0, (1-α)^1, (1-α)^2, ..., (1-α)^n np.power(self.beta, np.arange(n + 1), out=self.weights)
[docs] def compute(self, arr: npt.NDArray, out: npt.NDArray | None = None) -> npt.NDArray: if out is None: out = np.empty(arr.shape, arr.dtype) n = arr.shape[0] weights = self.weights[:n] weights = np.expand_dims(weights, list(range(1, arr.ndim))) # α*P0, α*P1, α*P2, ..., α*Pn np.multiply(self.alpha, arr, out) # α*P0/(1-α)^0, α*P1/(1-α)^1, α*P2/(1-α)^2, ..., α*Pn/(1-α)^n np.divide(out, weights, out) # α*P0/(1-α)^0, α*P0/(1-α)^0 + α*P1/(1-α)^1, ... np.cumsum(out, axis=0, out=out) # (α*P0/(1-α)^0)*(1-α)^0, (α*P0/(1-α)^0 + α*P1/(1-α)^1)*(1-α)^1, ... np.multiply(out, weights, out) # Add the previous output if self.prev is None: self.prev = arr[:1] out += self.prev * np.expand_dims(self.weights[1 : n + 1], list(range(1, arr.ndim))) self.prev = out[-1:] return out
[docs] def compute2(self, arr: npt.NDArray) -> npt.NDArray: """ Compute the Exponentially Weighted Moving Average (EWMA) of the input array. Args: arr: The input array to be smoothed. Returns: The smoothed array. """ n = arr.shape[0] if n > len(self.weights): self._precalc_weights(n) weights = self.weights[:n][::-1] weights = np.expand_dims(weights, list(range(1, arr.ndim))) result = np.cumsum(self.alpha * weights * arr, axis=0) result = result / weights # Handle the first call when prev is unset if self.prev is None: self.prev = arr[:1] result += self.prev * np.expand_dims(self.weights[1 : n + 1], list(range(1, arr.ndim))) # Store the result back into prev self.prev = result[-1] return result
[docs] def compute_sample(self, new_sample: npt.NDArray) -> npt.NDArray: if self.prev is None: self.prev = new_sample self.prev = self._step_func(new_sample, self.prev) return self.prev
[docs] class EWMASettings(ez.Settings): time_constant: float = 1.0 """The amount of time for the smoothed response of a unit step function to reach 1 - 1/e approx-eq 63.2%.""" axis: str | None = None accumulate: bool = True """If True, update the EWMA state with each sample. If False, only apply the current EWMA estimate without updating state (useful for inference periods where you don't want to adapt statistics)."""
[docs] @processor_state class EWMAState: alpha: float = field(default_factory=lambda: _alpha_from_tau(1.0, 1000.0)) zi: npt.NDArray | None = None
[docs] class EWMATransformer(BaseStatefulTransformer[EWMASettings, AxisArray, AxisArray, EWMAState]): # `accumulate` is read live in `_process` to gate state updates; other # fields are cached into state (alpha, zi) during `_reset_state`. NONRESET_SETTINGS_FIELDS = frozenset({"accumulate"}) def __call__(self, message: AxisArray) -> AxisArray: if np.prod(message.data.shape) == 0: return message return super().__call__(message) async def __acall__(self, message: AxisArray) -> AxisArray: if np.prod(message.data.shape) == 0: return message return await super().__acall__(message) def _hash_message(self, message: AxisArray) -> int: axis = self.settings.axis or message.dims[0] axis_idx = message.get_axis_idx(axis) sample_shape = message.data.shape[:axis_idx] + message.data.shape[axis_idx + 1 :] return hash((sample_shape, message.axes[axis].gain, message.key)) def _reset_state(self, message: AxisArray) -> None: axis = self.settings.axis or message.dims[0] self._state.alpha = _alpha_from_tau(self.settings.time_constant, message.axes[axis].gain) sub_dat = slice_along_axis(message.data, slice(None, 1, None), axis=message.get_axis_idx(axis)) self._state.zi = (1 - self._state.alpha) * sub_dat def _process(self, message: AxisArray) -> AxisArray: axis = self.settings.axis or message.dims[0] axis_idx = message.get_axis_idx(axis) xp = np if is_numpy_array(message.data) else get_namespace(message.data) if xp is not np and xp.__name__ == "mlx.core": expected, zf = _ewma_mlx_metal_xp(message.data, axis_idx, self._state.zi, self._state.alpha) if self.settings.accumulate: self._state.zi = zf elif self.settings.accumulate: # Normal behavior: update state with new samples. if self._state.zi is not None and not is_numpy_array(self._state.zi): self._state.zi = np.asarray(self._state.zi) expected, self._state.zi = sps.lfilter( [self._state.alpha], [1.0, self._state.alpha - 1.0], message.data, axis=axis_idx, zi=self._state.zi, ) else: # Process-only: compute output without updating state. if self._state.zi is not None and not is_numpy_array(self._state.zi): self._state.zi = np.asarray(self._state.zi) expected, _ = sps.lfilter( [self._state.alpha], [1.0, self._state.alpha - 1.0], message.data, axis=axis_idx, zi=self._state.zi, ) return replace(message, data=expected)
[docs] class EWMAUnit(BaseTransformerUnit[EWMASettings, AxisArray, AxisArray, EWMATransformer]): SETTINGS = EWMASettings