ezmsg.learn.process.torch#

Classes

class TorchModelProcessor(*args, **kwargs)[source]#

Bases: BaseAdaptiveTransformer[TorchModelSettings, AxisArray, AxisArray, TorchModelState], TorchProcessorMixin, ModelInitMixin

partial_fit(message)[source]#
Return type:

None

Parameters:

message (SampleMessage)

class TorchModelSettings(model_class, checkpoint_path=None, config_path=None, single_precision=True, device=None, model_kwargs=None, learning_rate=0.001, weight_decay=0.0001, loss_fn=None, loss_weights=None, scheduler_gamma=0.999)[source]#

Bases: TorchSimpleSettings

Parameters:
learning_rate: float = 0.001#

Learning rate for the optimizer

weight_decay: float = 0.0001#

Weight decay for the optimizer

loss_fn: Module | dict[str, Module] | None = None#

Loss function(s) for the decoder. If using multiple heads, this should be a dictionary mapping head names to loss functions. The keys must match the output head names.

loss_weights: dict[str, float] | None = None#

Weights for each loss function if using multiple heads. The keys must match the output head names. If None or missing/mismatched keys, losses will be unweighted.

scheduler_gamma: float = 0.999#

Learning scheduler decay rate. Set to 0.0 to disable the scheduler.

__init__(model_class, checkpoint_path=None, config_path=None, single_precision=True, device=None, model_kwargs=None, learning_rate=0.001, weight_decay=0.0001, loss_fn=None, loss_weights=None, scheduler_gamma=0.999)#
Parameters:
Return type:

None

class TorchModelState[source]#

Bases: TorchSimpleState

optimizer: Optimizer | None = None#
scheduler: LRScheduler | None = None#
class TorchModelUnit(*args, settings=None, **kwargs)[source]#

Bases: BaseAdaptiveTransformerUnit[TorchModelSettings, AxisArray, AxisArray, TorchModelProcessor]

Parameters:

settings (Settings | None)

SETTINGS#

alias of TorchModelSettings

async on_signal(message)[source]#
Return type:

AsyncGenerator

Parameters:

message (AxisArray)

class TorchProcessorMixin[source]#

Bases: object

Mixin with shared functionality for torch processors.

save_checkpoint(path)[source]#

Save the current model state to a checkpoint file.

Return type:

None

Parameters:
  • self (P)

  • path (str)

class TorchSimpleProcessor(*args, **kwargs)[source]#

Bases: BaseStatefulTransformer[TorchSimpleSettings, AxisArray, AxisArray, TorchSimpleState], TorchProcessorMixin, ModelInitMixin

class TorchSimpleSettings(model_class, checkpoint_path=None, config_path=None, single_precision=True, device=None, model_kwargs=None)[source]#

Bases: Settings

Parameters:
  • model_class (str)

  • checkpoint_path (str | None)

  • config_path (str | None)

  • single_precision (bool)

  • device (str | None)

  • model_kwargs (dict[str, Any] | None)

model_class: str#

Fully qualified class path of the model to be used. Example: “my_module.MyModelClass” This class should inherit from torch.nn.Module.

checkpoint_path: str | None = None#

Path to a checkpoint file containing model weights. If None, the model will be initialized with random weights. If parameters inferred from the weight sizes conflict with the settings / config, then the the inferred parameters will take priority and a warning will be logged.

config_path: str | None = None#

Path to a config file containing model parameters. Parameters loaded from the config file will take priority over settings. If settings differ from config parameters then a warning will be logged. If checkpoint_path is provided then any parameters inferred from the weights will take priority over the config parameters.

single_precision: bool = True#

Use single precision (float32) instead of double precision (float64)

device: str | None = None#

Device to use for the model. If None, the device will be determined automatically, with preference for cuda > mps > cpu.

model_kwargs: dict[str, Any] | None = None#

Additional keyword arguments to pass to the model constructor. This can include parameters like input_size, output_size, etc. If a config file is provided, these parameters will be updated with the config values. If a checkpoint file is provided, these parameters will be updated with the inferred values from the model weights.

__init__(model_class, checkpoint_path=None, config_path=None, single_precision=True, device=None, model_kwargs=None)#
Parameters:
  • model_class (str)

  • checkpoint_path (str | None)

  • config_path (str | None)

  • single_precision (bool)

  • device (str | None)

  • model_kwargs (dict[str, Any] | None)

Return type:

None

class TorchSimpleState[source]#

Bases: object

model: Module | None = None#
device: device | None = None#
chan_ax: dict[str, CoordinateAxis] | None = None#
class TorchSimpleUnit(*args, settings=None, **kwargs)[source]#

Bases: BaseTransformerUnit[TorchSimpleSettings, AxisArray, AxisArray, TorchSimpleProcessor]

Parameters:

settings (Settings | None)

SETTINGS#

alias of TorchSimpleSettings

async on_signal(message)[source]#
Return type:

AsyncGenerator

Parameters:

