Source code for ezmsg.learn.dim_reduce.incremental_decomp

import typing

import numpy as np
import ezmsg.core as ez
from ezmsg.util.messages.axisarray import AxisArray, replace
from ezmsg.sigproc.base import (
    CompositeProcessor,
    BaseStatefulProcessor,
    BaseTransformerUnit,
)
from ezmsg.sigproc.window import WindowTransformer

from .adaptive_decomp import (
    IncrementalPCASettings,
    IncrementalPCATransformer,
    MiniBatchNMFSettings,
    MiniBatchNMFTransformer,
)


[docs] class IncrementalDecompSettings(ez.Settings): axis: str = "!time" n_components: int = 2 update_interval: float = 0.0 method: str = "pca" batch_size: typing.Optional[int] = None # PCA specific settings whiten: bool = False # NMF specific settings init: str = "random" beta_loss: str = "frobenius" tol: float = 1e-3 alpha_W: float = 0.0 alpha_H: typing.Union[float, str] = "same" l1_ratio: float = 0.0 forget_factor: float = 0.7
[docs] class IncrementalDecompTransformer( CompositeProcessor[IncrementalDecompSettings, AxisArray, AxisArray] ): """ Automates usage of IncrementalPCATransformer and MiniBatchNMFTransformer by using a WindowTransformer to extract training samples then calls partial_fit on the decomposition transformer. """ @staticmethod def _initialize_processors( settings: IncrementalDecompSettings, ) -> dict[str, BaseStatefulProcessor]: # Create the appropriate decomposition transformer if settings.method == "pca": decomp_settings = IncrementalPCASettings( axis=settings.axis, n_components=settings.n_components, batch_size=settings.batch_size, whiten=settings.whiten, ) decomp = IncrementalPCATransformer(settings=decomp_settings) else: # nmf decomp_settings = MiniBatchNMFSettings( axis=settings.axis, n_components=settings.n_components, batch_size=settings.batch_size if settings.batch_size else 1024, init=settings.init, beta_loss=settings.beta_loss, tol=settings.tol, alpha_W=settings.alpha_W, alpha_H=settings.alpha_H, l1_ratio=settings.l1_ratio, forget_factor=settings.forget_factor, ) decomp = MiniBatchNMFTransformer(settings=decomp_settings) # Create windowing processor if update_interval is specified if settings.update_interval > 0: # TODO: This `iter_axis` is likely incorrect. iter_axis = settings.axis[1:] if settings.axis.startswith("!") else "time" windowing = WindowTransformer( axis=iter_axis, window_dur=settings.update_interval, window_shift=settings.update_interval, zero_pad_until="none", ) return { "decomp": decomp, "windowing": windowing, } return {"decomp": decomp} def _partial_fit_windowed(self, train_msg: AxisArray) -> None: """ Helper function to do the partial_fit on the windowed message. """ if np.prod(train_msg.data.shape) > 0: # Windowing created a new "win" axis, but we don't actually want to use that # in the message we send to the decomp processor. axis_idx = train_msg.get_axis_idx("win") win_axis = train_msg.axes["win"] offsets = win_axis.value(np.asarray(range(train_msg.data.shape[axis_idx]))) for ix, _msg in enumerate(train_msg.iter_over_axis("win")): _msg = replace( _msg, axes={ **_msg.axes, "time": replace( _msg.axes["time"], offset=_msg.axes["time"].offset + offsets[ix], ), }, ) self._procs["decomp"].partial_fit(_msg)
[docs] def stateful_op( self, state: dict[str, tuple[typing.Any, int]] | None, message: AxisArray, ) -> tuple[dict[str, tuple[typing.Any, int]], AxisArray]: state = state or {} estim = self._procs["decomp"]._state.estimator if not hasattr(estim, "components_") or estim.components_ is None: # If the estimator has not been trained once, train it with the first message self._procs["decomp"].partial_fit(message) elif "windowing" in self._procs: state["windowing"], train_msg = self._procs["windowing"].stateful_op( state.get("windowing", None), message ) self._partial_fit_windowed(train_msg) # Process the incoming message state["decomp"], result = self._procs["decomp"].stateful_op( state.get("decomp", None), message ) return state, result
async def _aprocess(self, message: AxisArray) -> AxisArray: """ Asynchronously process the incoming message. This is nearly identical to the _process method, but the processors are called asynchronously. """ estim = self._procs["decomp"]._state.estimator if not hasattr(estim, "components_") or estim.components_ is None: # If the estimator has not been trained once, train it with the first message self._procs["decomp"].partial_fit(message) elif "windowing" in self._procs: # If windowing is enabled, extract training samples and perform partial_fit train_msg = await self._procs["windowing"].__acall__(message) self._partial_fit_windowed(train_msg) # Non async # Process the incoming message decomp_result = await self._procs["decomp"].__acall__(message) return decomp_result def _process(self, message: AxisArray) -> AxisArray: estim = self._procs["decomp"]._state.estimator if not hasattr(estim, "components_") or estim.components_ is None: # If the estimator has not been trained once, train it with the first message self._procs["decomp"].partial_fit(message) elif "windowing" in self._procs: # If windowing is enabled, extract training samples and perform partial_fit train_msg = self._procs["windowing"](message) self._partial_fit_windowed(train_msg) # Process the incoming message decomp_result = self._procs["decomp"](message) return decomp_result
[docs] class IncrementalDecompUnit( BaseTransformerUnit[ IncrementalDecompSettings, AxisArray, AxisArray, IncrementalDecompTransformer ] ): SETTINGS = IncrementalDecompSettings