Source code for ezmsg.sigproc.ewmfilter

"""Exponentially weighted moving average filter for streaming normalization."""

import asyncio
import typing

import ezmsg.core as ez
import numpy as np
from ezmsg.util.messages.axisarray import AxisArray
from ezmsg.util.messages.util import replace

from .window import Window, WindowSettings


[docs] class EWMSettings(ez.Settings): axis: str | None = None """Name of the axis to accumulate.""" zero_offset: bool = True """If true, we assume zero DC offset for input data."""
[docs] class EWMState(ez.State): buffer_queue: "asyncio.Queue[AxisArray]" signal_queue: "asyncio.Queue[AxisArray]"
[docs] class EWM(ez.Unit): """ Exponentially Weighted Moving Average Standardization. This is deprecated. Please use :obj:`ezmsg.sigproc.scaler.AdaptiveStandardScaler` instead. References https://stackoverflow.com/a/42926270 """ SETTINGS = EWMSettings STATE = EWMState INPUT_SIGNAL = ez.InputStream(AxisArray) INPUT_BUFFER = ez.InputStream(AxisArray) OUTPUT_SIGNAL = ez.OutputStream(AxisArray)
[docs] async def initialize(self) -> None: ez.logger.warning( "EWM/EWMFilter is deprecated and will be removed in a future version. Use AdaptiveStandardScaler instead." ) self.STATE.signal_queue = asyncio.Queue() self.STATE.buffer_queue = asyncio.Queue()
[docs] @ez.subscriber(INPUT_SIGNAL) async def on_signal(self, message: AxisArray) -> None: self.STATE.signal_queue.put_nowait(message)
[docs] @ez.subscriber(INPUT_BUFFER) async def on_buffer(self, message: AxisArray) -> None: self.STATE.buffer_queue.put_nowait(message)
[docs] @ez.publisher(OUTPUT_SIGNAL) async def sync_output(self) -> typing.AsyncGenerator: while True: signal = await self.STATE.signal_queue.get() buffer = await self.STATE.buffer_queue.get() # includes signal axis_name = self.SETTINGS.axis if axis_name is None: axis_name = signal.dims[0] axis_idx = signal.get_axis_idx(axis_name) buffer_len = buffer.shape[axis_idx] block_len = signal.shape[axis_idx] window = buffer_len - block_len alpha = 2 / (window + 1.0) alpha_rev = 1 - alpha pows = alpha_rev ** (np.arange(buffer_len + 1)) scale_arr = 1 / pows[:-1] pw0 = alpha * alpha_rev ** (buffer_len - 1) buffer_data = buffer.data buffer_data = np.moveaxis(buffer_data, axis_idx, 0) while scale_arr.ndim < buffer_data.ndim: scale_arr = scale_arr[..., None] def ewma(data: np.ndarray) -> np.ndarray: mult = scale_arr * data * pw0 out = scale_arr[::-1] * mult.cumsum(axis=0) if not self.SETTINGS.zero_offset: out = (data[0, :, np.newaxis] * pows[1:]).T + out return out mean = ewma(buffer_data) std = ewma((buffer_data - mean) ** 2.0) standardized = (buffer_data - mean) / np.sqrt(std).clip(1e-4) standardized = standardized[-signal.shape[axis_idx] :, ...] standardized = np.moveaxis(standardized, axis_idx, 0) yield self.OUTPUT_SIGNAL, replace(signal, data=standardized)
[docs] class EWMFilterSettings(ez.Settings): history_dur: float """Previous data to accumulate for standardization.""" axis: str | None = None """Name of the axis to accumulate.""" zero_offset: bool = True """If true, we assume zero DC offset for input data."""
[docs] class EWMFilter(ez.Collection): """ A :obj:`Collection` that splits the input into a branch that leads to :obj:`Window` which then feeds into :obj:`EWM` 's INPUT_BUFFER and another branch that feeds directly into :obj:`EWM` 's INPUT_SIGNAL. This is deprecated. Please use :obj:`ezmsg.sigproc.scaler.AdaptiveStandardScaler` instead. """ SETTINGS = EWMFilterSettings INPUT_SIGNAL = ez.InputStream(AxisArray) OUTPUT_SIGNAL = ez.OutputStream(AxisArray) WINDOW = Window() EWM = EWM()
[docs] def configure(self) -> None: self.EWM.apply_settings( EWMSettings( axis=self.SETTINGS.axis, zero_offset=self.SETTINGS.zero_offset, ) ) self.WINDOW.apply_settings( WindowSettings( axis=self.SETTINGS.axis, window_dur=self.SETTINGS.history_dur, window_shift=None, # 1:1 mode ) )
[docs] def network(self) -> ez.NetworkDefinition: return ( (self.INPUT_SIGNAL, self.WINDOW.INPUT_SIGNAL), (self.WINDOW.OUTPUT_SIGNAL, self.EWM.INPUT_BUFFER), (self.INPUT_SIGNAL, self.EWM.INPUT_SIGNAL), (self.EWM.OUTPUT_SIGNAL, self.OUTPUT_SIGNAL), )