Source code for ezmsg.learn.model.transformer

from typing import Optional

import torch


[docs] class TransformerModel(torch.nn.Module): """ Transformer-based encoder (optional decoder) neural network. If `decoder_layers > 0`, the model includes a Transformer decoder. In this case, the `tgt` argument must be provided: during training, it is typically the ground-truth target sequence (i.e. teacher forcing); during inference, it can be constructed autoregressively from previous predictions. Attributes: input_size (int): Number of input features per time step. hidden_size (int): Dimensionality of the transformer model. encoder_layers (int, optional): Number of transformer encoder layers. Default is 1. decoder_layers (int, optional): Number of transformer decoder layers. Default is 0. output_size (int | dict[str, int], optional): Number of output features or classes if single head output, or a dictionary mapping head names to output sizes if multi-head output. Default is 2 (single head). dropout (float, optional): Dropout rate applied after input and transformer output. Default is 0.3. attention_heads (int, optional): Number of attention heads in the transformer. Default is 4. max_seq_len (int, optional): Maximum sequence length for positional embeddings. Default is 512. Returns: dict[str, torch.Tensor]: Dictionary of decoded predictions mapping head names to tensors of shape (batch, seq_len, output_size). If single head output, the key is "output". """
[docs] def __init__( self, input_size: int, hidden_size: int, encoder_layers: int = 1, decoder_layers: int = 0, output_size: int | dict[str, int] = 2, dropout: float = 0.3, attention_heads: int = 4, max_seq_len: int = 512, autoregressive_head: str | None = None, ): super().__init__() self.decoder_layers = decoder_layers self.hidden_size = hidden_size if isinstance(output_size, int): autoregressive_size = output_size else: autoregressive_size = list(output_size.values())[0] if isinstance(output_size, dict): autoregressive_size = output_size.get( autoregressive_head, autoregressive_size ) self.start_token = torch.nn.Parameter(torch.zeros(1, 1, autoregressive_size)) self.output_to_hidden = torch.nn.Linear(autoregressive_size, hidden_size) self.input_proj = torch.nn.Linear(input_size, hidden_size) self.pos_embedding = torch.nn.Embedding(max_seq_len, hidden_size) self.dropout = torch.nn.Dropout(dropout) self.encoder = torch.nn.TransformerEncoder( torch.nn.TransformerEncoderLayer( d_model=hidden_size, nhead=attention_heads, dim_feedforward=hidden_size * 4, dropout=dropout, batch_first=True, ), num_layers=encoder_layers, ) self.decoder = None if decoder_layers > 0: self.decoder = torch.nn.TransformerDecoder( torch.nn.TransformerDecoderLayer( d_model=hidden_size, nhead=attention_heads, dim_feedforward=hidden_size * 4, dropout=dropout, batch_first=True, ), num_layers=decoder_layers, ) if isinstance(output_size, int): output_size = {"output": output_size} self.heads = torch.nn.ModuleDict( { name: torch.nn.Linear(hidden_size, out_dim) for name, out_dim in output_size.items() } )
[docs] @classmethod def infer_config_from_state_dict(cls, state_dict: dict) -> dict[str, int | float]: # Infer output size from heads.<name>.bias (shape: [output_size]) output_size = { key.split(".")[1]: param.shape[0] for key, param in state_dict.items() if key.startswith("heads.") and key.endswith(".bias") } return { # Infer input_size from input_proj.weight (shape: [hidden_size, input_size]) "input_size": state_dict["input_proj.weight"].shape[1], # Infer hidden_size from input_proj.weight (shape: [hidden_size, input_size]) "hidden_size": state_dict["input_proj.weight"].shape[0], "output_size": output_size, # Infer encoder_layers from transformer layers in state_dict "encoder_layers": len( [k for k in state_dict if k.startswith("encoder.layers")] ), # Infer decoder_layers from transformer decoder layers in state_dict "decoder_layers": len( {k.split(".")[2] for k in state_dict if k.startswith("decoder.layers")} ) if any(k.startswith("decoder.layers") for k in state_dict) else 0, }
[docs] def forward( self, src: torch.Tensor, tgt: Optional[torch.Tensor] = None, src_mask: Optional[torch.Tensor] = None, tgt_mask: Optional[torch.Tensor] = None, start_pos: int = 0, ) -> dict[str, torch.Tensor]: """ Forward pass through the transformer model. Args: src (torch.Tensor): Input tensor of shape (batch, seq_len, input_size). tgt (Optional[torch.Tensor]): Target tensor for decoder, shape (batch, seq_len, input_size). Required if `decoder_layers > 0`. In training, this can be the ground-truth target sequence (i.e. teacher forcing). During inference, this is constructed autoregressively. src_mask (Optional[torch.Tensor]): Optional attention mask for the encoder input. Should be broadcastable to shape (batch, seq_len, seq_len) or (seq_len, seq_len). tgt_mask (Optional[torch.Tensor]): Optional attention mask for the decoder input. Used to enforce causal decoding (i.e. autoregressive generation) during training or inference. start_pos (int): Starting offset for positional embeddings. Used for streaming inference to maintain correct positional indices. Default is 0. Returns: dict[str, torch.Tensor]: Dictionary of output tensors each output head, each with shape (batch, seq_len, output_size). """ B, T, _ = src.shape device = src.device x = self.input_proj(src) pos_ids = torch.arange(start_pos, start_pos + T, device=device).expand(B, T) x = x + self.pos_embedding(pos_ids) x = self.dropout(x) memory = self.encoder(x, mask=src_mask) if self.decoder is not None: if tgt is None: tgt = self.start_token.expand(B, -1, -1).to(device) tgt_proj = self.output_to_hidden(tgt) tgt_pos_ids = torch.arange(tgt.shape[1], device=device).expand( B, tgt.shape[1] ) tgt_proj = tgt_proj + self.pos_embedding(tgt_pos_ids) tgt_proj = self.dropout(tgt_proj) out = self.decoder( tgt_proj, memory, tgt_mask=tgt_mask, memory_mask=src_mask, ) else: out = memory return {name: head(out) for name, head in self.heads.items()}