ezmsg.learn.process.rnn#

Classes

class RNNProcessor(*args, **kwargs)[source]#

Bases: BaseAdaptiveTransformer[RNNSettings, AxisArray, AxisArray, RNNState], TorchProcessorMixin, ModelInitMixin

reset_hidden(batch_size)[source]#
Return type:

None

Parameters:

batch_size (int)

partial_fit(message)[source]#
Return type:

None

Parameters:

message (SampleMessage)

class RNNSettings(model_class='ezmsg.learn.model.rnn.RNNModel', checkpoint_path=None, config_path=None, single_precision=True, device=None, model_kwargs=None, learning_rate=0.001, weight_decay=0.0001, loss_fn=None, loss_weights=None, scheduler_gamma=0.999, reset_hidden_on_fit=True, preserve_state_across_windows='auto')[source]#

Bases: TorchModelSettings

Parameters:
model_class: str = 'ezmsg.learn.model.rnn.RNNModel'#

Fully qualified class path of the model to be used. This should be “ezmsg.learn.model.rnn.RNNModel” for this.

reset_hidden_on_fit: bool = True#

Whether to reset the hidden state on each fit call. If True, the hidden state will be reset to zero after each fit. If False, the hidden state will be maintained across fit calls.

preserve_state_across_windows: bool | Literal['auto'] = 'auto'#

Whether to preserve the hidden state across windows. If True, the hidden state will be preserved across windows. If False, the hidden state will be reset at the start of each window. If “auto”, preserve if there is no overlap in time windows, otherwise reset.

__init__(model_class='ezmsg.learn.model.rnn.RNNModel', checkpoint_path=None, config_path=None, single_precision=True, device=None, model_kwargs=None, learning_rate=0.001, weight_decay=0.0001, loss_fn=None, loss_weights=None, scheduler_gamma=0.999, reset_hidden_on_fit=True, preserve_state_across_windows='auto')#
Parameters:
Return type:

None

class RNNState[source]#

Bases: TorchModelState

hx: Tensor | None = None#
class RNNUnit(*args, settings=None, **kwargs)[source]#

Bases: BaseAdaptiveTransformerUnit[RNNSettings, AxisArray, AxisArray, RNNProcessor]

Parameters:

settings (Settings | None)

SETTINGS#

alias of RNNSettings

async on_signal(message)[source]#
Return type:

AsyncGenerator

Parameters:

message (AxisArray)

class RNNSettings(model_class='ezmsg.learn.model.rnn.RNNModel', checkpoint_path=None, config_path=None, single_precision=True, device=None, model_kwargs=None, learning_rate=0.001, weight_decay=0.0001, loss_fn=None, loss_weights=None, scheduler_gamma=0.999, reset_hidden_on_fit=True, preserve_state_across_windows='auto')[source]#

Bases: TorchModelSettings

Parameters:
model_class: str = 'ezmsg.learn.model.rnn.RNNModel'#

Fully qualified class path of the model to be used. This should be “ezmsg.learn.model.rnn.RNNModel” for this.

reset_hidden_on_fit: bool = True#

Whether to reset the hidden state on each fit call. If True, the hidden state will be reset to zero after each fit. If False, the hidden state will be maintained across fit calls.

preserve_state_across_windows: bool | Literal['auto'] = 'auto'#

Whether to preserve the hidden state across windows. If True, the hidden state will be preserved across windows. If False, the hidden state will be reset at the start of each window. If “auto”, preserve if there is no overlap in time windows, otherwise reset.

__init__(model_class='ezmsg.learn.model.rnn.RNNModel', checkpoint_path=None, config_path=None, single_precision=True, device=None, model_kwargs=None, learning_rate=0.001, weight_decay=0.0001, loss_fn=None, loss_weights=None, scheduler_gamma=0.999, reset_hidden_on_fit=True, preserve_state_across_windows='auto')#
Parameters:
Return type:

None

class RNNState[source]#

Bases: TorchModelState

hx: Tensor | None = None#
class RNNProcessor(*args, **kwargs)[source]#

Bases: BaseAdaptiveTransformer[RNNSettings, AxisArray, AxisArray, RNNState], TorchProcessorMixin, ModelInitMixin

reset_hidden(batch_size)[source]#
Return type:

None

Parameters:

batch_size (int)

partial_fit(message)[source]#
Return type:

None

Parameters:

message (SampleMessage)

class RNNUnit(*args, settings=None, **kwargs)[source]#

Bases: BaseAdaptiveTransformerUnit[RNNSettings, AxisArray, AxisArray, RNNProcessor]

Parameters:

settings (Settings | None)

SETTINGS#

alias of RNNSettings

async on_signal(message)[source]#
Return type:

AsyncGenerator

Parameters:

message (AxisArray)