Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix/Robuster parameters check when loading weights #1952

Merged
merged 12 commits into from
Aug 15, 2023
Merged
154 changes: 114 additions & 40 deletions darts/models/forecasting/torch_forecasting_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1743,6 +1743,7 @@ def load_weights_from_checkpoint(
best: bool = True,
strict: bool = True,
load_encoders: bool = True,
skip_checks: bool = False,
**kwargs,
):
"""
Expand All @@ -1758,6 +1759,9 @@ def load_weights_from_checkpoint(
For manually saved model, consider using :meth:`load() <TorchForecastingModel.load()>` or
:meth:`load_weights() <TorchForecastingModel.load_weights()>` instead.

Note: This method needs to be able to access the darts model checkpoint (.pt) in order to load the encoders
and perform sanity checks on the model parameters.

Parameters
----------
model_name
Expand All @@ -1777,6 +1781,9 @@ def load_weights_from_checkpoint(
load_encoders
If set, will load the encoders from the model to enable direct call of fit() or predict().
Default: ``True``.
skip_checks
If set, will disable the loading of the encoders and the sanity checks on model parameters
(not recommended). Cannot be used with `load_encoders=True`. Default: ``False``.
**kwargs
Additional kwargs for PyTorch's :func:`load` method, such as ``map_location`` to load the model onto a
different device than the one from which it was saved.
Expand All @@ -1790,6 +1797,13 @@ def load_weights_from_checkpoint(
logger,
)

raise_if(
skip_checks and load_encoders,
"`skip-checks` and `load_encoders` are mutually exclusive parameters and cannot be both "
"set to `True`.",
logger,
)

# use the name of the model being loaded with the saved weights
if model_name is None:
model_name = self.model_name
Expand All @@ -1816,39 +1830,6 @@ def load_weights_from_checkpoint(

ckpt_path = os.path.join(checkpoint_dir, file_name)
ckpt = torch.load(ckpt_path, **kwargs)
ckpt_hyper_params = ckpt["hyper_parameters"]

# verify that the arguments passed to the constructor match those of the checkpoint
# add_encoders is checked in _load_encoders()
skipped_params = list(
inspect.signature(TorchForecastingModel.__init__).parameters.keys()
) + [
"loss_fn",
"torch_metrics",
"optimizer_cls",
"optimizer_kwargs",
"lr_scheduler_cls",
"lr_scheduler_kwargs",
]
for param_key, param_value in self.model_params.items():
# TODO: there are discrepancies between the param names, for ex num_layer/n_rnn_layers
if (
param_key in ckpt_hyper_params.keys()
and param_key not in skipped_params
):
# some parameters must be converted
if isinstance(ckpt_hyper_params[param_key], list) and not isinstance(
param_value, list
):
param_value = [param_value] * len(ckpt_hyper_params[param_key])

raise_if(
param_value != ckpt_hyper_params[param_key],
f"The values of the hyper parameter {param_key} should be identical between "
f"the instantiated model ({param_value}) and the loaded checkpoint "
f"({ckpt_hyper_params[param_key]}). Please adjust the model accordingly.",
logger,
)

# indicate to the user than checkpoints generated with darts <= 0.23.1 are not supported
raise_if_not(
Expand All @@ -1867,17 +1848,32 @@ def load_weights_from_checkpoint(
]
self.train_sample = tuple(mock_train_sample)

# updating model attributes before self._init_model() which create new ckpt
tfm_save_file_path = os.path.join(tfm_save_file_dir, tfm_save_file_name)
with open(tfm_save_file_path, "rb") as tfm_save_file:
tfm_save: TorchForecastingModel = torch.load(
tfm_save_file, map_location=kwargs.get("map_location", None)
)
if not skip_checks:
# path to the tfm checkpoint (darts model, .pt extension)
tfm_save_file_path = os.path.join(tfm_save_file_dir, tfm_save_file_name)
if not os.path.exists(tfm_save_file_path):
raise_log(
FileNotFoundError(
f"Could not find {tfm_save_file_path}, necessary to load the encoders "
f"and run sanity checks on the model parameters."
),
logger,
)

# updating model attributes before self._init_model() which create new tfm ckpt
with open(tfm_save_file_path, "rb") as tfm_save_file:
tfm_save: TorchForecastingModel = torch.load(
tfm_save_file, map_location=kwargs.get("map_location", None)
)

# encoders are necessary for direct inference
self.encoders, self.add_encoders = self._load_encoders(
tfm_save, load_encoders
)

# meaningful error message if parameters are incompatible with the ckpt weights
self._check_ckpt_parameters(tfm_save)

# instanciate the model without having to call `fit_from_dataset`
self.model = self._init_model()
# cast model precision to correct type
Expand All @@ -1887,10 +1883,15 @@ def load_weights_from_checkpoint(
# update the fit_called attribute to allow for direct inference
self._fit_called = True

def load_weights(self, path: str, load_encoders: bool = True, **kwargs):
def load_weights(
self, path: str, load_encoders: bool = True, skip_checks: bool = False, **kwargs
):
"""
Loads the weights from a manually saved model (saved with :meth:`save() <TorchForecastingModel.save()>`).

