import functools
from dataclasses import field
import ezmsg.core as ez
import numpy as np
import numpy.typing as npt
import scipy.signal as sps
from ezmsg.baseproc import BaseStatefulTransformer, BaseTransformerUnit, processor_state
from ezmsg.util.messages.axisarray import AxisArray, slice_along_axis
from ezmsg.util.messages.util import replace
def _tau_from_alpha(alpha: float, dt: float) -> float:
"""
Inverse of _alpha_from_tau. See that function for explanation.
"""
return -dt / np.log(1 - alpha)
def _alpha_from_tau(tau: float, dt: float) -> float:
"""
# https://en.wikipedia.org/wiki/Exponential_smoothing#Time_constant
:param tau: The amount of time for the smoothed response of a unit step function to reach
1 - 1/e approx-eq 63.2%.
:param dt: sampling period, or 1 / sampling_rate.
:return: alpha, the "fading factor" in exponential smoothing.
"""
return 1 - np.exp(-dt / tau)
[docs]
def ewma_step(sample: npt.NDArray, zi: npt.NDArray, alpha: float, beta: float | None = None):
"""
Do an exponentially weighted moving average step.
Args:
sample: The new sample.
zi: The output of the previous step.
alpha: Fading factor.
beta: Persisting factor. If None, it is calculated as 1-alpha.
Returns:
alpha * sample + beta * zi
"""
# Potential micro-optimization:
# Current: scalar-arr multiplication, scalar-arr multiplication, arr-arr addition
# Alternative: arr-arr subtraction, arr-arr multiplication, arr-arr addition
# return zi + alpha * (new_sample - zi)
beta = beta or (1 - alpha)
return alpha * sample + beta * zi
[docs]
class EWMA_Deprecated:
"""
Grabbed these methods from https://stackoverflow.com/a/70998068 and other answers in that topic,
but they ended up being slower than the scipy.signal.lfilter method.
Additionally, `compute` and `compute2` suffer from potential errors as the vector length increases
and beta**n approaches zero.
"""
[docs]
def __init__(self, alpha: float, max_len: int):
self.alpha = alpha
self.beta = 1 - alpha
self.prev: npt.NDArray | None = None
self.weights = np.empty((max_len + 1,), float)
self._precalc_weights(max_len)
self._step_func = functools.partial(ewma_step, alpha=self.alpha, beta=self.beta)
def _precalc_weights(self, n: int):
# (1-α)^0, (1-α)^1, (1-α)^2, ..., (1-α)^n
np.power(self.beta, np.arange(n + 1), out=self.weights)
[docs]
def compute(self, arr: npt.NDArray, out: npt.NDArray | None = None) -> npt.NDArray:
if out is None:
out = np.empty(arr.shape, arr.dtype)
n = arr.shape[0]
weights = self.weights[:n]
weights = np.expand_dims(weights, list(range(1, arr.ndim)))
# α*P0, α*P1, α*P2, ..., α*Pn
np.multiply(self.alpha, arr, out)
# α*P0/(1-α)^0, α*P1/(1-α)^1, α*P2/(1-α)^2, ..., α*Pn/(1-α)^n
np.divide(out, weights, out)
# α*P0/(1-α)^0, α*P0/(1-α)^0 + α*P1/(1-α)^1, ...
np.cumsum(out, axis=0, out=out)
# (α*P0/(1-α)^0)*(1-α)^0, (α*P0/(1-α)^0 + α*P1/(1-α)^1)*(1-α)^1, ...
np.multiply(out, weights, out)
# Add the previous output
if self.prev is None:
self.prev = arr[:1]
out += self.prev * np.expand_dims(self.weights[1 : n + 1], list(range(1, arr.ndim)))
self.prev = out[-1:]
return out
[docs]
def compute2(self, arr: npt.NDArray) -> npt.NDArray:
"""
Compute the Exponentially Weighted Moving Average (EWMA) of the input array.
Args:
arr: The input array to be smoothed.
Returns:
The smoothed array.
"""
n = arr.shape[0]
if n > len(self.weights):
self._precalc_weights(n)
weights = self.weights[:n][::-1]
weights = np.expand_dims(weights, list(range(1, arr.ndim)))
result = np.cumsum(self.alpha * weights * arr, axis=0)
result = result / weights
# Handle the first call when prev is unset
if self.prev is None:
self.prev = arr[:1]
result += self.prev * np.expand_dims(self.weights[1 : n + 1], list(range(1, arr.ndim)))
# Store the result back into prev
self.prev = result[-1]
return result
[docs]
def compute_sample(self, new_sample: npt.NDArray) -> npt.NDArray:
if self.prev is None:
self.prev = new_sample
self.prev = self._step_func(new_sample, self.prev)
return self.prev
[docs]
class EWMASettings(ez.Settings):
time_constant: float = 1.0
"""The amount of time for the smoothed response of a unit step function to reach 1 - 1/e approx-eq 63.2%."""
axis: str | None = None
accumulate: bool = True
"""If True, update the EWMA state with each sample. If False, only apply
the current EWMA estimate without updating state (useful for inference
periods where you don't want to adapt statistics)."""
[docs]
@processor_state
class EWMAState:
alpha: float = field(default_factory=lambda: _alpha_from_tau(1.0, 1000.0))
zi: npt.NDArray | None = None
[docs]
class EWMAUnit(BaseTransformerUnit[EWMASettings, AxisArray, AxisArray, EWMATransformer]):
SETTINGS = EWMASettings
[docs]
@ez.subscriber(BaseTransformerUnit.INPUT_SETTINGS)
async def on_settings(self, msg: EWMASettings) -> None:
"""
Handle settings updates with smart reset behavior.
Only resets state if `axis` changes (structural change).
Changes to `time_constant` or `accumulate` are applied without
resetting accumulated state.
"""
old_axis = self.SETTINGS.axis
self.apply_settings(msg)
if msg.axis != old_axis:
# Axis changed - need full reset
self.create_processor()
else:
# Only accumulate or time_constant changed - keep state
self.processor.settings = msg