Source code for ezmsg.learn.process.sgd
import typing
import ezmsg.core as ez
import numpy as np
from ezmsg.baseproc import (
BaseAdaptiveTransformer,
BaseAdaptiveTransformerUnit,
processor_state,
)
from ezmsg.util.messages.axisarray import AxisArray
from ezmsg.util.messages.util import replace
from sklearn.exceptions import NotFittedError
from sklearn.linear_model import SGDClassifier
from ..util import ClassifierMessage
[docs]
class SGDDecoderSettings(ez.Settings):
alpha: float = 1e-5
eta0: float = 3e-4
loss: str = "hinge"
label_weights: dict[str, float] | None = None
settings_path: str | None = None
[docs]
class SGDDecoderTransformer(BaseAdaptiveTransformer[SGDDecoderSettings, AxisArray, ClassifierMessage, SGDDecoderState]):
"""
SGD-based online classifier.
Online Passive-Aggressive Algorithms
<http://jmlr.csail.mit.edu/papers/volume7/crammer06a/crammer06a.pdf>
K. Crammer, O. Dekel, J. Keshat, S. Shalev-Shwartz, Y. Singer - JMLR (2006)
"""
def _refreshed_model(self):
if self.settings.settings_path is not None:
import pickle
with open(self.settings.settings_path, "rb") as f:
model = pickle.load(f)
if self.settings.label_weights is not None:
model.class_weight = self.settings.label_weights
model.eta0 = self.settings.eta0
else:
model = SGDClassifier(
loss=self.settings.loss,
alpha=self.settings.alpha,
penalty="elasticnet",
learning_rate="adaptive",
eta0=self.settings.eta0,
early_stopping=False,
class_weight=self.settings.label_weights,
)
return model
def _reset_state(self, message: AxisArray) -> None:
self._state.model = self._refreshed_model()
def _process(self, message: AxisArray) -> ClassifierMessage | None:
if self._state.model is None or not message.data.size:
return None
if np.any(np.isnan(message.data)):
return None
try:
X = message.data.reshape((message.data.shape[0], -1))
result = self._state.model._predict_proba_lr(X)
except NotFittedError:
return None
out_axes = {}
if message.dims[0] in message.axes:
out_axes[message.dims[0]] = replace(
message.axes[message.dims[0]],
offset=message.axes[message.dims[0]].offset,
)
return ClassifierMessage(
data=result,
dims=message.dims[:1] + ["labels"],
axes=out_axes,
labels=list(self._state.model.class_weight.keys()),
key=message.key,
)
[docs]
def partial_fit(self, message: AxisArray) -> None:
if self._hash != 0:
self._reset_state(message)
self._hash = 0
if np.any(np.isnan(message.data)):
return
train_sample = message.data.reshape(1, -1)
if self._state.b_first_train:
self._state.model.partial_fit(
train_sample,
[message.attrs["trigger"].value],
classes=list(self.settings.label_weights.keys()),
)
self._state.b_first_train = False
else:
self._state.model.partial_fit(train_sample, [message.attrs["trigger"].value])
[docs]
class SGDDecoder(
BaseAdaptiveTransformerUnit[
SGDDecoderSettings,
AxisArray,
ClassifierMessage,
SGDDecoderTransformer,
]
):
SETTINGS = SGDDecoderSettings