ezmsg.learn.process.transformer#
Classes
- class TransformerProcessor(*args, **kwargs)[source]#
Bases:
BaseAdaptiveTransformer[TransformerSettings,AxisArray,AxisArray,TransformerState],TorchProcessorMixin,ModelInitMixin- partial_fit(message)[source]#
- Return type:
- Parameters:
message (SampleMessage)
- class TransformerSettings(model_class='ezmsg.learn.model.transformer.TransformerModel', checkpoint_path=None, config_path=None, single_precision=True, device=None, model_kwargs=None, learning_rate=0.001, weight_decay=0.0001, loss_fn=None, loss_weights=None, scheduler_gamma=0.999, autoregressive_head=None, max_cache_len=128)[source]#
Bases:
TorchModelSettings- Parameters:
- model_class: str = 'ezmsg.learn.model.transformer.TransformerModel'#
Fully qualified class path of the model to be used. This should be “ezmsg.learn.model.transformer.TransformerModel” for this.
- autoregressive_head: str | None = None#
The name of the output head used for autoregressive decoding. This should match one of the keys in the model’s output dictionary. If None, the first output head will be used.
- max_cache_len: int | None = 128#
Maximum length of the target sequence cache for autoregressive decoding. This limits the context length during decoding to prevent excessive memory usage. If set to None, the cache will grow indefinitely.
- __init__(model_class='ezmsg.learn.model.transformer.TransformerModel', checkpoint_path=None, config_path=None, single_precision=True, device=None, model_kwargs=None, learning_rate=0.001, weight_decay=0.0001, loss_fn=None, loss_weights=None, scheduler_gamma=0.999, autoregressive_head=None, max_cache_len=128)#
- class TransformerState[source]#
Bases:
TorchModelState
- class TransformerUnit(*args, settings=None, **kwargs)[source]#
Bases:
BaseAdaptiveTransformerUnit[TransformerSettings,AxisArray,AxisArray,TransformerProcessor]- Parameters:
settings (Settings | None)
- SETTINGS#
alias of
TransformerSettings
- class TransformerSettings(model_class='ezmsg.learn.model.transformer.TransformerModel', checkpoint_path=None, config_path=None, single_precision=True, device=None, model_kwargs=None, learning_rate=0.001, weight_decay=0.0001, loss_fn=None, loss_weights=None, scheduler_gamma=0.999, autoregressive_head=None, max_cache_len=128)[source]#
Bases:
TorchModelSettings- Parameters:
- model_class: str = 'ezmsg.learn.model.transformer.TransformerModel'#
Fully qualified class path of the model to be used. This should be “ezmsg.learn.model.transformer.TransformerModel” for this.
- autoregressive_head: str | None = None#
The name of the output head used for autoregressive decoding. This should match one of the keys in the model’s output dictionary. If None, the first output head will be used.
- max_cache_len: int | None = 128#
Maximum length of the target sequence cache for autoregressive decoding. This limits the context length during decoding to prevent excessive memory usage. If set to None, the cache will grow indefinitely.
- __init__(model_class='ezmsg.learn.model.transformer.TransformerModel', checkpoint_path=None, config_path=None, single_precision=True, device=None, model_kwargs=None, learning_rate=0.001, weight_decay=0.0001, loss_fn=None, loss_weights=None, scheduler_gamma=0.999, autoregressive_head=None, max_cache_len=128)#
- class TransformerState[source]#
Bases:
TorchModelState
- class TransformerProcessor(*args, **kwargs)[source]#
Bases:
BaseAdaptiveTransformer[TransformerSettings,AxisArray,AxisArray,TransformerState],TorchProcessorMixin,ModelInitMixin- partial_fit(message)[source]#
- Return type:
- Parameters:
message (SampleMessage)
- class TransformerUnit(*args, settings=None, **kwargs)[source]#
Bases:
BaseAdaptiveTransformerUnit[TransformerSettings,AxisArray,AxisArray,TransformerProcessor]- Parameters:
settings (Settings | None)
- SETTINGS#
alias of
TransformerSettings