Source code for ezmsg.learn.process.linear_regressor

from dataclasses import field

import ezmsg.core as ez
import numpy as np
from ezmsg.baseproc import (
    BaseAdaptiveTransformer,
    BaseAdaptiveTransformerUnit,
    processor_state,
)
from ezmsg.sigproc.sampler import SampleMessage
from ezmsg.util.messages.axisarray import AxisArray, replace
from sklearn.linear_model._base import LinearModel

from ..util import RegressorType, StaticLinearRegressor, get_regressor


[docs] class LinearRegressorSettings(ez.Settings): model_type: StaticLinearRegressor = StaticLinearRegressor.LINEAR settings_path: str | None = None model_kwargs: dict = field(default_factory=dict)
[docs] @processor_state class LinearRegressorState: template: AxisArray | None = None model: LinearModel | None = None
[docs] class LinearRegressorTransformer( BaseAdaptiveTransformer[LinearRegressorSettings, AxisArray, AxisArray, LinearRegressorState] ): """ Linear regressor. Note: `partial_fit` is not 'partial'. It fully resets the model using the entirety of the SampleMessage provided. If you require adaptive fitting, try using the adaptive_linear_regressor module. """
[docs] def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) if self.settings.settings_path is not None: # Load model from file import pickle with open(self.settings.settings_path, "rb") as f: self.state.model = pickle.load(f) else: regressor_klass = get_regressor(RegressorType.STATIC, self.settings.model_type) self.state.model = regressor_klass(**self.settings.model_kwargs)
def _reset_state(self, message: AxisArray) -> None: # So far, there is nothing to reset. # .model and .template are initialized in __init__ pass
[docs] def partial_fit(self, message: SampleMessage) -> None: if np.any(np.isnan(message.sample.data)): return X = message.sample.data y = message.trigger.value.data # TODO: Resample should provide identical durations. self.state.model = self.state.model.fit(X[: y.shape[0]], y[: X.shape[0]]) self.state.template = replace( message.trigger.value, data=np.array([[]]), key=message.trigger.value.key + "_pred", )
def _process(self, message: AxisArray) -> AxisArray: if self.state.template is None: return AxisArray(np.array([[]]), dims=["time", "ch"]) preds = self.state.model.predict(message.data) return replace( self.state.template, data=preds, axes={ **self.state.template.axes, "time": replace( message.axes["time"], offset=message.axes["time"].offset, ), }, )
[docs] class AdaptiveLinearRegressorUnit( BaseAdaptiveTransformerUnit[ LinearRegressorSettings, AxisArray, AxisArray, LinearRegressorTransformer, ] ): SETTINGS = LinearRegressorSettings