Source code for ezmsg.sigproc.ewma

import functools
from dataclasses import field

import ezmsg.core as ez
import numpy as np
import numpy.typing as npt
import scipy.signal as sps
from ezmsg.baseproc import BaseStatefulTransformer, BaseTransformerUnit, processor_state
from ezmsg.util.messages.axisarray import AxisArray, slice_along_axis
from ezmsg.util.messages.util import replace


def _tau_from_alpha(alpha: float, dt: float) -> float:
    """
    Inverse of _alpha_from_tau. See that function for explanation.
    """
    return -dt / np.log(1 - alpha)


def _alpha_from_tau(tau: float, dt: float) -> float:
    """
    # https://en.wikipedia.org/wiki/Exponential_smoothing#Time_constant
    :param tau: The amount of time for the smoothed response of a unit step function to reach
        1 - 1/e approx-eq 63.2%.
    :param dt: sampling period, or 1 / sampling_rate.
    :return: alpha, the "fading factor" in exponential smoothing.
    """
    return 1 - np.exp(-dt / tau)


[docs] def ewma_step(sample: npt.NDArray, zi: npt.NDArray, alpha: float, beta: float | None = None): """ Do an exponentially weighted moving average step. Args: sample: The new sample. zi: The output of the previous step. alpha: Fading factor. beta: Persisting factor. If None, it is calculated as 1-alpha. Returns: alpha * sample + beta * zi """ # Potential micro-optimization: # Current: scalar-arr multiplication, scalar-arr multiplication, arr-arr addition # Alternative: arr-arr subtraction, arr-arr multiplication, arr-arr addition # return zi + alpha * (new_sample - zi) beta = beta or (1 - alpha) return alpha * sample + beta * zi
[docs] class EWMA_Deprecated: """ Grabbed these methods from https://stackoverflow.com/a/70998068 and other answers in that topic, but they ended up being slower than the scipy.signal.lfilter method. Additionally, `compute` and `compute2` suffer from potential errors as the vector length increases and beta**n approaches zero. """
[docs] def __init__(self, alpha: float, max_len: int): self.alpha = alpha self.beta = 1 - alpha self.prev: npt.NDArray | None = None self.weights = np.empty((max_len + 1,), float) self._precalc_weights(max_len) self._step_func = functools.partial(ewma_step, alpha=self.alpha, beta=self.beta)
def _precalc_weights(self, n: int): # (1-α)^0, (1-α)^1, (1-α)^2, ..., (1-α)^n np.power(self.beta, np.arange(n + 1), out=self.weights)
[docs] def compute(self, arr: npt.NDArray, out: npt.NDArray | None = None) -> npt.NDArray: if out is None: out = np.empty(arr.shape, arr.dtype) n = arr.shape[0] weights = self.weights[:n] weights = np.expand_dims(weights, list(range(1, arr.ndim))) # α*P0, α*P1, α*P2, ..., α*Pn np.multiply(self.alpha, arr, out) # α*P0/(1-α)^0, α*P1/(1-α)^1, α*P2/(1-α)^2, ..., α*Pn/(1-α)^n np.divide(out, weights, out) # α*P0/(1-α)^0, α*P0/(1-α)^0 + α*P1/(1-α)^1, ... np.cumsum(out, axis=0, out=out) # (α*P0/(1-α)^0)*(1-α)^0, (α*P0/(1-α)^0 + α*P1/(1-α)^1)*(1-α)^1, ... np.multiply(out, weights, out) # Add the previous output if self.prev is None: self.prev = arr[:1] out += self.prev * np.expand_dims(self.weights[1 : n + 1], list(range(1, arr.ndim))) self.prev = out[-1:] return out
[docs] def compute2(self, arr: npt.NDArray) -> npt.NDArray: """ Compute the Exponentially Weighted Moving Average (EWMA) of the input array. Args: arr: The input array to be smoothed. Returns: The smoothed array. """ n = arr.shape[0] if n > len(self.weights): self._precalc_weights(n) weights = self.weights[:n][::-1] weights = np.expand_dims(weights, list(range(1, arr.ndim))) result = np.cumsum(self.alpha * weights * arr, axis=0) result = result / weights # Handle the first call when prev is unset if self.prev is None: self.prev = arr[:1] result += self.prev * np.expand_dims(self.weights[1 : n + 1], list(range(1, arr.ndim))) # Store the result back into prev self.prev = result[-1] return result
[docs] def compute_sample(self, new_sample: npt.NDArray) -> npt.NDArray: if self.prev is None: self.prev = new_sample self.prev = self._step_func(new_sample, self.prev) return self.prev
[docs] class EWMASettings(ez.Settings): time_constant: float = 1.0 """The amount of time for the smoothed response of a unit step function to reach 1 - 1/e approx-eq 63.2%.""" axis: str | None = None accumulate: bool = True """If True, update the EWMA state with each sample. If False, only apply the current EWMA estimate without updating state (useful for inference periods where you don't want to adapt statistics)."""
[docs] @processor_state class EWMAState: alpha: float = field(default_factory=lambda: _alpha_from_tau(1.0, 1000.0)) zi: npt.NDArray | None = None
[docs] class EWMATransformer(BaseStatefulTransformer[EWMASettings, AxisArray, AxisArray, EWMAState]): def _hash_message(self, message: AxisArray) -> int: axis = self.settings.axis or message.dims[0] axis_idx = message.get_axis_idx(axis) sample_shape = message.data.shape[:axis_idx] + message.data.shape[axis_idx + 1 :] return hash((sample_shape, message.axes[axis].gain, message.key)) def _reset_state(self, message: AxisArray) -> None: axis = self.settings.axis or message.dims[0] self._state.alpha = _alpha_from_tau(self.settings.time_constant, message.axes[axis].gain) sub_dat = slice_along_axis(message.data, slice(None, 1, None), axis=message.get_axis_idx(axis)) self._state.zi = (1 - self._state.alpha) * sub_dat def _process(self, message: AxisArray) -> AxisArray: if np.prod(message.data.shape) == 0: return message axis = self.settings.axis or message.dims[0] axis_idx = message.get_axis_idx(axis) if self.settings.accumulate: # Normal behavior: update state with new samples expected, self._state.zi = sps.lfilter( [self._state.alpha], [1.0, self._state.alpha - 1.0], message.data, axis=axis_idx, zi=self._state.zi, ) else: # Process-only: compute output without updating state expected, _ = sps.lfilter( [self._state.alpha], [1.0, self._state.alpha - 1.0], message.data, axis=axis_idx, zi=self._state.zi, ) return replace(message, data=expected)
[docs] class EWMAUnit(BaseTransformerUnit[EWMASettings, AxisArray, AxisArray, EWMATransformer]): SETTINGS = EWMASettings
[docs] @ez.subscriber(BaseTransformerUnit.INPUT_SETTINGS) async def on_settings(self, msg: EWMASettings) -> None: """ Handle settings updates with smart reset behavior. Only resets state if `axis` changes (structural change). Changes to `time_constant` or `accumulate` are applied without resetting accumulated state. """ old_axis = self.SETTINGS.axis self.apply_settings(msg) if msg.axis != old_axis: # Axis changed - need full reset self.create_processor() else: # Only accumulate or time_constant changed - keep state self.processor.settings = msg