ezmsg.learn.process.torch#
Classes
- class TorchModelProcessor(*args, **kwargs)[source]#
Bases:
BaseAdaptiveTransformer[TorchModelSettings,AxisArray,AxisArray,TorchModelState],TorchProcessorMixin,ModelInitMixin- partial_fit(message)[source]#
- Return type:
- Parameters:
message (SampleMessage)
- class TorchModelSettings(model_class, checkpoint_path=None, config_path=None, single_precision=True, device=None, model_kwargs=None, learning_rate=0.001, weight_decay=0.0001, loss_fn=None, loss_weights=None, scheduler_gamma=0.999)[source]#
Bases:
TorchSimpleSettings- Parameters:
- loss_fn: Module | dict[str, Module] | None = None#
Loss function(s) for the decoder. If using multiple heads, this should be a dictionary mapping head names to loss functions. The keys must match the output head names.
- loss_weights: dict[str, float] | None = None#
Weights for each loss function if using multiple heads. The keys must match the output head names. If None or missing/mismatched keys, losses will be unweighted.
- __init__(model_class, checkpoint_path=None, config_path=None, single_precision=True, device=None, model_kwargs=None, learning_rate=0.001, weight_decay=0.0001, loss_fn=None, loss_weights=None, scheduler_gamma=0.999)#
- class TorchModelState[source]#
Bases:
TorchSimpleState- scheduler: LRScheduler | None = None#
- class TorchModelUnit(*args, settings=None, **kwargs)[source]#
Bases:
BaseAdaptiveTransformerUnit[TorchModelSettings,AxisArray,AxisArray,TorchModelProcessor]- Parameters:
settings (Settings | None)
- SETTINGS#
alias of
TorchModelSettings
- class TorchProcessorMixin[source]#
Bases:
objectMixin with shared functionality for torch processors.
- class TorchSimpleProcessor(*args, **kwargs)[source]#
Bases:
BaseStatefulTransformer[TorchSimpleSettings,AxisArray,AxisArray,TorchSimpleState],TorchProcessorMixin,ModelInitMixin
- class TorchSimpleSettings(model_class, checkpoint_path=None, config_path=None, single_precision=True, device=None, model_kwargs=None)[source]#
Bases:
Settings- Parameters:
- model_class: str#
Fully qualified class path of the model to be used. Example: “my_module.MyModelClass” This class should inherit from torch.nn.Module.
- checkpoint_path: str | None = None#
Path to a checkpoint file containing model weights. If None, the model will be initialized with random weights. If parameters inferred from the weight sizes conflict with the settings / config, then the the inferred parameters will take priority and a warning will be logged.
- config_path: str | None = None#
Path to a config file containing model parameters. Parameters loaded from the config file will take priority over settings. If settings differ from config parameters then a warning will be logged. If checkpoint_path is provided then any parameters inferred from the weights will take priority over the config parameters.
- device: str | None = None#
Device to use for the model. If None, the device will be determined automatically, with preference for cuda > mps > cpu.
- model_kwargs: dict[str, Any] | None = None#
Additional keyword arguments to pass to the model constructor. This can include parameters like input_size, output_size, etc. If a config file is provided, these parameters will be updated with the config values. If a checkpoint file is provided, these parameters will be updated with the inferred values from the model weights.
- __init__(model_class, checkpoint_path=None, config_path=None, single_precision=True, device=None, model_kwargs=None)#
- class TorchSimpleUnit(*args, settings=None, **kwargs)[source]#
Bases:
BaseTransformerUnit[TorchSimpleSettings,AxisArray,AxisArray,TorchSimpleProcessor]- Parameters:
settings (Settings | None)
- SETTINGS#
alias of
TorchSimpleSettings
- class TorchSimpleSettings(model_class, checkpoint_path=None, config_path=None, single_precision=True, device=None, model_kwargs=None)[source]#
Bases:
Settings- Parameters:
- model_class: str#
Fully qualified class path of the model to be used. Example: “my_module.MyModelClass” This class should inherit from torch.nn.Module.
- checkpoint_path: str | None = None#
Path to a checkpoint file containing model weights. If None, the model will be initialized with random weights. If parameters inferred from the weight sizes conflict with the settings / config, then the the inferred parameters will take priority and a warning will be logged.
- config_path: str | None = None#
Path to a config file containing model parameters. Parameters loaded from the config file will take priority over settings. If settings differ from config parameters then a warning will be logged. If checkpoint_path is provided then any parameters inferred from the weights will take priority over the config parameters.
- device: str | None = None#
Device to use for the model. If None, the device will be determined automatically, with preference for cuda > mps > cpu.
- model_kwargs: dict[str, Any] | None = None#
Additional keyword arguments to pass to the model constructor. This can include parameters like input_size, output_size, etc. If a config file is provided, these parameters will be updated with the config values. If a checkpoint file is provided, these parameters will be updated with the inferred values from the model weights.
- __init__(model_class, checkpoint_path=None, config_path=None, single_precision=True, device=None, model_kwargs=None)#
- class TorchModelSettings(model_class, checkpoint_path=None, config_path=None, single_precision=True, device=None, model_kwargs=None, learning_rate=0.001, weight_decay=0.0001, loss_fn=None, loss_weights=None, scheduler_gamma=0.999)[source]#
Bases:
TorchSimpleSettings- Parameters:
- loss_fn: Module | dict[str, Module] | None = None#
Loss function(s) for the decoder. If using multiple heads, this should be a dictionary mapping head names to loss functions. The keys must match the output head names.
- loss_weights: dict[str, float] | None = None#
Weights for each loss function if using multiple heads. The keys must match the output head names. If None or missing/mismatched keys, losses will be unweighted.
- __init__(model_class, checkpoint_path=None, config_path=None, single_precision=True, device=None, model_kwargs=None, learning_rate=0.001, weight_decay=0.0001, loss_fn=None, loss_weights=None, scheduler_gamma=0.999)#
- class TorchModelState[source]#
Bases:
TorchSimpleState- scheduler: LRScheduler | None = None#
- class TorchProcessorMixin[source]#
Bases:
objectMixin with shared functionality for torch processors.
- class TorchSimpleProcessor(*args, **kwargs)[source]#
Bases:
BaseStatefulTransformer[TorchSimpleSettings,AxisArray,AxisArray,TorchSimpleState],TorchProcessorMixin,ModelInitMixin
- class TorchSimpleUnit(*args, settings=None, **kwargs)[source]#
Bases:
BaseTransformerUnit[TorchSimpleSettings,AxisArray,AxisArray,TorchSimpleProcessor]- Parameters:
settings (Settings | None)
- SETTINGS#
alias of
TorchSimpleSettings
- class TorchModelProcessor(*args, **kwargs)[source]#
Bases:
BaseAdaptiveTransformer[TorchModelSettings,AxisArray,AxisArray,TorchModelState],TorchProcessorMixin,ModelInitMixin- partial_fit(message)[source]#
- Return type:
- Parameters:
message (SampleMessage)
- class TorchModelUnit(*args, settings=None, **kwargs)[source]#
Bases:
BaseAdaptiveTransformerUnit[TorchModelSettings,AxisArray,AxisArray,TorchModelProcessor]- Parameters:
settings (Settings | None)
- SETTINGS#
alias of
TorchModelSettings