ezmsg.learn.model.transformer#
Classes
- class TransformerModel(input_size, hidden_size, encoder_layers=1, decoder_layers=0, output_size=2, dropout=0.3, attention_heads=4, max_seq_len=512, autoregressive_head=None)[source]#
Bases:
ModuleTransformer-based encoder (optional decoder) neural network.
If decoder_layers > 0, the model includes a Transformer decoder. In this case, the tgt argument must be provided: during training, it is typically the ground-truth target sequence (i.e. teacher forcing); during inference, it can be constructed autoregressively from previous predictions.
Dimensionality of the transformer model.
- Type:
- output_size#
Number of output features or classes if single head output, or a dictionary mapping head names to output sizes if multi-head output. Default is 2 (single head).
- dropout#
Dropout rate applied after input and transformer output. Default is 0.3.
- Type:
float, optional
- Returns:
- Dictionary of decoded predictions mapping head names to tensors of shape
(batch, seq_len, output_size). If single head output, the key is “output”.
- Return type:
- Parameters:
- __init__(input_size, hidden_size, encoder_layers=1, decoder_layers=0, output_size=2, dropout=0.3, attention_heads=4, max_seq_len=512, autoregressive_head=None)[source]#
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- forward(src, tgt=None, src_mask=None, tgt_mask=None, start_pos=0)[source]#
Forward pass through the transformer model. :type src:
Tensor:param src: Input tensor of shape (batch, seq_len, input_size). :type src: torch.Tensor :type tgt:Optional[Tensor] :param tgt: Target tensor for decoder, shape (batch, seq_len, input_size).Required if decoder_layers > 0. In training, this can be the ground-truth target sequence (i.e. teacher forcing). During inference, this is constructed autoregressively.
- Parameters:
src_mask (Optional[Tensor]) – Optional attention mask for the encoder input. Should be broadcastable to shape (batch, seq_len, seq_len) or (seq_len, seq_len).
tgt_mask (Optional[Tensor]) – Optional attention mask for the decoder input. Used to enforce causal decoding (i.e. autoregressive generation) during training or inference.
start_pos (int) – Starting offset for positional embeddings. Used for streaming inference to maintain correct positional indices. Default is 0.
src (Tensor)
tgt (Optional[Tensor])
- Returns:
- Dictionary of output tensors each output head, each with shape (batch, seq_len,
output_size).
- Return type:
- class TransformerModel(input_size, hidden_size, encoder_layers=1, decoder_layers=0, output_size=2, dropout=0.3, attention_heads=4, max_seq_len=512, autoregressive_head=None)[source]#
Bases:
ModuleTransformer-based encoder (optional decoder) neural network.
If decoder_layers > 0, the model includes a Transformer decoder. In this case, the tgt argument must be provided: during training, it is typically the ground-truth target sequence (i.e. teacher forcing); during inference, it can be constructed autoregressively from previous predictions.
- output_size#
Number of output features or classes if single head output, or a dictionary mapping head names to output sizes if multi-head output. Default is 2 (single head).
- dropout#
Dropout rate applied after input and transformer output. Default is 0.3.
- Type:
float, optional
- Returns:
- Dictionary of decoded predictions mapping head names to tensors of shape
(batch, seq_len, output_size). If single head output, the key is “output”.
- Return type:
- Parameters:
- __init__(input_size, hidden_size, encoder_layers=1, decoder_layers=0, output_size=2, dropout=0.3, attention_heads=4, max_seq_len=512, autoregressive_head=None)[source]#
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- forward(src, tgt=None, src_mask=None, tgt_mask=None, start_pos=0)[source]#
Forward pass through the transformer model. :type src:
Tensor:param src: Input tensor of shape (batch, seq_len, input_size). :type src: torch.Tensor :type tgt:Optional[Tensor] :param tgt: Target tensor for decoder, shape (batch, seq_len, input_size).Required if decoder_layers > 0. In training, this can be the ground-truth target sequence (i.e. teacher forcing). During inference, this is constructed autoregressively.
- Parameters:
src_mask (Optional[Tensor]) – Optional attention mask for the encoder input. Should be broadcastable to shape (batch, seq_len, seq_len) or (seq_len, seq_len).
tgt_mask (Optional[Tensor]) – Optional attention mask for the decoder input. Used to enforce causal decoding (i.e. autoregressive generation) during training or inference.
start_pos (int) – Starting offset for positional embeddings. Used for streaming inference to maintain correct positional indices. Default is 0.
src (Tensor)
tgt (Optional[Tensor])
- Returns:
- Dictionary of output tensors each output head, each with shape (batch, seq_len,
output_size).
- Return type: