ezmsg.learn.process.transformer#

Classes

class TransformerProcessor(*args, **kwargs)[source]#

Bases: BaseAdaptiveTransformer[TransformerSettings, AxisArray, AxisArray, TransformerState], TorchProcessorMixin, ModelInitMixin

property has_decoder: bool#
reset_cache()[source]#
Return type:

None

partial_fit(message)[source]#
Return type:

None

Parameters:

message (SampleMessage)

class TransformerSettings(model_class='ezmsg.learn.model.transformer.TransformerModel', checkpoint_path=None, config_path=None, single_precision=True, device=None, model_kwargs=None, learning_rate=0.001, weight_decay=0.0001, loss_fn=None, loss_weights=None, scheduler_gamma=0.999, autoregressive_head=None, max_cache_len=128)[source]#

Bases: TorchModelSettings

Parameters:
model_class: str = 'ezmsg.learn.model.transformer.TransformerModel'#

Fully qualified class path of the model to be used. This should be “ezmsg.learn.model.transformer.TransformerModel” for this.

autoregressive_head: str | None = None#

The name of the output head used for autoregressive decoding. This should match one of the keys in the model’s output dictionary. If None, the first output head will be used.

max_cache_len: int | None = 128#

Maximum length of the target sequence cache for autoregressive decoding. This limits the context length during decoding to prevent excessive memory usage. If set to None, the cache will grow indefinitely.

__init__(model_class='ezmsg.learn.model.transformer.TransformerModel', checkpoint_path=None, config_path=None, single_precision=True, device=None, model_kwargs=None, learning_rate=0.001, weight_decay=0.0001, loss_fn=None, loss_weights=None, scheduler_gamma=0.999, autoregressive_head=None, max_cache_len=128)#
Parameters:
Return type:

None

class TransformerState[source]#

Bases: TorchModelState

ar_head: str | None = None#

The name of the autoregressive head used for decoding. This is set based on the autoregressive_head setting. If None, the first output head will be used.

tgt_cache: Tensor | None = None#

Cache for the target sequence used in autoregressive decoding. This is updated with each processed message to maintain context.

class TransformerUnit(*args, settings=None, **kwargs)[source]#

Bases: BaseAdaptiveTransformerUnit[TransformerSettings, AxisArray, AxisArray, TransformerProcessor]

Parameters:

settings (Settings | None)

SETTINGS#

alias of TransformerSettings

async on_signal(message)[source]#
Return type:

AsyncGenerator

Parameters:

message (AxisArray)

class TransformerSettings(model_class='ezmsg.learn.model.transformer.TransformerModel', checkpoint_path=None, config_path=None, single_precision=True, device=None, model_kwargs=None, learning_rate=0.001, weight_decay=0.0001, loss_fn=None, loss_weights=None, scheduler_gamma=0.999, autoregressive_head=None, max_cache_len=128)[source]#

Bases: TorchModelSettings

Parameters:
model_class: str = 'ezmsg.learn.model.transformer.TransformerModel'#

Fully qualified class path of the model to be used. This should be “ezmsg.learn.model.transformer.TransformerModel” for this.

autoregressive_head: str | None = None#

The name of the output head used for autoregressive decoding. This should match one of the keys in the model’s output dictionary. If None, the first output head will be used.

max_cache_len: int | None = 128#

Maximum length of the target sequence cache for autoregressive decoding. This limits the context length during decoding to prevent excessive memory usage. If set to None, the cache will grow indefinitely.

__init__(model_class='ezmsg.learn.model.transformer.TransformerModel', checkpoint_path=None, config_path=None, single_precision=True, device=None, model_kwargs=None, learning_rate=0.001, weight_decay=0.0001, loss_fn=None, loss_weights=None, scheduler_gamma=0.999, autoregressive_head=None, max_cache_len=128)#
Parameters:
Return type:

None

class TransformerState[source]#

Bases: TorchModelState

ar_head: str | None = None#

The name of the autoregressive head used for decoding. This is set based on the autoregressive_head setting. If None, the first output head will be used.

tgt_cache: Tensor | None = None#

Cache for the target sequence used in autoregressive decoding. This is updated with each processed message to maintain context.

class TransformerProcessor(*args, **kwargs)[source]#

Bases: BaseAdaptiveTransformer[TransformerSettings, AxisArray, AxisArray, TransformerState], TorchProcessorMixin, ModelInitMixin

property has_decoder: bool#
reset_cache()[source]#
Return type:

None

partial_fit(message)[source]#
Return type:

None

Parameters:

message (SampleMessage)

class TransformerUnit(*args, settings=None, **kwargs)[source]#

Bases: BaseAdaptiveTransformerUnit[TransformerSettings, AxisArray, AxisArray, TransformerProcessor]

Parameters:

settings (Settings | None)

SETTINGS#

alias of TransformerSettings

async on_signal(message)[source]#
Return type:

AsyncGenerator

Parameters:

message (AxisArray)