Source code for ezmsg.learn.process.base
import inspect
import json
import typing
from pathlib import Path
import ezmsg.core as ez
import torch
[docs]
class ModelInitMixin:
"""
Mixin class to support model initialization from:
1. Setting parameters
2. Config file
3. Checkpoint file
"""
@staticmethod
def _merge_config(model_kwargs: dict, config) -> None:
"""
Mutate the model_kwargs dictionary with the config parameters.
Args:
model_kwargs: Original to-be-mutated model kwargs.
config: Update config parameters.
Returns:
None because model_kwargs is mutated in place.
"""
if "model_params" in config:
config = config["model_params"]
# Update model_kwargs with config parameters
for key, value in config.items():
if key in model_kwargs:
if model_kwargs[key] != value:
ez.logger.warning(f"Config parameter {key} ({value}) differs from settings ({model_kwargs[key]}).")
else:
ez.logger.warning(f"Config parameter {key} is not in model_kwargs.")
model_kwargs[key] = value
def _filter_model_kwargs(self, model_class, kwargs: dict) -> dict:
valid_params = inspect.signature(model_class.__init__).parameters
filtered_out = set(kwargs.keys()) - {k for k in valid_params if k != "self"}
if filtered_out:
ez.logger.warning(
"Ignoring unexpected model parameters not accepted by"
f"{model_class.__name__} constructor: {sorted(filtered_out)}"
)
# Keep all valid parameters, including None values, so checkpoint-inferred values can overwrite them
return {k: v for k, v in kwargs.items() if k in valid_params and k != "self"}
def _init_model(
self,
model_class,
params: dict[str, typing.Any] | None = None,
config_path: str | None = None,
checkpoint_path: str | None = None,
device: str = "cpu",
state_dict_prefix: str | None = None,
weights_only: bool | None = None,
) -> torch.nn.Module:
"""
Args:
model_class: The class of the model to be initialized.
params: A dictionary of setting parameters to be used for model initialization.
config_path: Path to a JSON config file to update model parameters.
checkpoint_path: Path to a checkpoint file to load model weights and possibly config.
Returns:
The initialized model.
The model will be initialized with the correct config and weights.
"""
# Model parameters are taken from multiple sources, in ascending priority:
# 1. Setting parameters
# 2. Config file if provided
# 3. "config" entry in checkpoint file if checkpoint file provided and config present
# 4. Sizes of weights in checkpoint file if provided
# Get configs from setting params.
model_kwargs = params or {}
state_dict = None
# Check if a config file is provided and if so use that to update kwargs (with warnings).
if config_path:
config_path = Path(config_path)
if not config_path.exists():
ez.logger.error(f"Config path {config_path} does not exist.")
raise FileNotFoundError(f"Config path {config_path} does not exist.")
try:
with open(config_path, "r") as f:
config = json.load(f)
self._merge_config(model_kwargs, config)
except Exception as e:
raise RuntimeError(f"Failed to load config from {config_path}: {str(e)}")
# If a checkpoint file is provided, load it.
if checkpoint_path:
checkpoint_path = Path(checkpoint_path)
if not checkpoint_path.exists():
ez.logger.error(f"Checkpoint path {checkpoint_path} does not exist.")
raise FileNotFoundError(f"Checkpoint path {checkpoint_path} does not exist.")
try:
checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=weights_only)
if "config" in checkpoint:
config = checkpoint["config"]
self._merge_config(model_kwargs, config)
# Load the model weights and infer the config.
state_dict = checkpoint
if "model_state_dict" in checkpoint:
state_dict = checkpoint["model_state_dict"]
elif "state_dict" in checkpoint:
# This is for backward compatibility with older checkpoints
# that used "state_dict" instead of "model_state_dict"
state_dict = checkpoint["state_dict"]
infer_config = getattr(
model_class,
"infer_config_from_state_dict",
lambda _state_dict: {}, # Default to empty dict if not defined
)
infer_kwargs = {"rnn_type": model_kwargs["rnn_type"]} if "rnn_type" in model_kwargs else {}
self._merge_config(
model_kwargs,
infer_config(state_dict, **infer_kwargs),
)
except Exception as e:
raise RuntimeError(f"Failed to load checkpoint from {checkpoint_path}: {str(e)}")
# Filter model_kwargs to only include valid parameters for the model class
filtered_kwargs = self._filter_model_kwargs(model_class, model_kwargs)
# Remove None values from filtered_kwargs to avoid passing them to the model constructor
# This should only happen for parameters that weren't inferred from the checkpoint
final_kwargs = {k: v for k, v in filtered_kwargs.items() if v is not None}
# Create the model with the final kwargs
model = model_class(**final_kwargs)
# Finally, load the weights.
if state_dict:
if state_dict_prefix:
# If a prefix is provided, filter the state_dict keys
state_dict = {
k[len(state_dict_prefix) :]: v for k, v in state_dict.items() if k.startswith(state_dict_prefix)
}
# Load the model weights
missing, unexpected = model.load_state_dict(state_dict, strict=False, assign=True)
if missing or unexpected:
ez.logger.warning(f"Partial load: missing keys: {missing}, unexpected keys: {unexpected}")
model.to(device)
return model