ezmsg.learn.process.sklearn#

Classes

class SklearnModelProcessor(*args, **kwargs)[source]#

Bases: BaseAdaptiveTransformer[SklearnModelSettings, AxisArray, AxisArray, SklearnModelState]

Processor that wraps a scikit-learn, River, or HMMLearn model for use in the ezmsg framework.

This processor supports: - fit, partial_fit, or River’s learn_many/learn_one for training. - predict, River’s predict_many, or predict_one for inference. - Optional model checkpoint loading and saving.

The processor expects and outputs AxisArray messages with a “ch” (channel) axis.

Settings:#

model_classstr

Full path to the sklearn or River model class to use. Example: “sklearn.linear_model.SGDClassifier” or “river.linear_model.LogisticRegression”

model_kwargsdict[str, typing.Any], optional

Additional keyword arguments passed to the model constructor.

checkpoint_pathstr, optional

Path to a pickle file to load a previously saved model. If provided, the model will be restored from this path at startup.

partial_fit_classesnp.ndarray, optional

For classifiers that require all class labels to be specified during partial_fit.

Example:#

```python processor = SklearnModelProcessor(

settings=SklearnModelSettings(

model_class=’sklearn.linear_model.SGDClassifier’, model_kwargs={‘loss’: ‘log_loss’}, partial_fit_classes=np.array([0, 1]),

)

)#

save_checkpoint(path)[source]#
Return type:

None

Parameters:

path (str)

load_checkpoint(path)[source]#
Return type:

None

Parameters:

path (str)

partial_fit(message)[source]#
Return type:

None

Parameters:

message (SampleMessage)

fit(X, y)[source]#
Return type:

None

Parameters:
class SklearnModelSettings(model_class, model_kwargs=None, checkpoint_path=None, partial_fit_classes=None)[source]#

Bases: Settings

Parameters:
  • model_class (str)

  • model_kwargs (dict[str, Any])

  • checkpoint_path (str | None)

  • partial_fit_classes (ndarray | None)

model_class: str#

Full path to the sklearn model class Example: ‘sklearn.linear_model.LinearRegression’

model_kwargs: dict[str, Any] = None#

Additional keyword arguments to pass to the model constructor. Example: {‘fit_intercept’: True, ‘normalize’: False}

checkpoint_path: str | None = None#

Path to a checkpoint file to load the model from. If provided, the model will be initialized from this checkpoint. Example: ‘path/to/checkpoint.pkl’

partial_fit_classes: ndarray | None = None#

The full list of classes to use for partial_fit calls. This must be provided to use partial_fit with classifiers.

__init__(model_class, model_kwargs=None, checkpoint_path=None, partial_fit_classes=None)#
Parameters:
  • model_class (str)

  • model_kwargs (dict[str, Any])

  • checkpoint_path (str | None)

  • partial_fit_classes (ndarray | None)

Return type:

None

class SklearnModelState[source]#

Bases: object

model: Any = None#
chan_ax: CoordinateAxis | None = None#
class SklearnModelUnit(*args, settings=None, **kwargs)[source]#

Bases: BaseAdaptiveTransformerUnit[SklearnModelSettings, AxisArray, AxisArray, SklearnModelProcessor]

Unit wrapper for the SklearnModelProcessor.

This unit provides a plug-and-play interface for using a scikit-learn or River model in an ezmsg graph-based system. It takes in AxisArray inputs and outputs predictions in the same format, optionally performing training via partial_fit or fit.

Example:#

```python unit = SklearnModelUnit(

settings=SklearnModelSettings(

model_class=’sklearn.linear_model.SGDClassifier’, model_kwargs={‘loss’: ‘log_loss’}, partial_fit_classes=np.array([0, 1]),

)

)#

SETTINGS#

alias of SklearnModelSettings

Parameters:

settings (Settings | None)

class SklearnModelSettings(model_class, model_kwargs=None, checkpoint_path=None, partial_fit_classes=None)[source]#

Bases: Settings

Parameters:
  • model_class (str)

  • model_kwargs (dict[str, Any])

  • checkpoint_path (str | None)

  • partial_fit_classes (ndarray | None)

model_class: str#

Full path to the sklearn model class Example: ‘sklearn.linear_model.LinearRegression’

model_kwargs: dict[str, Any] = None#

Additional keyword arguments to pass to the model constructor. Example: {‘fit_intercept’: True, ‘normalize’: False}

checkpoint_path: str | None = None#

Path to a checkpoint file to load the model from. If provided, the model will be initialized from this checkpoint. Example: ‘path/to/checkpoint.pkl’

partial_fit_classes: ndarray | None = None#

The full list of classes to use for partial_fit calls. This must be provided to use partial_fit with classifiers.

__init__(model_class, model_kwargs=None, checkpoint_path=None, partial_fit_classes=None)#
Parameters:
  • model_class (str)

  • model_kwargs (dict[str, Any])

  • checkpoint_path (str | None)

  • partial_fit_classes (ndarray | None)

Return type:

None

class SklearnModelState[source]#

Bases: object

model: Any = None#
chan_ax: CoordinateAxis | None = None#
class SklearnModelProcessor(*args, **kwargs)[source]#

Bases: BaseAdaptiveTransformer[SklearnModelSettings, AxisArray, AxisArray, SklearnModelState]

Processor that wraps a scikit-learn, River, or HMMLearn model for use in the ezmsg framework.

This processor supports: - fit, partial_fit, or River’s learn_many/learn_one for training. - predict, River’s predict_many, or predict_one for inference. - Optional model checkpoint loading and saving.

The processor expects and outputs AxisArray messages with a “ch” (channel) axis.

Settings:#

model_classstr

Full path to the sklearn or River model class to use. Example: “sklearn.linear_model.SGDClassifier” or “river.linear_model.LogisticRegression”

model_kwargsdict[str, typing.Any], optional

Additional keyword arguments passed to the model constructor.

checkpoint_pathstr, optional

Path to a pickle file to load a previously saved model. If provided, the model will be restored from this path at startup.

partial_fit_classesnp.ndarray, optional

For classifiers that require all class labels to be specified during partial_fit.

Example:#

```python processor = SklearnModelProcessor(

settings=SklearnModelSettings(

model_class=’sklearn.linear_model.SGDClassifier’, model_kwargs={‘loss’: ‘log_loss’}, partial_fit_classes=np.array([0, 1]),

)

)#

save_checkpoint(path)[source]#
Return type:

None

Parameters:

path (str)

load_checkpoint(path)[source]#
Return type:

None

Parameters:

path (str)

partial_fit(message)[source]#
Return type:

None

Parameters:

message (SampleMessage)

fit(X, y)[source]#
Return type:

None

Parameters:
class SklearnModelUnit(*args, settings=None, **kwargs)[source]#

Bases: BaseAdaptiveTransformerUnit[SklearnModelSettings, AxisArray, AxisArray, SklearnModelProcessor]

Unit wrapper for the SklearnModelProcessor.

This unit provides a plug-and-play interface for using a scikit-learn or River model in an ezmsg graph-based system. It takes in AxisArray inputs and outputs predictions in the same format, optionally performing training via partial_fit or fit.

Example:#

```python unit = SklearnModelUnit(

settings=SklearnModelSettings(

model_class=’sklearn.linear_model.SGDClassifier’, model_kwargs={‘loss’: ‘log_loss’}, partial_fit_classes=np.array([0, 1]),

)

)#

SETTINGS#

alias of SklearnModelSettings

Parameters:

settings (Settings | None)