message (AxisArray)

class TorchSimpleSettings(model_class, checkpoint_path=None, config_path=None, single_precision=True, device=None, model_kwargs=None)[source]#

Bases: Settings

Parameters:
  • model_class (str)

  • checkpoint_path (str | None)

  • config_path (str | None)

  • single_precision (bool)

  • device (str | None)

  • model_kwargs (dict[str, Any] | None)

model_class: str#

Fully qualified class path of the model to be used. Example: “my_module.MyModelClass” This class should inherit from torch.nn.Module.

checkpoint_path: str | None = None#

Path to a checkpoint file containing model weights. If None, the model will be initialized with random weights. If parameters inferred from the weight sizes conflict with the settings / config, then the the inferred parameters will take priority and a warning will be logged.

config_path: str | None = None#

Path to a config file containing model parameters. Parameters loaded from the config file will take priority over settings. If settings differ from config parameters then a warning will be logged. If checkpoint_path is provided then any parameters inferred from the weights will take priority over the config parameters.

single_precision: bool = True#

Use single precision (float32) instead of double precision (float64)

device: str | None = None#

Device to use for the model. If None, the device will be determined automatically, with preference for cuda > mps > cpu.

model_kwargs: dict[str, Any] | None = None#

Additional keyword arguments to pass to the model constructor. This can include parameters like input_size, output_size, etc. If a config file is provided, these parameters will be updated with the config values. If a checkpoint file is provided, these parameters will be updated with the inferred values from the model weights.

__init__(model_class, checkpoint_path=None, config_path=None, single_precision=True, device=None, model_kwargs=None)#
Parameters:
  • model_class (str)

  • checkpoint_path (str | None)

  • config_path (str | None)

  • single_precision (bool)

  • device (str | None)

  • model_kwargs (dict[str, Any] | None)

Return type:

None

class TorchModelSettings(model_class, checkpoint_path=None, config_path=None, single_precision=True, device=None, model_kwargs=None, learning_rate=0.001, weight_decay=0.0001, loss_fn=None, loss_weights=None, scheduler_gamma=0.999)[source]#

Bases: TorchSimpleSettings

Parameters:
learning_rate: float = 0.001#

Learning rate for the optimizer

weight_decay: float = 0.0001#

Weight decay for the optimizer

loss_fn: Module | dict[str, Module] | None = None#

Loss function(s) for the decoder. If using multiple heads, this should be a dictionary mapping head names to loss functions. The keys must match the output head names.

loss_weights: dict[str, float] | None = None#

Weights for each loss function if using multiple heads. The keys must match the output head names. If None or missing/mismatched keys, losses will be unweighted.

scheduler_gamma: float = 0.999#

Learning scheduler decay rate. Set to 0.0 to disable the scheduler.

__init__(model_class, checkpoint_path=None, config_path=None, single_precision=True, device=None, model_kwargs=None, learning_rate=0.001, weight_decay=0.0001, loss_fn=None, loss_weights=None, scheduler_gamma=0.999)#
Parameters:
Return type:

None

class TorchSimpleState[source]#

Bases: object

model: Module | None = None#
device: device | None = None#
chan_ax: dict[str, CoordinateAxis] | None = None#
class TorchModelState[source]#

Bases: TorchSimpleState

optimizer: Optimizer | None = None#
scheduler: LRScheduler | None = None#
class TorchProcessorMixin[source]#

Bases: object

Mixin with shared functionality for torch processors.

save_checkpoint(path)[source]#

Save the current model state to a checkpoint file.

Return type:

None

Parameters:
  • self (P)

  • path (str)

class TorchSimpleProcessor(*args, **kwargs)[source]#

Bases: BaseStatefulTransformer[TorchSimpleSettings, AxisArray, AxisArray, TorchSimpleState], TorchProcessorMixin, ModelInitMixin

class TorchSimpleUnit(*args, settings=None, **kwargs)[source]#

Bases: BaseTransformerUnit[TorchSimpleSettings, AxisArray, AxisArray, TorchSimpleProcessor]

Parameters:

settings (Settings | None)

SETTINGS#

alias of TorchSimpleSettings

async on_signal(message)[source]#
Return type:

AsyncGenerator

Parameters:

message (AxisArray)

class TorchModelProcessor(*args, **kwargs)[source]#

Bases: BaseAdaptiveTransformer[TorchModelSettings, AxisArray, AxisArray, TorchModelState], TorchProcessorMixin, ModelInitMixin

partial_fit(message)[source]#
Return type:

None

Parameters:

message (SampleMessage)

class TorchModelUnit(*args, settings=None, **kwargs)[source]#

Bases: BaseAdaptiveTransformerUnit[TorchModelSettings, AxisArray, AxisArray, TorchModelProcessor]

Parameters:

settings (Settings | None)

SETTINGS#

alias of TorchModelSettings

async on_signal(message)[source]#
Return type:

AsyncGenerator

Parameters:

message (AxisArray)