Note: This method needs to be able to access the darts model checkpoint (.pt) in order to load the encoders
and perform sanity checks on the model parameters.

Parameters
----------
path
Expand All @@ -1899,6 +1900,9 @@ def load_weights(self, path: str, load_encoders: bool = True, **kwargs):
load_encoders
If set, will load the encoders from the model to enable direct call of fit() or predict().
Default: ``True``.
skip_checks
If set, will disable the loading of the encoders and the sanity checks on model parameters
(not recommended). Cannot be used with `load_encoders=True`. Default: ``False``.
**kwargs
Additional kwargs for PyTorch's :func:`load` method, such as ``map_location`` to load the model onto a
different device than the one from which it was saved.
Expand All @@ -1916,6 +1920,7 @@ def load_weights(self, path: str, load_encoders: bool = True, **kwargs):
self.load_weights_from_checkpoint(
file_name=path_ptl_ckpt,
load_encoders=load_encoders,
skip_checks=skip_checks,
**kwargs,
)

Expand Down Expand Up @@ -2058,6 +2063,75 @@ def _load_encoders(

return new_encoders, new_add_encoders

def _check_ckpt_parameters(self, tfm_save):
"""
Check that the positional parameters used to instantiate the new model loading the weights match those
of the saved model, to return meaningful messages in case of discrepancies.
"""
# parameters unrelated to the weights shape
skipped_params = list(
inspect.signature(TorchForecastingModel.__init__).parameters.keys()
) + [
"loss_fn",
"torch_metrics",
"optimizer_cls",
"optimizer_kwargs",
"lr_scheduler_cls",
"lr_scheduler_kwargs",
]
# model_params can be missing some kwargs
params_to_check = set(tfm_save.model_params.keys()).union(
self.model_params.keys()
) - set(skipped_params)

incorrect_params = []
missing_params = []
for param_key in params_to_check:
# param was not used at loading model creation
if param_key not in self.model_params.keys():
missing_params.append((param_key, tfm_save.model_params[param_key]))
# new param was used at loading model creation
elif param_key not in tfm_save.model_params.keys():
incorrect_params.append(
(
param_key,
None,
self.model_params[param_key],
)
)
# param was different at loading model creation
elif self.model_params[param_key] != tfm_save.model_params[param_key]:
# NOTE: for TFTModel, default is None but converted to `QuantileRegression()`
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this still relevant?

Suggested change
# NOTE: for TFTModel, default is None but converted to `QuantileRegression()`

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it means that if the user manually use QuantileRegression() as likelihood for the loading model and the initial model was created with likelihood=None, an error will be raised despite the resulting model being acceptable. It's probably a corner case that will not occur much but we might want to prevent it?

incorrect_params.append(
(
param_key,
tfm_save.model_params[param_key],
self.model_params[param_key],
)
)

# at least one discrepancy was detected
if len(missing_params) + len(incorrect_params) > 0:
msg = [
"The values of the hyper-parameters in the model and loaded checkpoint should be identical."
]

# warning messages formated to facilate copy-pasting
if len(missing_params) > 0:
msg += ["missing :"]
msg += [
f" - {param}={exp_val}" for (param, exp_val) in missing_params
]

if len(incorrect_params) > 0:
msg += ["incorrect :"]
msg += [
f" - found {param}={cur_val}, should be {param}={exp_val}"
for (param, exp_val, cur_val) in incorrect_params
]

raise_log(ValueError("\n".join(msg)), logger)

def __getstate__(self):
# do not pickle the PyTorch LightningModule, and Trainer
return {k: v for k, v in self.__dict__.items() if k not in TFM_ATTRS_NO_PICKLE}
Expand Down
Loading