Source code for ezmsg.learn.util

from enum import Enum
from dataclasses import dataclass, field
import typing

from ezmsg.util.messages.axisarray import AxisArray
import sklearn.linear_model
import river.linear_model
# from sklearn.neural_network import MLPClassifier


[docs] class RegressorType(str, Enum): ADAPTIVE = "adaptive" STATIC = "static"
[docs] class AdaptiveLinearRegressor(str, Enum): LINEAR = "linear" LOGISTIC = "logistic" SGD = "sgd" PAR = "par" # passive-aggressive
# MLP = "mlp"
[docs] class StaticLinearRegressor(str, Enum): LINEAR = "linear" RIDGE = "ridge"
ADAPTIVE_REGRESSORS = { AdaptiveLinearRegressor.LINEAR: river.linear_model.LinearRegression, AdaptiveLinearRegressor.LOGISTIC: river.linear_model.LogisticRegression, AdaptiveLinearRegressor.SGD: sklearn.linear_model.SGDRegressor, AdaptiveLinearRegressor.PAR: sklearn.linear_model.PassiveAggressiveRegressor, # AdaptiveLinearRegressor.MLP: MLPClassifier, } # Function to get a regressor by type and name
[docs] def get_regressor( regressor_type: typing.Union[RegressorType, str], regressor_name: typing.Union[AdaptiveLinearRegressor, StaticLinearRegressor, str], ): if isinstance(regressor_type, str): regressor_type = RegressorType(regressor_type) if regressor_type == RegressorType.ADAPTIVE: if isinstance(regressor_name, str): regressor_name = AdaptiveLinearRegressor(regressor_name) return ADAPTIVE_REGRESSORS[regressor_name] elif regressor_type == RegressorType.STATIC: if isinstance(regressor_name, str): regressor_name = StaticLinearRegressor(regressor_name) return STATIC_REGRESSORS[regressor_name] else: raise ValueError(f"Unknown regressor type: {regressor_type}")
STATIC_REGRESSORS = { StaticLinearRegressor.LINEAR: sklearn.linear_model.LinearRegression, StaticLinearRegressor.RIDGE: sklearn.linear_model.Ridge, }
[docs] @dataclass class ClassifierMessage(AxisArray): labels: list[str] = field(default_factory=list)