ezmsg.learn.model.rnn#

Classes

class RNNModel(input_size, hidden_size, num_layers=1, output_size=2, dropout=0.3, rnn_type='GRU')[source]#

Bases: Module

Recurrent neural network supporting GRU, LSTM, and vanilla RNN (tanh/relu).

input_size#

Number of input features per time step.

Type:

int

hidden_size#

Number of hidden units in the RNN cell.

Type:

int

num_layers#

Number of RNN layers. Default is 1.

Type:

int, optional

output_size#

Number of output features or classes if single head output or a dictionary mapping head names to output sizes if multi-head output. Default is 2 (single head).

Type:

int | dict[str, int], optional

dropout#

Dropout rate applied after input and RNN output. Default is 0.3.

Type:

float, optional

rnn_type#

Type of RNN cell to use: ‘GRU’, ‘LSTM’, ‘RNN-Tanh’, ‘RNN-ReLU’. Default is ‘GRU’.

Type:

str, optional

Returns:

Dictionary of decoded predictions mapping head names to tensors of shape

(batch, seq_len, output_size). If single head output, the key is “output”.

Return type:

dict[str, Tensor]

Parameters:
__init__(input_size, hidden_size, num_layers=1, output_size=2, dropout=0.3, rnn_type='GRU')[source]#

Initialize internal Module state, shared by both nn.Module and ScriptModule.

Parameters:
classmethod infer_config_from_state_dict(state_dict, rnn_type='GRU')[source]#

This method is specific to each processor.

Parameters:
  • state_dict (dict) – The state dict of the model.

  • rnn_type (str) – The type of RNN used in the model (e.g., ‘GRU’, ‘LSTM’, ‘RNN-Tanh’, ‘RNN-ReLU’).

Return type:

dict[str, int | float]

Returns:

A dictionary of model parameters obtained from the state dict.

init_hidden(batch_size, device)[source]#

Initialize the hidden state for the RNN. :type batch_size: int :param batch_size: Size of the batch. :type batch_size: int :type device: device :param device: Device to place the hidden state on (e.g., ‘cpu’ or ‘cuda’). :type device: torch.device

Returns:

Initial hidden state for the RNN.

For LSTM, returns a tuple of (h_n, c_n) where h_n is the hidden state and c_n is the cell state. For GRU or vanilla RNN, returns just h_n.

Return type:

Tensor | tuple[Tensor, Tensor]

Parameters:
forward(x, input_lens=None, hx=None)[source]#

Forward pass through the RNN model. :type x: Tensor :param x: Input tensor of shape (batch, seq_len, input_size). :type x: torch.Tensor :type input_lens: Optional[Tensor] :param input_lens: Optional tensor of lengths for each sequence in the batch.

If provided, sequences will be packed before passing through the RNN.

Parameters:
Returns:

A dictionary mapping head names to output tensors of shape (batch, seq_len, output_size). If the RNN is LSTM, the second element is the hidden state (h_n, c_n) or just h_n if GRU.

Return type:

tuple[dict[str, Tensor], Tensor | tuple]

class RNNModel(input_size, hidden_size, num_layers=1, output_size=2, dropout=0.3, rnn_type='GRU')[source]#

Bases: Module

Recurrent neural network supporting GRU, LSTM, and vanilla RNN (tanh/relu).

input_size#

Number of input features per time step.

Type:

int

hidden_size#

Number of hidden units in the RNN cell.

Type:

int

num_layers#

Number of RNN layers. Default is 1.

Type:

int, optional

output_size#

Number of output features or classes if single head output or a dictionary mapping head names to output sizes if multi-head output. Default is 2 (single head).

Type:

int | dict[str, int], optional

dropout#

Dropout rate applied after input and RNN output. Default is 0.3.

Type:

float, optional

rnn_type#

Type of RNN cell to use: ‘GRU’, ‘LSTM’, ‘RNN-Tanh’, ‘RNN-ReLU’. Default is ‘GRU’.

Type:

str, optional

Returns:

Dictionary of decoded predictions mapping head names to tensors of shape

(batch, seq_len, output_size). If single head output, the key is “output”.

Return type:

dict[str, Tensor]

Parameters:
__init__(input_size, hidden_size, num_layers=1, output_size=2, dropout=0.3, rnn_type='GRU')[source]#

Initialize internal Module state, shared by both nn.Module and ScriptModule.

Parameters:
classmethod infer_config_from_state_dict(state_dict, rnn_type='GRU')[source]#

This method is specific to each processor.

Parameters:
  • state_dict (dict) – The state dict of the model.

  • rnn_type (str) – The type of RNN used in the model (e.g., ‘GRU’, ‘LSTM’, ‘RNN-Tanh’, ‘RNN-ReLU’).

Return type:

dict[str, int | float]

Returns:

A dictionary of model parameters obtained from the state dict.

init_hidden(batch_size, device)[source]#

Initialize the hidden state for the RNN. :type batch_size: int :param batch_size: Size of the batch. :type batch_size: int :type device: device :param device: Device to place the hidden state on (e.g., ‘cpu’ or ‘cuda’). :type device: torch.device

Returns:

Initial hidden state for the RNN.

For LSTM, returns a tuple of (h_n, c_n) where h_n is the hidden state and c_n is the cell state. For GRU or vanilla RNN, returns just h_n.

Return type:

Tensor | tuple[Tensor, Tensor]

Parameters:
forward(x, input_lens=None, hx=None)[source]#

Forward pass through the RNN model. :type x: Tensor :param x: Input tensor of shape (batch, seq_len, input_size). :type x: torch.Tensor :type input_lens: Optional[Tensor] :param input_lens: Optional tensor of lengths for each sequence in the batch.

If provided, sequences will be packed before passing through the RNN.

Parameters:
Returns:

A dictionary mapping head names to output tensors of shape (batch, seq_len, output_size). If the RNN is LSTM, the second element is the hidden state (h_n, c_n) or just h_n if GRU.

Return type:

tuple[dict[str, Tensor], Tensor | tuple]