Source code for ezmsg.learn.process.sgd
import typing
import ezmsg.core as ez
import numpy as np
from ezmsg.baseproc import GenAxisArray
from ezmsg.sigproc.sampler import SampleMessage
from ezmsg.util.generator import consumer
from ezmsg.util.messages.axisarray import AxisArray
from ezmsg.util.messages.util import replace
from sklearn.exceptions import NotFittedError
from sklearn.linear_model import SGDClassifier
from ..util import ClassifierMessage
[docs]
@consumer
def sgd_decoder(
alpha: float = 1.5e-5,
eta0: float = 1e-7, # Lower than what you'd use for offline training.
loss: str = "squared_hinge",
label_weights: dict[str, float] | None = None,
settings_path: str | None = None,
) -> typing.Generator[AxisArray | SampleMessage, ClassifierMessage | None, None]:
"""
Passive Aggressive Classifier
Online Passive-Aggressive Algorithms <http://jmlr.csail.mit.edu/papers/volume7/crammer06a/crammer06a.pdf>
K. Crammer, O. Dekel, J. Keshat, S. Shalev-Shwartz, Y. Singer - JMLR (2006)
Args:
alpha: Maximum step size (regularization)
eta0: The initial learning rate for the 'adaptive’ schedules.
loss: The loss function to be used:
hinge: equivalent to PA-I in the reference paper.
squared_hinge: equivalent to PA-II in the reference paper.
label_weights: An optional dictionary of label names and their relative weight.
e.g., {'Go': 31.0, 'Stop': 0.5}
If this is None then settings_path must be provided and the pre-trained model
settings_path: Path to the stored sklearn model pkl file.
Returns:
Generator that accepts `SampleMessage` for incremental training (`partial_fit`) and yields None,
or `AxisArray` for inference (`predict`) and yields a `ClassifierMessage`.
"""
# pre-init inputs and outputs
msg_out = ClassifierMessage(data=np.array([]), dims=[""])
# State variables:
if settings_path is not None:
import pickle
with open(settings_path, "rb") as f:
model = pickle.load(f)
if label_weights is not None:
model.class_weight = label_weights
# Overwrite eta0, probably with a value lower than what was used online.
model.eta0 = eta0
else:
model = SGDClassifier(
loss=loss,
alpha=alpha,
penalty="elasticnet",
learning_rate="adaptive",
eta0=eta0,
early_stopping=False,
class_weight=label_weights,
)
b_first_train = True
# TODO: template_out
while True:
msg_in: AxisArray | SampleMessage = yield msg_out
msg_out = None
if type(msg_in) is SampleMessage:
# SampleMessage used for training.
if not np.any(np.isnan(msg_in.sample.data)):
train_sample = msg_in.sample.data.reshape(1, -1)
if b_first_train:
model.partial_fit(
train_sample,
[msg_in.trigger.value],
classes=list(label_weights.keys()),
)
b_first_train = False
else:
model.partial_fit(train_sample, [msg_in.trigger.value])
elif msg_in.data.size:
# AxisArray used for inference
if not np.any(np.isnan(msg_in.data)):
try:
X = msg_in.data.reshape((msg_in.data.shape[0], -1))
result = model._predict_proba_lr(X)
except NotFittedError:
result = None
if result is not None:
out_axes = {}
if msg_in.dims[0] in msg_in.axes:
out_axes[msg_in.dims[0]] = replace(
msg_in.axes[msg_in.dims[0]],
offset=msg_in.axes[msg_in.dims[0]].offset,
)
msg_out = ClassifierMessage(
data=result,
dims=msg_in.dims[:1] + ["labels"],
axes=out_axes,
labels=list(model.class_weight.keys()),
key=msg_in.key,
)
[docs]
class SGDDecoderSettings(ez.Settings):
alpha: float = 1e-5
eta0: float = 3e-4
loss: str = "hinge"
label_weights: dict[str, float] | None = None
settings_path: str | None = None
[docs]
class SGDDecoder(GenAxisArray):
SETTINGS = SGDDecoderSettings
INPUT_SAMPLE = ez.InputStream(SampleMessage)
# Method to be implemented by subclasses to construct the specific generator
[docs]
def construct_generator(self):
self.STATE.gen = sgd_decoder(**self.SETTINGS.__dict__)
[docs]
@ez.subscriber(INPUT_SAMPLE)
async def on_sample(self, msg: SampleMessage) -> None:
_ = self.STATE.gen.send(msg)