Source code for ezmsg.learn.process.adaptive_linear_regressor
from dataclasses import field
import ezmsg.core as ez
import numpy as np
import pandas as pd
import river.linear_model
import river.optim
import sklearn.base
from ezmsg.baseproc import (
BaseAdaptiveTransformer,
BaseAdaptiveTransformerUnit,
processor_state,
)
from ezmsg.sigproc.sampler import SampleMessage
from ezmsg.util.messages.axisarray import AxisArray, replace
from ..util import AdaptiveLinearRegressor, RegressorType, get_regressor
[docs]
class AdaptiveLinearRegressorSettings(ez.Settings):
model_type: AdaptiveLinearRegressor = AdaptiveLinearRegressor.LINEAR
settings_path: str | None = None
model_kwargs: dict = field(default_factory=dict)
[docs]
@processor_state
class AdaptiveLinearRegressorState:
template: AxisArray | None = None
model: river.linear_model.base.GLM | sklearn.base.RegressorMixin | None = None
[docs]
class AdaptiveLinearRegressorTransformer(
BaseAdaptiveTransformer[
AdaptiveLinearRegressorSettings,
AxisArray,
AxisArray,
AdaptiveLinearRegressorState,
]
):
[docs]
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.settings = replace(self.settings, model_type=AdaptiveLinearRegressor(self.settings.model_type))
b_river = self.settings.model_type in [
AdaptiveLinearRegressor.LINEAR,
AdaptiveLinearRegressor.LOGISTIC,
]
if b_river:
self.settings.model_kwargs["l2"] = self.settings.model_kwargs.get("l2", 0.0)
if "learn_rate" in self.settings.model_kwargs:
self.settings.model_kwargs["optimizer"] = river.optim.SGD(self.settings.model_kwargs.pop("learn_rate"))
if self.settings.settings_path is not None:
# Load model from file
import pickle
with open(self.settings.settings_path, "rb") as f:
self.state.model = pickle.load(f)
if b_river:
# Override with kwargs?!
self.state.model.l2 = self.settings.model_kwargs["l2"]
if "optimizer" in self.settings.model_kwargs:
self.state.model.optimizer = self.settings.model_kwargs["optimizer"]
else:
print("TODO: Override sklearn model with kwargs")
else:
# Build model from scratch.
regressor_klass = get_regressor(RegressorType.ADAPTIVE, self.settings.model_type)
self.state.model = regressor_klass(**self.settings.model_kwargs)
def _hash_message(self, message: AxisArray) -> int:
# So far, nothing to reset so hash can be constant.
return -1
def _reset_state(self, message: AxisArray) -> None:
# So far, there is nothing to reset.
# .model is initialized in __init__
# .template is updated in partial_fit
pass
[docs]
def partial_fit(self, message: SampleMessage) -> None:
if np.any(np.isnan(message.sample.data)):
return
if self.settings.model_type in [
AdaptiveLinearRegressor.LINEAR,
AdaptiveLinearRegressor.LOGISTIC,
]:
x = pd.DataFrame.from_dict({k: v for k, v in zip(message.sample.axes["ch"].data, message.sample.data.T)})
y = pd.Series(
data=message.trigger.value.data[:, 0],
name=message.trigger.value.axes["ch"].data[0],
)
self.state.model.learn_many(x, y)
else:
X = message.sample.data
if message.sample.get_axis_idx("time") != 0:
X = np.moveaxis(X, message.sample.get_axis_idx("time"), 0)
self.state.model.partial_fit(X, message.trigger.value.data)
self.state.template = replace(
message.trigger.value,
data=np.empty_like(message.trigger.value.data),
key=message.trigger.value.key + "_pred",
)
def _process(self, message: AxisArray) -> AxisArray | None:
if self.state.template is None:
return AxisArray(np.array([]), dims=[""])
if not np.any(np.isnan(message.data)):
if self.settings.model_type in [
AdaptiveLinearRegressor.LINEAR,
AdaptiveLinearRegressor.LOGISTIC,
]:
# convert msg_in.data to something appropriate for river
x = pd.DataFrame.from_dict({k: v for k, v in zip(message.axes["ch"].data, message.data.T)})
preds = self.state.model.predict_many(x).values
else:
preds = self.state.model.predict(message.data)
return replace(
self.state.template,
data=preds.reshape((len(preds), -1)),
axes={
**self.state.template.axes,
"time": replace(
message.axes["time"],
offset=message.axes["time"].offset,
),
},
)
[docs]
class AdaptiveLinearRegressorUnit(
BaseAdaptiveTransformerUnit[
AdaptiveLinearRegressorSettings,
AxisArray,
AxisArray,
AdaptiveLinearRegressorTransformer,
]
):
SETTINGS = AdaptiveLinearRegressorSettings