import typing
import ezmsg.core as ez
import numpy as np
import torch
import torch.nn
from ezmsg.baseproc import (
BaseAdaptiveTransformer,
BaseAdaptiveTransformerUnit,
processor_state,
)
from ezmsg.sigproc.sampler import SampleMessage
from ezmsg.util.messages.axisarray import AxisArray
from ezmsg.util.messages.util import replace
from ..model.mlp_old import MLP
[docs]
class MLPSettings(ez.Settings):
hidden_channels: list[int]
"""List of the hidden channel dimensions"""
norm_layer: typing.Callable[..., torch.nn.Module] | None = None
"""Norm layer that will be stacked on top of the linear layer. If None this layer won’t be used."""
activation_layer: typing.Callable[..., torch.nn.Module] | None = torch.nn.ReLU
"""Activation function which will be stacked on top of the normalization layer (if not None),
otherwise on top of the linear layer. If None this layer won’t be used."""
inplace: bool | None = None
"""Parameter for the activation layer, which can optionally do the operation in-place.
Default is None, which uses the respective default values of the activation_layer and Dropout layer."""
bias: bool = True
"""Whether to use bias in the linear layer."""
dropout: float = 0.0
"""The probability for the dropout layer."""
single_precision: bool = True
learning_rate: float = 0.001
scheduler_gamma: float = 0.999
"""Learning scheduler decay rate. Set to 0.0 to disable the scheduler."""
checkpoint_path: str | None = None
"""
Path to a checkpoint file containing model weights.
If None, the model will be initialized with random weights.
"""
[docs]
@processor_state
class MLPState:
model: MLP | None = None
optimizer: torch.optim.Optimizer | None = None
scheduler: torch.optim.lr_scheduler.LRScheduler | None = None
template: AxisArray | None = None
device: object | None = None
[docs]
class MLPProcessor(BaseAdaptiveTransformer[MLPSettings, AxisArray, AxisArray, MLPState]):
def _hash_message(self, message: AxisArray) -> int:
hash_items = (message.key,)
if "ch" in message.dims:
hash_items += (message.data.shape[message.get_axis_idx("ch")],)
return hash(hash_items)
def _reset_state(self, message: AxisArray) -> None:
# Create the model
self._state.model = MLP(
in_channels=message.data.shape[message.get_axis_idx("ch")],
hidden_channels=self.settings.hidden_channels,
norm_layer=self.settings.norm_layer,
activation_layer=self.settings.activation_layer,
inplace=self.settings.inplace,
bias=self.settings.bias,
dropout=self.settings.dropout,
)
# Load model weights from checkpoint if specified
if self.settings.checkpoint_path is not None:
try:
checkpoint = torch.load(self.settings.checkpoint_path)
self._state.model.load_state_dict(checkpoint["model_state_dict"])
except Exception as e:
raise RuntimeError(f"Failed to load checkpoint from {self.settings.checkpoint_path}: {str(e)}")
# Set the model to evaluation mode by default
self._state.model.eval()
# Create the optimizer
self._state.optimizer = torch.optim.Adam(self._state.model.parameters(), lr=self.settings.learning_rate)
# Update the optimizer from checkpoint if it exists
if self.settings.checkpoint_path is not None:
try:
checkpoint = torch.load(self.settings.checkpoint_path)
if "optimizer_state_dict" in checkpoint:
self._state.optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
except Exception as e:
raise RuntimeError(f"Failed to load optimizer from {self.settings.checkpoint_path}: {str(e)}")
# TODO: Should the model be moved to a device before the next line?
self._state.device = next(self.state.model.parameters()).device
# Optionally create the learning rate scheduler
self._state.scheduler = (
torch.optim.lr_scheduler.ExponentialLR(self._state.optimizer, gamma=self.settings.scheduler_gamma)
if self.settings.scheduler_gamma > 0.0
else None
)
# Create the output channel axis for reuse in each output.
n_output_channels = self.settings.hidden_channels[-1]
self._state.chan_ax = AxisArray.CoordinateAxis(
data=np.array([f"ch{_}" for _ in range(n_output_channels)]), dims=["ch"]
)
[docs]
def save_checkpoint(self, path: str) -> None:
"""Save the current model state to a checkpoint file.
Args:
path: Path where the checkpoint will be saved
"""
checkpoint = {
"model_state_dict": self._state.model.state_dict(),
"optimizer_state_dict": self._state.optimizer.state_dict(),
}
torch.save(checkpoint, path)
def _to_tensor(self, data: np.ndarray) -> torch.Tensor:
dtype = torch.float32 if self.settings.single_precision else torch.float64
return torch.tensor(data, dtype=dtype, device=self._state.device)
[docs]
def partial_fit(self, message: SampleMessage) -> None:
self._state.model.train()
# TODO: loss_fn should be determined by setting
loss_fn = torch.nn.functional.mse_loss
X = self._to_tensor(message.sample.data)
y_targ = self._to_tensor(message.trigger.value)
with torch.set_grad_enabled(True):
self._state.model.train()
y_pred = self.state.model(X)
loss = loss_fn(y_pred, y_targ)
self.state.optimizer.zero_grad()
loss.backward()
self.state.optimizer.step() # Update weights
if self.state.scheduler is not None:
self.state.scheduler.step() # Update learning rate
self._state.model.eval()
def _process(self, message: AxisArray) -> AxisArray:
data = message.data
if not isinstance(data, torch.Tensor):
data = torch.tensor(
data,
dtype=torch.float32 if self.settings.single_precision else torch.float64,
)
with torch.no_grad():
output = self.state.model(data.to(self.state.device))
return replace(
message,
data=output.cpu().numpy(),
axes={
**message.axes,
"ch": self.state.chan_ax,
},
)
[docs]
class MLPUnit(
BaseAdaptiveTransformerUnit[
MLPSettings,
AxisArray,
AxisArray,
MLPProcessor,
]
):
SETTINGS = MLPSettings