ezmsg.learn.process.rnn#
Classes
- class RNNProcessor(*args, **kwargs)[source]#
Bases:
BaseAdaptiveTransformer[RNNSettings,AxisArray,AxisArray,RNNState],TorchProcessorMixin,ModelInitMixin- partial_fit(message)[source]#
- Return type:
- Parameters:
message (SampleMessage)
- class RNNSettings(model_class='ezmsg.learn.model.rnn.RNNModel', checkpoint_path=None, config_path=None, single_precision=True, device=None, model_kwargs=None, learning_rate=0.001, weight_decay=0.0001, loss_fn=None, loss_weights=None, scheduler_gamma=0.999, reset_hidden_on_fit=True, preserve_state_across_windows='auto')[source]#
Bases:
TorchModelSettings- Parameters:
- model_class: str = 'ezmsg.learn.model.rnn.RNNModel'#
Fully qualified class path of the model to be used. This should be “ezmsg.learn.model.rnn.RNNModel” for this.
Whether to reset the hidden state on each fit call. If True, the hidden state will be reset to zero after each fit. If False, the hidden state will be maintained across fit calls.
- preserve_state_across_windows: bool | Literal['auto'] = 'auto'#
Whether to preserve the hidden state across windows. If True, the hidden state will be preserved across windows. If False, the hidden state will be reset at the start of each window. If “auto”, preserve if there is no overlap in time windows, otherwise reset.
- __init__(model_class='ezmsg.learn.model.rnn.RNNModel', checkpoint_path=None, config_path=None, single_precision=True, device=None, model_kwargs=None, learning_rate=0.001, weight_decay=0.0001, loss_fn=None, loss_weights=None, scheduler_gamma=0.999, reset_hidden_on_fit=True, preserve_state_across_windows='auto')#
- class RNNUnit(*args, settings=None, **kwargs)[source]#
Bases:
BaseAdaptiveTransformerUnit[RNNSettings,AxisArray,AxisArray,RNNProcessor]- Parameters:
settings (Settings | None)
- SETTINGS#
alias of
RNNSettings
- class RNNSettings(model_class='ezmsg.learn.model.rnn.RNNModel', checkpoint_path=None, config_path=None, single_precision=True, device=None, model_kwargs=None, learning_rate=0.001, weight_decay=0.0001, loss_fn=None, loss_weights=None, scheduler_gamma=0.999, reset_hidden_on_fit=True, preserve_state_across_windows='auto')[source]#
Bases:
TorchModelSettings- Parameters:
- model_class: str = 'ezmsg.learn.model.rnn.RNNModel'#
Fully qualified class path of the model to be used. This should be “ezmsg.learn.model.rnn.RNNModel” for this.
- reset_hidden_on_fit: bool = True#
Whether to reset the hidden state on each fit call. If True, the hidden state will be reset to zero after each fit. If False, the hidden state will be maintained across fit calls.
- preserve_state_across_windows: bool | Literal['auto'] = 'auto'#
Whether to preserve the hidden state across windows. If True, the hidden state will be preserved across windows. If False, the hidden state will be reset at the start of each window. If “auto”, preserve if there is no overlap in time windows, otherwise reset.
- __init__(model_class='ezmsg.learn.model.rnn.RNNModel', checkpoint_path=None, config_path=None, single_precision=True, device=None, model_kwargs=None, learning_rate=0.001, weight_decay=0.0001, loss_fn=None, loss_weights=None, scheduler_gamma=0.999, reset_hidden_on_fit=True, preserve_state_across_windows='auto')#
- class RNNProcessor(*args, **kwargs)[source]#
Bases:
BaseAdaptiveTransformer[RNNSettings,AxisArray,AxisArray,RNNState],TorchProcessorMixin,ModelInitMixin- partial_fit(message)[source]#
- Return type:
- Parameters:
message (SampleMessage)
- class RNNUnit(*args, settings=None, **kwargs)[source]#
Bases:
BaseAdaptiveTransformerUnit[RNNSettings,AxisArray,AxisArray,RNNProcessor]- Parameters:
settings (Settings | None)
- SETTINGS#
alias of
RNNSettings