Real-Time Classification#
This guide shows how to use ezmsg-learn for real-time classification in streaming pipelines.
Overview#
ezmsg-learn provides machine learning components that integrate with ezmsg pipelines. Key features include:
Pre-trained models: Load and apply existing classifiers
Online learning: Update models incrementally with streaming data
Flexible backends: Support for scikit-learn, PyTorch, and River models
Available Classifiers#
ezmsg-learn includes several classifier types:
Classifier |
Description |
Use Case |
|---|---|---|
|
Shrinkage Linear Discriminant Analysis |
BCI, small datasets |
|
Wrapper for any scikit-learn model |
General ML tasks |
|
Stochastic Gradient Descent |
Online learning |
|
Multi-layer Perceptron (PyTorch) |
Complex patterns |
Using a Pre-Trained SLDA Classifier#
The simplest approach is to use a pre-trained model:
from ezmsg.learn.process.slda import SLDA, SLDASettings
classifier = SLDA(
SLDASettings(
settings_path="path/to/trained_model.pkl",
axis="time", # Axis containing samples
)
)
Input format: AxisArray[time, features] where features are flattened from your pipeline.
Output format: ClassifierMessage[time, classes] with class probabilities.
Training an SLDA model (offline):
import pickle
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis as LDA
# Train offline with your data
X_train = ... # shape: (n_samples, n_features)
y_train = ... # shape: (n_samples,)
lda = LDA(solver="lsqr", shrinkage="auto")
lda.fit(X_train, y_train)
# Save for use in ezmsg
with open("trained_model.pkl", "wb") as f:
pickle.dump(lda, f)
Using Scikit-Learn Models#
SklearnModelUnit wraps any scikit-learn compatible model:
from ezmsg.learn.process.sklearn import SklearnModelUnit, SklearnModelSettings
import numpy as np
classifier = SklearnModelUnit(
SklearnModelSettings(
model_class="sklearn.linear_model.SGDClassifier",
model_kwargs={
"loss": "log_loss", # For probability outputs
"warm_start": True,
},
partial_fit_classes=np.array([0, 1]), # Required for online learning
)
)
Loading a pre-trained model:
classifier = SklearnModelUnit(
SklearnModelSettings(
model_class="sklearn.linear_model.SGDClassifier",
checkpoint_path="path/to/saved_model.pkl",
)
)
Online Learning#
For models that support partial_fit, you can update them during streaming:
from ezmsg.learn.process.sklearn import SklearnModelProcessor, SklearnModelSettings
from ezmsg.baseproc import SampleTriggerMessage
from ezmsg.util.messages.util import replace
# Create processor with online learning support
processor = SklearnModelProcessor(
settings=SklearnModelSettings(
model_class="sklearn.linear_model.SGDClassifier",
model_kwargs={"loss": "log_loss"},
partial_fit_classes=np.array([0, 1]),
)
)
# Training with labeled samples
sample_msg = replace(
feature_array, # AxisArray with features
attrs={"trigger": SampleTriggerMessage(value=label_value)}
)
processor.partial_fit(sample_msg)
# Prediction (after training)
prediction = processor(input_features)
Complete Pipeline Example#
Here’s a complete BCI classification pipeline:
import ezmsg.core as ez
from ezmsg.lsl.inlet import LSLInletUnit, LSLInletSettings, LSLInfo
from ezmsg.lsl.outlet import LSLOutletUnit, LSLOutletSettings
from ezmsg.sigproc.butterworthfilter import ButterworthFilter, ButterworthFilterSettings
from ezmsg.sigproc.window import Window, WindowSettings
from ezmsg.sigproc.spectrum import Spectrum, SpectrumSettings
from ezmsg.sigproc.aggregate import RangedAggregate, RangedAggregateSettings, AggregationFunction
from ezmsg.learn.process.slda import SLDA, SLDASettings
components = {
# Data acquisition
"LSL_IN": LSLInletUnit(
LSLInletSettings(info=LSLInfo(name="EEG", type="EEG"))
),
# Signal processing
"FILTER": ButterworthFilter(
ButterworthFilterSettings(order=4, cuton=8.0, cutoff=30.0)
),
"WINDOW": Window(
WindowSettings(window_dur=1.0, window_shift=0.5)
),
"SPECTRUM": Spectrum(SpectrumSettings(window="hann")),
"BANDPOWER": RangedAggregate(
RangedAggregateSettings(
axis="freq",
bands=[(8.0, 12.0), (18.0, 25.0)],
operation=AggregationFunction.MEAN,
)
),
# Classification
"CLASSIFIER": SLDA(
SLDASettings(settings_path="model.pkl", axis="time")
),
# Output
"LSL_OUT": LSLOutletUnit(
LSLOutletSettings(stream_name="Predictions", stream_type="Markers")
),
}
connections = (
(components["LSL_IN"].OUTPUT_SIGNAL, components["FILTER"].INPUT_SIGNAL),
(components["FILTER"].OUTPUT_SIGNAL, components["WINDOW"].INPUT_SIGNAL),
(components["WINDOW"].OUTPUT_SIGNAL, components["SPECTRUM"].INPUT_SIGNAL),
(components["SPECTRUM"].OUTPUT_SIGNAL, components["BANDPOWER"].INPUT_SIGNAL),
(components["BANDPOWER"].OUTPUT_SIGNAL, components["CLASSIFIER"].INPUT_SIGNAL),
(components["CLASSIFIER"].OUTPUT_SIGNAL, components["LSL_OUT"].INPUT_SIGNAL),
)
if __name__ == "__main__":
ez.run(components=components, connections=connections)
Feature Preparation#
Classifiers expect flattened 2D input [samples, features]. Multi-dimensional arrays
are automatically flattened along the channel dimension.
For example, if your bandpower output is [time=1, band=2, ch=8]:
The classifier receives shape
[1, 16](2 bands × 8 channels)Features are flattened in C-order (row-major)
Output Format#
Classification outputs use ClassifierMessage, which extends AxisArray with:
dims:
["time", "classes"]data: Probability scores for each class
labels: List of class names/identifiers
Example output shape: [time=1, classes=2] with probabilities for each class.
Tips for Better Performance#
Normalize features: Use
Scalerfrom ezmsg-sigproc before classificationfrom ezmsg.sigproc.scaler import Scaler, ScalerSettings scaler = Scaler(ScalerSettings(mode="zscore"))
Match training conditions: Ensure online features match offline training preprocessing
Window size: Larger windows give more stable features but higher latency
Feature selection: Start with relevant frequency bands for your application
Troubleshooting#
- “Model has not been fit yet”:
The model needs training data before prediction. Either: - Provide a
checkpoint_pathwith a pre-trained model - Callfit()orpartial_fit()before processing- Shape mismatch errors:
Verify input feature dimensions match trained model
Check
n_features_in_attribute of loaded models
- NaN in predictions:
Ensure input features don’t contain NaN values
Check for numerical stability in preprocessing