ezmsg.learn.model.rnn#
Classes
- class RNNModel(input_size, hidden_size, num_layers=1, output_size=2, dropout=0.3, rnn_type='GRU')[source]#
Bases:
ModuleRecurrent neural network supporting GRU, LSTM, and vanilla RNN (tanh/relu).
Number of hidden units in the RNN cell.
- Type:
- output_size#
Number of output features or classes if single head output or a dictionary mapping head names to output sizes if multi-head output. Default is 2 (single head).
- rnn_type#
Type of RNN cell to use: ‘GRU’, ‘LSTM’, ‘RNN-Tanh’, ‘RNN-ReLU’. Default is ‘GRU’.
- Type:
str, optional
- Returns:
- Dictionary of decoded predictions mapping head names to tensors of shape
(batch, seq_len, output_size). If single head output, the key is “output”.
- Return type:
- Parameters:
- __init__(input_size, hidden_size, num_layers=1, output_size=2, dropout=0.3, rnn_type='GRU')[source]#
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- classmethod infer_config_from_state_dict(state_dict, rnn_type='GRU')[source]#
This method is specific to each processor.
Initialize the hidden state for the RNN. :type batch_size:
int:param batch_size: Size of the batch. :type batch_size: int :type device:device:param device: Device to place the hidden state on (e.g., ‘cpu’ or ‘cuda’). :type device: torch.device
- forward(x, input_lens=None, hx=None)[source]#
Forward pass through the RNN model. :type x:
Tensor:param x: Input tensor of shape (batch, seq_len, input_size). :type x: torch.Tensor :type input_lens:Optional[Tensor] :param input_lens: Optional tensor of lengths for each sequence in the batch.If provided, sequences will be packed before passing through the RNN.
- Parameters:
- Returns:
A dictionary mapping head names to output tensors of shape (batch, seq_len, output_size). If the RNN is LSTM, the second element is the hidden state (h_n, c_n) or just h_n if GRU.
- Return type:
- class RNNModel(input_size, hidden_size, num_layers=1, output_size=2, dropout=0.3, rnn_type='GRU')[source]#
Bases:
ModuleRecurrent neural network supporting GRU, LSTM, and vanilla RNN (tanh/relu).
- output_size#
Number of output features or classes if single head output or a dictionary mapping head names to output sizes if multi-head output. Default is 2 (single head).
- rnn_type#
Type of RNN cell to use: ‘GRU’, ‘LSTM’, ‘RNN-Tanh’, ‘RNN-ReLU’. Default is ‘GRU’.
- Type:
str, optional
- Returns:
- Dictionary of decoded predictions mapping head names to tensors of shape
(batch, seq_len, output_size). If single head output, the key is “output”.
- Return type:
- Parameters:
- __init__(input_size, hidden_size, num_layers=1, output_size=2, dropout=0.3, rnn_type='GRU')[source]#
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- classmethod infer_config_from_state_dict(state_dict, rnn_type='GRU')[source]#
This method is specific to each processor.
- init_hidden(batch_size, device)[source]#
Initialize the hidden state for the RNN. :type batch_size:
int:param batch_size: Size of the batch. :type batch_size: int :type device:device:param device: Device to place the hidden state on (e.g., ‘cpu’ or ‘cuda’). :type device: torch.device
- forward(x, input_lens=None, hx=None)[source]#
Forward pass through the RNN model. :type x:
Tensor:param x: Input tensor of shape (batch, seq_len, input_size). :type x: torch.Tensor :type input_lens:Optional[Tensor] :param input_lens: Optional tensor of lengths for each sequence in the batch.If provided, sequences will be packed before passing through the RNN.
- Parameters:
- Returns:
A dictionary mapping head names to output tensors of shape (batch, seq_len, output_size). If the RNN is LSTM, the second element is the hidden state (h_n, c_n) or just h_n if GRU.
- Return type: