Source code for ezmsg.learn.model.rnn

from typing import Optional

import torch


[docs] class RNNModel(torch.nn.Module): """ Recurrent neural network supporting GRU, LSTM, and vanilla RNN (tanh/relu). Attributes: input_size (int): Number of input features per time step. hidden_size (int): Number of hidden units in the RNN cell. num_layers (int, optional): Number of RNN layers. Default is 1. output_size (int | dict[str, int], optional): Number of output features or classes if single head output or a dictionary mapping head names to output sizes if multi-head output. Default is 2 (single head). dropout (float, optional): Dropout rate applied after input and RNN output. Default is 0.3. rnn_type (str, optional): Type of RNN cell to use: 'GRU', 'LSTM', 'RNN-Tanh', 'RNN-ReLU'. Default is 'GRU'. Returns: dict[str, torch.Tensor]: Dictionary of decoded predictions mapping head names to tensors of shape (batch, seq_len, output_size). If single head output, the key is "output". """
[docs] def __init__( self, input_size: int, hidden_size: int, num_layers: int = 1, output_size: int | dict[str, int] = 2, dropout: float = 0.3, rnn_type: str = "GRU", ): super().__init__() self.linear_embeddings = torch.nn.Linear(input_size, input_size) self.dropout_input = torch.nn.Dropout(dropout) rnn_klass_str = rnn_type.upper().split("-")[0] if rnn_klass_str not in ["GRU", "LSTM", "RNN"]: raise ValueError(f"Unrecognized rnn_type: {rnn_type}") rnn_klass = {"GRU": torch.nn.GRU, "LSTM": torch.nn.LSTM, "RNN": torch.nn.RNN}[ rnn_klass_str ] rnn_kwargs = {} if rnn_klass_str == "RNN": rnn_kwargs["nonlinearity"] = rnn_type.lower().split("-")[-1] self.rnn = rnn_klass( input_size, hidden_size, num_layers, batch_first=True, dropout=dropout if num_layers > 1 else 0.0, **rnn_kwargs, ) self.rnn_type = rnn_klass_str self.output_dropout = torch.nn.Dropout(dropout) if isinstance(output_size, int): output_size = {"output": output_size} self.heads = torch.nn.ModuleDict( { name: torch.nn.Linear(hidden_size, size) for name, size in output_size.items() } )
[docs] @classmethod def infer_config_from_state_dict( cls, state_dict: dict, rnn_type: str = "GRU" ) -> dict[str, int | float]: """ This method is specific to each processor. Args: state_dict: The state dict of the model. rnn_type: The type of RNN used in the model (e.g., 'GRU', 'LSTM', 'RNN-Tanh', 'RNN-ReLU'). Returns: A dictionary of model parameters obtained from the state dict. """ output_size = { key.split(".")[1]: param.shape[0] for key, param in state_dict.items() if key.startswith("heads.") and key.endswith(".bias") } return { # Infer input_size from linear_embeddings.weight (shape: [input_size, input_size]) "input_size": state_dict["linear_embeddings.weight"].shape[1], # Infer hidden_size from rnn.weight_ih_l0 (shape: [hidden_size * 3, input_size]) "hidden_size": state_dict["rnn.weight_ih_l0"].shape[0] // cls._get_gate_count(rnn_type), # Infer num_layers by counting rnn layers in state_dict (e.g., weight_ih_l<k>) "num_layers": sum(1 for key in state_dict if "rnn.weight_ih_l" in key), "output_size": output_size, }
@staticmethod def _get_gate_count(rnn_type: str) -> int: if rnn_type.upper() == "GRU": return 3 elif rnn_type.upper() == "LSTM": return 4 elif rnn_type.upper().startswith("RNN"): return 1 else: raise ValueError(f"Unsupported rnn_type: {rnn_type}")
[docs] def init_hidden(self, batch_size: int, device: torch.device) -> torch.Tensor: """ Initialize the hidden state for the RNN. Args: batch_size (int): Size of the batch. device (torch.device): Device to place the hidden state on (e.g., 'cpu' or 'cuda'). Returns: torch.Tensor | tuple[torch.Tensor, torch.Tensor]: Initial hidden state for the RNN. For LSTM, returns a tuple of (h_n, c_n) where h_n is the hidden state and c_n is the cell state. For GRU or vanilla RNN, returns just h_n. """ shape = (self.rnn.num_layers, batch_size, self.rnn.hidden_size) if self.rnn_type == "LSTM": return ( torch.zeros(shape, device=device), torch.zeros(shape, device=device), ) else: return torch.zeros(shape, device=device)
[docs] def forward( self, x: torch.Tensor, input_lens: Optional[torch.Tensor] = None, hx: Optional[torch.Tensor | tuple[torch.Tensor, torch.Tensor]] = None, ) -> tuple[dict[str, torch.Tensor], torch.Tensor | tuple]: """ Forward pass through the RNN model. Args: x (torch.Tensor): Input tensor of shape (batch, seq_len, input_size). input_lens (Optional[torch.Tensor]): Optional tensor of lengths for each sequence in the batch. If provided, sequences will be packed before passing through the RNN. hx (Optional[torch.Tensor | tuple[torch.Tensor, torch.Tensor]]): Optional initial hidden state for the RNN. Returns: tuple[dict[str, torch.Tensor], torch.Tensor | tuple]: A dictionary mapping head names to output tensors of shape (batch, seq_len, output_size). If the RNN is LSTM, the second element is the hidden state (h_n, c_n) or just h_n if GRU. """ x = self.linear_embeddings(x) x = self.dropout_input(x) total_length = x.shape[1] if input_lens is not None: x = torch.nn.utils.rnn.pack_padded_sequence( x, input_lens, batch_first=True, enforce_sorted=False ) x_out, hx_out = self.rnn(x, hx) if input_lens is not None: x_out, _ = torch.nn.utils.rnn.pad_packed_sequence( x_out, batch_first=True, total_length=total_length ) x_out = self.output_dropout(x_out) return {name: head(x_out) for name, head in self.heads.items()}, hx_out