Source code for ezmsg.learn.process.torch

import importlib
import typing

import ezmsg.core as ez
import numpy as np
import torch
from ezmsg.baseproc import (
    BaseAdaptiveTransformer,
    BaseAdaptiveTransformerUnit,
    BaseStatefulTransformer,
    BaseTransformerUnit,
    processor_state,
)
from ezmsg.baseproc.util.profile import profile_subpub
from ezmsg.sigproc.sampler import SampleMessage
from ezmsg.util.messages.axisarray import AxisArray
from ezmsg.util.messages.util import replace

from .base import ModelInitMixin


[docs] class TorchSimpleSettings(ez.Settings): model_class: str """ Fully qualified class path of the model to be used. Example: "my_module.MyModelClass" This class should inherit from `torch.nn.Module`. """ checkpoint_path: str | None = None """ Path to a checkpoint file containing model weights. If None, the model will be initialized with random weights. If parameters inferred from the weight sizes conflict with the settings / config, then the the inferred parameters will take priority and a warning will be logged. """ config_path: str | None = None """ Path to a config file containing model parameters. Parameters loaded from the config file will take priority over settings. If settings differ from config parameters then a warning will be logged. If `checkpoint_path` is provided then any parameters inferred from the weights will take priority over the config parameters. """ single_precision: bool = True """Use single precision (float32) instead of double precision (float64)""" device: str | None = None """ Device to use for the model. If None, the device will be determined automatically, with preference for cuda > mps > cpu. """ model_kwargs: dict[str, typing.Any] | None = None """ Additional keyword arguments to pass to the model constructor. This can include parameters like `input_size`, `output_size`, etc. If a config file is provided, these parameters will be updated with the config values. If a checkpoint file is provided, these parameters will be updated with the inferred values from the model weights. """
[docs] class TorchModelSettings(TorchSimpleSettings): learning_rate: float = 0.001 """Learning rate for the optimizer""" weight_decay: float = 0.0001 """Weight decay for the optimizer""" loss_fn: torch.nn.Module | dict[str, torch.nn.Module] | None = None """ Loss function(s) for the decoder. If using multiple heads, this should be a dictionary mapping head names to loss functions. The keys must match the output head names. """ loss_weights: dict[str, float] | None = None """ Weights for each loss function if using multiple heads. The keys must match the output head names. If None or missing/mismatched keys, losses will be unweighted. """ scheduler_gamma: float = 0.999 """Learning scheduler decay rate. Set to 0.0 to disable the scheduler."""
[docs] @processor_state class TorchSimpleState: model: torch.nn.Module | None = None device: torch.device | None = None chan_ax: dict[str, AxisArray.CoordinateAxis] | None = None
[docs] class TorchModelState(TorchSimpleState): optimizer: torch.optim.Optimizer | None = None scheduler: torch.optim.lr_scheduler.LRScheduler | None = None
P = typing.TypeVar("P", bound=BaseStatefulTransformer)
[docs] class TorchProcessorMixin: """Mixin with shared functionality for torch processors.""" def _import_model(self, class_path: str) -> type[torch.nn.Module]: """Dynamically import model class from string.""" if class_path is None: raise ValueError("Model class path must be provided in settings.") module_path, class_name = class_path.rsplit(".", 1) module = importlib.import_module(module_path) return getattr(module, class_name) def _infer_output_sizes(self: P, model: torch.nn.Module, n_input: int) -> dict[str, int]: """Simple inference to get output channel size. Override if needed.""" dummy_input = torch.zeros(1, 1, n_input, device=self._state.device) with torch.no_grad(): output = model(dummy_input) if isinstance(output, dict): return {k: v.shape[-1] for k, v in output.items()} else: return {"output": output.shape[-1]} def _init_optimizer(self) -> None: self._state.optimizer = torch.optim.AdamW( self._state.model.parameters(), lr=self.settings.learning_rate, weight_decay=self.settings.weight_decay, ) self._state.scheduler = ( torch.optim.lr_scheduler.ExponentialLR(self._state.optimizer, gamma=self.settings.scheduler_gamma) if self.settings.scheduler_gamma > 0.0 else None ) def _validate_loss_keys(self, output_keys: list[str]): if isinstance(self.settings.loss_fn, dict): missing = [k for k in output_keys if k not in self.settings.loss_fn] if missing: raise ValueError(f"Missing loss function(s) for output keys: {missing}") def _to_tensor(self: P, data: np.ndarray) -> torch.Tensor: dtype = torch.float32 if self.settings.single_precision else torch.float64 if isinstance(data, torch.Tensor): return data.detach().clone().to(device=self._state.device, dtype=dtype) return torch.tensor(data, dtype=dtype, device=self._state.device)
[docs] def save_checkpoint(self: P, path: str) -> None: """Save the current model state to a checkpoint file.""" if self._state.model is None: raise RuntimeError("Model must be initialized before saving a checkpoint.") checkpoint = { "model_state_dict": self._state.model.state_dict(), "config": self.settings.model_kwargs or {}, } # Add optimizer state if available if hasattr(self._state, "optimizer") and self._state.optimizer is not None: checkpoint["optimizer_state_dict"] = self._state.optimizer.state_dict() torch.save(checkpoint, path)
def _ensure_batched(self, tensor: torch.Tensor) -> tuple[torch.Tensor, bool]: """ Ensure tensor has a batch dimension. Returns the potentially modified tensor and a flag indicating whether a dimension was added. """ if tensor.ndim == 2: return tensor.unsqueeze(0), True return tensor, False def _common_process(self: P, message: AxisArray) -> list[AxisArray]: data = message.data data = self._to_tensor(data) # Add batch dimension if missing data, added_batch_dim = self._ensure_batched(data) with torch.no_grad(): output = self._state.model(data) if isinstance(output, dict): output_messages = [ replace( message, data=value.cpu().numpy().squeeze(0) if added_batch_dim else value.cpu().numpy(), axes={ **message.axes, "ch": self._state.chan_ax[key], }, key=key, ) for key, value in output.items() ] return output_messages return [ replace( message, data=output.cpu().numpy().squeeze(0) if added_batch_dim else output.cpu().numpy(), axes={ **message.axes, "ch": self._state.chan_ax["output"], }, ) ] def _common_reset_state(self: P, message: AxisArray, model_kwargs: dict) -> None: n_input = message.data.shape[message.get_axis_idx("ch")] if "input_size" in model_kwargs: if model_kwargs["input_size"] != n_input: raise ValueError( f"Mismatch between model_kwargs['input_size']={model_kwargs['input_size']} " f"and input data channels={n_input}" ) else: model_kwargs["input_size"] = n_input device = "cuda" if torch.cuda.is_available() else ("mps" if torch.mps.is_available() else "cpu") device = self.settings.device or device self._state.device = torch.device(device) model_class = self._import_model(self.settings.model_class) self._state.model = self._init_model( model_class=model_class, params=model_kwargs, config_path=self.settings.config_path, checkpoint_path=self.settings.checkpoint_path, device=device, ) self._state.model.eval() output_sizes = self._infer_output_sizes(self._state.model, n_input) self._state.chan_ax = { head: AxisArray.CoordinateAxis( data=np.array([f"{head}_ch{_}" for _ in range(size)]), dims=["ch"], ) for head, size in output_sizes.items() }
[docs] class TorchSimpleProcessor( BaseStatefulTransformer[TorchSimpleSettings, AxisArray, AxisArray, TorchSimpleState], TorchProcessorMixin, ModelInitMixin, ): def _reset_state(self, message: AxisArray) -> None: model_kwargs = dict(self.settings.model_kwargs or {}) self._common_reset_state(message, model_kwargs) def _process(self, message: AxisArray) -> list[AxisArray]: """Process the input message and return the output messages.""" return self._common_process(message)
[docs] class TorchSimpleUnit( BaseTransformerUnit[ TorchSimpleSettings, AxisArray, AxisArray, TorchSimpleProcessor, ] ): SETTINGS = TorchSimpleSettings
[docs] @ez.subscriber(BaseTransformerUnit.INPUT_SIGNAL, zero_copy=True) @ez.publisher(BaseTransformerUnit.OUTPUT_SIGNAL) @profile_subpub(trace_oldest=False) async def on_signal(self, message: AxisArray) -> typing.AsyncGenerator: results = await self.processor.__acall__(message) for result in results: yield self.OUTPUT_SIGNAL, result
[docs] class TorchModelProcessor( BaseAdaptiveTransformer[TorchModelSettings, AxisArray, AxisArray, TorchModelState], TorchProcessorMixin, ModelInitMixin, ): def _reset_state(self, message: AxisArray) -> None: model_kwargs = dict(self.settings.model_kwargs or {}) self._common_reset_state(message, model_kwargs) self._init_optimizer() self._validate_loss_keys(list(self._state.chan_ax.keys())) def _process(self, message: AxisArray) -> list[AxisArray]: return self._common_process(message)
[docs] def partial_fit(self, message: SampleMessage) -> None: self._state.model.train() X = self._to_tensor(message.sample.data) X, batched = self._ensure_batched(X) y_targ = message.trigger.value if not isinstance(y_targ, dict): y_targ = {"output": y_targ} y_targ = {key: self._to_tensor(value) for key, value in y_targ.items()} if batched: for key in y_targ: y_targ[key] = y_targ[key].unsqueeze(0) loss_fns = self.settings.loss_fn if loss_fns is None: raise ValueError("loss_fn must be provided in settings to use partial_fit") if not isinstance(loss_fns, dict): loss_fns = {k: loss_fns for k in y_targ.keys()} weights = self.settings.loss_weights or {} with torch.set_grad_enabled(True): y_pred = self._state.model(X) if not isinstance(y_pred, dict): y_pred = {"output": y_pred} losses = [] for key in y_targ.keys(): loss_fn = loss_fns.get(key) if loss_fn is None: raise ValueError(f"Loss function for key '{key}' is not defined in settings.") if isinstance(loss_fn, torch.nn.CrossEntropyLoss): loss = loss_fn(y_pred[key].permute(0, 2, 1), y_targ[key].long()) else: loss = loss_fn(y_pred[key], y_targ[key]) weight = weights.get(key, 1.0) losses.append(loss * weight) total_loss = sum(losses) self._state.optimizer.zero_grad() total_loss.backward() self._state.optimizer.step() if self._state.scheduler is not None: self._state.scheduler.step() self._state.model.eval()
[docs] class TorchModelUnit( BaseAdaptiveTransformerUnit[ TorchModelSettings, AxisArray, AxisArray, TorchModelProcessor, ] ): SETTINGS = TorchModelSettings
[docs] @ez.subscriber(BaseAdaptiveTransformerUnit.INPUT_SIGNAL, zero_copy=True) @ez.publisher(BaseAdaptiveTransformerUnit.OUTPUT_SIGNAL) @profile_subpub(trace_oldest=False) async def on_signal(self, message: AxisArray) -> typing.AsyncGenerator: results = await self.processor.__acall__(message) for result in results: yield self.OUTPUT_SIGNAL, result