ezmsg.learn.process.sklearn#
Classes
- class SklearnModelProcessor(*args, **kwargs)[source]#
Bases:
BaseAdaptiveTransformer[SklearnModelSettings,AxisArray,AxisArray,SklearnModelState]Processor that wraps a scikit-learn, River, or HMMLearn model for use in the ezmsg framework.
This processor supports: - fit, partial_fit, or River’s learn_many/learn_one for training. - predict, River’s predict_many, or predict_one for inference. - Optional model checkpoint loading and saving.
The processor expects and outputs AxisArray messages with a “ch” (channel) axis.
Settings:#
- model_classstr
Full path to the sklearn or River model class to use. Example: “sklearn.linear_model.SGDClassifier” or “river.linear_model.LogisticRegression”
- model_kwargsdict[str, typing.Any], optional
Additional keyword arguments passed to the model constructor.
- checkpoint_pathstr, optional
Path to a pickle file to load a previously saved model. If provided, the model will be restored from this path at startup.
- partial_fit_classesnp.ndarray, optional
For classifiers that require all class labels to be specified during partial_fit.
Example:#
```python processor = SklearnModelProcessor(
- settings=SklearnModelSettings(
model_class=’sklearn.linear_model.SGDClassifier’, model_kwargs={‘loss’: ‘log_loss’}, partial_fit_classes=np.array([0, 1]),
)
)#
- partial_fit(message)[source]#
- Return type:
- Parameters:
message (SampleMessage)
- class SklearnModelSettings(model_class, model_kwargs=None, checkpoint_path=None, partial_fit_classes=None)[source]#
Bases:
Settings- Parameters:
- model_class: str#
Full path to the sklearn model class Example: ‘sklearn.linear_model.LinearRegression’
- model_kwargs: dict[str, Any] = None#
Additional keyword arguments to pass to the model constructor. Example: {‘fit_intercept’: True, ‘normalize’: False}
- checkpoint_path: str | None = None#
Path to a checkpoint file to load the model from. If provided, the model will be initialized from this checkpoint. Example: ‘path/to/checkpoint.pkl’
- class SklearnModelUnit(*args, settings=None, **kwargs)[source]#
Bases:
BaseAdaptiveTransformerUnit[SklearnModelSettings,AxisArray,AxisArray,SklearnModelProcessor]Unit wrapper for the SklearnModelProcessor.
This unit provides a plug-and-play interface for using a scikit-learn or River model in an ezmsg graph-based system. It takes in AxisArray inputs and outputs predictions in the same format, optionally performing training via partial_fit or fit.
Example:#
```python unit = SklearnModelUnit(
- settings=SklearnModelSettings(
model_class=’sklearn.linear_model.SGDClassifier’, model_kwargs={‘loss’: ‘log_loss’}, partial_fit_classes=np.array([0, 1]),
)
)#
- SETTINGS#
alias of
SklearnModelSettings
- Parameters:
settings (Settings | None)
- class SklearnModelSettings(model_class, model_kwargs=None, checkpoint_path=None, partial_fit_classes=None)[source]#
Bases:
Settings- Parameters:
- model_class: str#
Full path to the sklearn model class Example: ‘sklearn.linear_model.LinearRegression’
- model_kwargs: dict[str, Any] = None#
Additional keyword arguments to pass to the model constructor. Example: {‘fit_intercept’: True, ‘normalize’: False}
- checkpoint_path: str | None = None#
Path to a checkpoint file to load the model from. If provided, the model will be initialized from this checkpoint. Example: ‘path/to/checkpoint.pkl’
- class SklearnModelProcessor(*args, **kwargs)[source]#
Bases:
BaseAdaptiveTransformer[SklearnModelSettings,AxisArray,AxisArray,SklearnModelState]Processor that wraps a scikit-learn, River, or HMMLearn model for use in the ezmsg framework.
This processor supports: - fit, partial_fit, or River’s learn_many/learn_one for training. - predict, River’s predict_many, or predict_one for inference. - Optional model checkpoint loading and saving.
The processor expects and outputs AxisArray messages with a “ch” (channel) axis.
Settings:#
- model_classstr
Full path to the sklearn or River model class to use. Example: “sklearn.linear_model.SGDClassifier” or “river.linear_model.LogisticRegression”
- model_kwargsdict[str, typing.Any], optional
Additional keyword arguments passed to the model constructor.
- checkpoint_pathstr, optional
Path to a pickle file to load a previously saved model. If provided, the model will be restored from this path at startup.
- partial_fit_classesnp.ndarray, optional
For classifiers that require all class labels to be specified during partial_fit.
Example:#
```python processor = SklearnModelProcessor(
- settings=SklearnModelSettings(
model_class=’sklearn.linear_model.SGDClassifier’, model_kwargs={‘loss’: ‘log_loss’}, partial_fit_classes=np.array([0, 1]),
)
)#
- partial_fit(message)[source]#
- Return type:
- Parameters:
message (SampleMessage)
- class SklearnModelUnit(*args, settings=None, **kwargs)[source]#
Bases:
BaseAdaptiveTransformerUnit[SklearnModelSettings,AxisArray,AxisArray,SklearnModelProcessor]Unit wrapper for the SklearnModelProcessor.
This unit provides a plug-and-play interface for using a scikit-learn or River model in an ezmsg graph-based system. It takes in AxisArray inputs and outputs predictions in the same format, optionally performing training via partial_fit or fit.
Example:#
```python unit = SklearnModelUnit(
- settings=SklearnModelSettings(
model_class=’sklearn.linear_model.SGDClassifier’, model_kwargs={‘loss’: ‘log_loss’}, partial_fit_classes=np.array([0, 1]),
)
)#
- SETTINGS#
alias of
SklearnModelSettings
- Parameters:
settings (Settings | None)