Source code for ezmsg.learn.process.adaptive_linear_regressor
"""Adaptive linear regressor processor.
.. note::
This module supports the Array API standard via
``array_api_compat.get_namespace()``. NaN checks and axis permutations
use Array API operations; a NumPy boundary is applied before sklearn
``partial_fit``/``predict`` and before river ``learn_many``/``predict_many``.
"""
import copy
import typing
from dataclasses import field
import ezmsg.core as ez
import numpy as np
import pandas as pd
import river.linear_model
import river.optim
import sklearn.base
from array_api_compat import get_namespace, is_numpy_array
from ezmsg.baseproc import (
BaseAdaptiveTransformer,
BaseAdaptiveTransformerUnit,
processor_state,
)
from ezmsg.util.messages.axisarray import AxisArray, replace
from ..util import AdaptiveLinearRegressor, RegressorType, get_regressor
[docs]
class AdaptiveLinearRegressorSettings(ez.Settings):
model_type: AdaptiveLinearRegressor = AdaptiveLinearRegressor.LINEAR
settings_path: str | None = None
model_kwargs: dict = field(default_factory=dict)
[docs]
@processor_state
class AdaptiveLinearRegressorState:
template: AxisArray | None = None
model: (
river.linear_model.base.GLM
| dict[typing.Hashable, river.linear_model.base.GLM]
| sklearn.base.RegressorMixin
| None
) = None
def _normalize_axis_label(label):
dtype_names = getattr(getattr(label, "dtype", None), "names", None)
if dtype_names is not None:
if "label" in dtype_names:
return str(label["label"])
return tuple((name, _normalize_axis_label(label[name])) for name in dtype_names)
if isinstance(label, np.generic):
return label.item()
try:
hash(label)
return label
except TypeError:
return str(label)
def _axis_labels(axis_data) -> list:
return [_normalize_axis_label(label) for label in axis_data]
def _prediction_template(message: AxisArray) -> AxisArray:
return replace(
message,
data=np.empty_like(message.data),
key=message.key + "_pred",
)
def _prediction_template_from_signal(message: AxisArray, output_labels: list[typing.Hashable]) -> AxisArray:
n_time = message.data.shape[message.get_axis_idx("time")]
return AxisArray(
data=np.empty((n_time, len(output_labels))),
dims=["time", "ch"],
axes={
"time": replace(message.axes["time"], offset=message.axes["time"].offset),
"ch": AxisArray.CoordinateAxis(data=np.asarray(output_labels), dims=["ch"]),
},
key=message.key + "_pred",
)
def _output_labels(message: AxisArray) -> list[typing.Hashable]:
if "ch" not in message.axes:
data = np.asarray(message.data)
width = data.shape[-1] if data.ndim > 1 else 1
return [f"ch{idx}" for idx in range(width)]
return _axis_labels(message.axes["ch"].data)
[docs]
class AdaptiveLinearRegressorTransformer(
BaseAdaptiveTransformer[
AdaptiveLinearRegressorSettings,
AxisArray,
AxisArray,
AdaptiveLinearRegressorState,
]
):
[docs]
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
model_kwargs = dict(self.settings.model_kwargs)
model_type = AdaptiveLinearRegressor(self.settings.model_type)
b_river = self.settings.model_type in [
AdaptiveLinearRegressor.LINEAR,
AdaptiveLinearRegressor.LOGISTIC,
]
if b_river:
model_kwargs.setdefault("l2", 0.0)
if "learn_rate" in model_kwargs:
model_kwargs["optimizer"] = river.optim.SGD(model_kwargs.pop("learn_rate"))
self.settings = replace(
self.settings,
model_type=model_type,
model_kwargs=model_kwargs,
)
self._regressor_klass = get_regressor(RegressorType.ADAPTIVE, self.settings.model_type)
if self.settings.settings_path is not None:
# Load model from file
import pickle
with open(self.settings.settings_path, "rb") as f:
model = pickle.load(f)
if b_river:
models = model.values() if isinstance(model, dict) else [model]
for river_model in models:
river_model.l2 = self.settings.model_kwargs["l2"]
if "optimizer" in self.settings.model_kwargs:
river_model.optimizer = copy.deepcopy(self.settings.model_kwargs["optimizer"])
else:
print("TODO: Override sklearn model with kwargs")
self.state.model = model
elif not b_river:
# Build model from scratch.
self.state.model = self._regressor_klass(**self.settings.model_kwargs)
def _hash_message(self, message: AxisArray) -> int:
# So far, nothing to reset so hash can be constant.
return -1
def _reset_state(self, message: AxisArray) -> None:
# So far, there is nothing to reset.
# .model is initialized in __init__
# .template is updated in partial_fit
pass
def _prediction_labels(self, n_outputs: int) -> list[typing.Hashable]:
if self.state.template is not None:
return _output_labels(self.state.template)
if isinstance(self.state.model, dict):
return list(self.state.model.keys())
return [f"ch{idx}" for idx in range(n_outputs)]
[docs]
def partial_fit(self, message: AxisArray) -> None:
xp = get_namespace(message.data)
if xp.any(xp.isnan(message.data)):
return
if self.settings.model_type in [
AdaptiveLinearRegressor.LINEAR,
AdaptiveLinearRegressor.LOGISTIC,
]:
# river path: needs numpy/pandas
data_np = np.asarray(message.data) if not is_numpy_array(message.data) else message.data
x = pd.DataFrame(data_np, columns=[f"f{i}" for i in range(data_np.shape[1])])
targets = message.attrs["trigger"].value
target_np = np.asarray(targets.data)
if target_np.ndim == 1:
target_np = target_np[:, None]
target_labels = _output_labels(targets)
if self.state.model is None:
if len(target_labels) == 1:
self.state.model = self._regressor_klass(**copy.deepcopy(self.settings.model_kwargs))
else:
models = {}
for label in target_labels:
models[label] = self._regressor_klass(**copy.deepcopy(self.settings.model_kwargs))
self.state.model = {label: models[label] for label in target_labels}
models = self.state.model
if len(target_labels) == 1 and not isinstance(models, dict):
models = {target_labels[0]: models}
if set(target_labels) != set(models.keys()):
ez.logger.error(f"Target labels ({target_labels}) does not match model labels ({list(models.keys())}).")
raise ValueError("Target labels do not match model labels.")
for idx, label in enumerate(target_labels):
y = pd.Series(data=target_np[:, idx], name=label)
models[label].learn_many(x, y)
else:
# sklearn path: permute then convert to numpy
X = message.data
ax_idx = message.get_axis_idx("time")
if ax_idx != 0:
perm = (ax_idx,) + tuple(i for i in range(X.ndim) if i != ax_idx)
X = xp.permute_dims(X, perm)
X_np = np.asarray(X) if not is_numpy_array(X) else X
self.state.model.partial_fit(X_np, message.attrs["trigger"].value.data)
self.state.template = _prediction_template(message.attrs["trigger"].value)
def _process(self, message: AxisArray) -> AxisArray | None:
if self.state.model is None:
return AxisArray(np.array([]), dims=[""])
xp = get_namespace(message.data)
if not xp.any(xp.isnan(message.data)):
if self.settings.model_type in [
AdaptiveLinearRegressor.LINEAR,
AdaptiveLinearRegressor.LOGISTIC,
]:
# river path: needs numpy/pandas
data_np = np.asarray(message.data) if not is_numpy_array(message.data) else message.data
x = pd.DataFrame(data_np, columns=[f"f{i}" for i in range(data_np.shape[1])])
n_outputs = len(self.state.model) if isinstance(self.state.model, dict) else 1
out_labels = self._prediction_labels(n_outputs)
if isinstance(self.state.model, dict):
pred_cols = []
for label in out_labels:
model = self.state.model.get(label)
if model is None:
pred_cols.append(np.zeros(len(x), dtype=float))
else:
pred_cols.append(model.predict_many(x).to_numpy())
preds = np.column_stack(pred_cols)
else:
first_col = self.state.model.predict_many(x).to_numpy()
if len(out_labels) == 1:
preds = first_col[:, None]
else:
zeros = np.zeros((len(x), len(out_labels) - 1), dtype=float)
preds = np.column_stack([first_col, zeros])
else:
# sklearn path: needs numpy
data_np = np.asarray(message.data) if not is_numpy_array(message.data) else message.data
preds = self.state.model.predict(data_np)
preds = preds.reshape((len(preds), -1))
template = self.state.template
if template is None:
template = _prediction_template_from_signal(message, self._prediction_labels(preds.shape[1]))
self.state.template = template
return replace(
template,
data=preds,
axes={
**template.axes,
"time": message.axes["time"],
},
)
[docs]
class AdaptiveLinearRegressorUnit(
BaseAdaptiveTransformerUnit[
AdaptiveLinearRegressorSettings,
AxisArray,
AxisArray,
AdaptiveLinearRegressorTransformer,
]
):
SETTINGS = AdaptiveLinearRegressorSettings