"""Exponentially weighted moving average (EWMA) utilities and parameter conversion."""
import functools
from dataclasses import field
import ezmsg.core as ez
import numpy as np
import numpy.typing as npt
import scipy.signal as sps
from array_api_compat import get_namespace, is_numpy_array
from ezmsg.baseproc import BaseStatefulTransformer, BaseTransformerUnit, processor_state
from ezmsg.util.messages.axisarray import AxisArray, slice_along_axis
from ezmsg.util.messages.util import replace
def _ewma_mlx_metal_xp(data, axis_idx: int, zi, alpha: float):
"""Run EWMA through the MLX Metal helper while preserving scipy zi layout."""
import mlx.core as mx
from .util.ewma_mlx_metal import MAX_CHUNK_SIZE, ewma_mlx_metal
zi = mx.asarray(zi, dtype=data.dtype)
last_data_axis = data.ndim - 1
last_zi_axis = zi.ndim - 1
x_mx = mx.moveaxis(data, axis_idx, last_data_axis) if axis_idx != last_data_axis else data
zi_mx = mx.moveaxis(zi, axis_idx, last_zi_axis) if axis_idx != last_zi_axis else zi
cs = min(x_mx.shape[-1], MAX_CHUNK_SIZE)
y_mx, zf_mx = ewma_mlx_metal(x_mx, alpha, zi_mx, chunk_size=cs)
y = mx.moveaxis(y_mx, last_data_axis, axis_idx) if axis_idx != last_data_axis else y_mx
zf = mx.moveaxis(zf_mx, last_zi_axis, axis_idx) if axis_idx != last_zi_axis else zf_mx
return y, zf
def _tau_from_alpha(alpha: float, dt: float) -> float:
"""
Inverse of _alpha_from_tau. See that function for explanation.
"""
return -dt / np.log(1 - alpha)
def _alpha_from_tau(tau: float, dt: float) -> float:
"""
# https://en.wikipedia.org/wiki/Exponential_smoothing#Time_constant
:param tau: The amount of time for the smoothed response of a unit step function to reach
1 - 1/e approx-eq 63.2%.
:param dt: sampling period, or 1 / sampling_rate.
:return: alpha, the "fading factor" in exponential smoothing.
"""
return 1 - np.exp(-dt / tau)
[docs]
def ewma_step(sample: npt.NDArray, zi: npt.NDArray, alpha: float, beta: float | None = None):
"""
Do an exponentially weighted moving average step.
Args:
sample: The new sample.
zi: The output of the previous step.
alpha: Fading factor.
beta: Persisting factor. If None, it is calculated as 1-alpha.
Returns:
alpha * sample + beta * zi
"""
# Potential micro-optimization:
# Current: scalar-arr multiplication, scalar-arr multiplication, arr-arr addition
# Alternative: arr-arr subtraction, arr-arr multiplication, arr-arr addition
# return zi + alpha * (new_sample - zi)
beta = beta or (1 - alpha)
return alpha * sample + beta * zi
[docs]
class EWMA_Deprecated:
"""
Grabbed these methods from https://stackoverflow.com/a/70998068 and other answers in that topic,
but they ended up being slower than the scipy.signal.lfilter method.
Additionally, `compute` and `compute2` suffer from potential errors as the vector length increases
and beta**n approaches zero.
"""
[docs]
def __init__(self, alpha: float, max_len: int):
self.alpha = alpha
self.beta = 1 - alpha
self.prev: npt.NDArray | None = None
self.weights = np.empty((max_len + 1,), float)
self._precalc_weights(max_len)
self._step_func = functools.partial(ewma_step, alpha=self.alpha, beta=self.beta)
def _precalc_weights(self, n: int):
# (1-α)^0, (1-α)^1, (1-α)^2, ..., (1-α)^n
np.power(self.beta, np.arange(n + 1), out=self.weights)
[docs]
def compute(self, arr: npt.NDArray, out: npt.NDArray | None = None) -> npt.NDArray:
if out is None:
out = np.empty(arr.shape, arr.dtype)
n = arr.shape[0]
weights = self.weights[:n]
weights = np.expand_dims(weights, list(range(1, arr.ndim)))
# α*P0, α*P1, α*P2, ..., α*Pn
np.multiply(self.alpha, arr, out)
# α*P0/(1-α)^0, α*P1/(1-α)^1, α*P2/(1-α)^2, ..., α*Pn/(1-α)^n
np.divide(out, weights, out)
# α*P0/(1-α)^0, α*P0/(1-α)^0 + α*P1/(1-α)^1, ...
np.cumsum(out, axis=0, out=out)
# (α*P0/(1-α)^0)*(1-α)^0, (α*P0/(1-α)^0 + α*P1/(1-α)^1)*(1-α)^1, ...
np.multiply(out, weights, out)
# Add the previous output
if self.prev is None:
self.prev = arr[:1]
out += self.prev * np.expand_dims(self.weights[1 : n + 1], list(range(1, arr.ndim)))
self.prev = out[-1:]
return out
[docs]
def compute2(self, arr: npt.NDArray) -> npt.NDArray:
"""
Compute the Exponentially Weighted Moving Average (EWMA) of the input array.
Args:
arr: The input array to be smoothed.
Returns:
The smoothed array.
"""
n = arr.shape[0]
if n > len(self.weights):
self._precalc_weights(n)
weights = self.weights[:n][::-1]
weights = np.expand_dims(weights, list(range(1, arr.ndim)))
result = np.cumsum(self.alpha * weights * arr, axis=0)
result = result / weights
# Handle the first call when prev is unset
if self.prev is None:
self.prev = arr[:1]
result += self.prev * np.expand_dims(self.weights[1 : n + 1], list(range(1, arr.ndim)))
# Store the result back into prev
self.prev = result[-1]
return result
[docs]
def compute_sample(self, new_sample: npt.NDArray) -> npt.NDArray:
if self.prev is None:
self.prev = new_sample
self.prev = self._step_func(new_sample, self.prev)
return self.prev
[docs]
class EWMASettings(ez.Settings):
time_constant: float = 1.0
"""The amount of time for the smoothed response of a unit step function to reach 1 - 1/e approx-eq 63.2%."""
axis: str | None = None
accumulate: bool = True
"""If True, update the EWMA state with each sample. If False, only apply
the current EWMA estimate without updating state (useful for inference
periods where you don't want to adapt statistics)."""
[docs]
@processor_state
class EWMAState:
alpha: float = field(default_factory=lambda: _alpha_from_tau(1.0, 1000.0))
zi: npt.NDArray | None = None
[docs]
class EWMAUnit(BaseTransformerUnit[EWMASettings, AxisArray, AxisArray, EWMATransformer]):
SETTINGS = EWMASettings