Skip to content

Commit b69b8ca

Browse files
Fix/Robuster parameters check when loading weights (#1952)
* fix: comparing the parameters stored in .model_params (saved and loading models) to make it more robust. this check can be skipped (not recommended). * feat: nicer message, all the hp discrepancies are listed at once * feat: repr method for LikelihoodModel * fix: removed unused param from LikelihoodModel.__repr__ * fix: better sanity check of the kwargs during weights loading from ckpt * fix: removed copy-paste leftovers from unittests * fix: removing redudant helper code * test: adding tests, made the assertion more specific to distinguish incorrect/missing * Apply suggestions from code review Co-authored-by: Dennis Bader <[email protected]> * fix: linting * Update CHANGELOG.md --------- Co-authored-by: Dennis Bader <[email protected]>
1 parent b3463ea commit b69b8ca

File tree

4 files changed

+301
-123
lines changed

4 files changed

+301
-123
lines changed

CHANGELOG.md

+1
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ but cannot always guarantee backwards compatibility. Changes that may **break co
1414
- Fixed a bug in `TimeSeries.from_dataframe()` when using a pandas.DataFrame with `df.columns.name != None`. [#1938](https://github.com/unit8co/darts/pull/1938) by [Antoine Madrona](https://github.com/madtoinou).
1515
- Fixed a bug in `RegressionEnsembleModel.extreme_lags` when the forecasting models have only covariates lags. [#1942](https://github.com/unit8co/darts/pull/1942) by [Antoine Madrona](https://github.com/madtoinou).
1616
- Fixed a bug when using `TFTExplainer` with a `TFTModel` running on GPU. [#1949](https://github.com/unit8co/darts/pull/1949) by [Dennis Bader](https://github.com/dennisbader).
17+
- Fixed a bug in `TorchForecastingModel.load_weights()` that raised an error when loading the weights from a valid architecture. [#1952](https://github.com/unit8co/darts/pull/1952) by [Antoine Madrona](https://github.com/madtoinou).
1718

1819

1920
## [0.25.0](https://github.com/unit8co/darts/tree/0.25.0) (2023-08-04)

darts/models/forecasting/torch_forecasting_model.py

+114-40
Original file line numberDiff line numberDiff line change
@@ -1743,6 +1743,7 @@ def load_weights_from_checkpoint(
17431743
best: bool = True,
17441744
strict: bool = True,
17451745
load_encoders: bool = True,
1746+
skip_checks: bool = False,
17461747
**kwargs,
17471748
):
17481749
"""
@@ -1758,6 +1759,9 @@ def load_weights_from_checkpoint(
17581759
For manually saved model, consider using :meth:`load() <TorchForecastingModel.load()>` or
17591760
:meth:`load_weights() <TorchForecastingModel.load_weights()>` instead.
17601761
1762+
Note: This method needs to be able to access the darts model checkpoint (.pt) in order to load the encoders
1763+
and perform sanity checks on the model parameters.
1764+
17611765
Parameters
17621766
----------
17631767
model_name
@@ -1777,6 +1781,9 @@ def load_weights_from_checkpoint(
17771781
load_encoders
17781782
If set, will load the encoders from the model to enable direct call of fit() or predict().
17791783
Default: ``True``.
1784+
skip_checks
1785+
If set, will disable the loading of the encoders and the sanity checks on model parameters
1786+
(not recommended). Cannot be used with `load_encoders=True`. Default: ``False``.
17801787
**kwargs
17811788
Additional kwargs for PyTorch's :func:`load` method, such as ``map_location`` to load the model onto a
17821789
different device than the one from which it was saved.
@@ -1790,6 +1797,13 @@ def load_weights_from_checkpoint(
17901797
logger,
17911798
)
17921799

1800+
raise_if(
1801+
skip_checks and load_encoders,
1802+
"`skip-checks` and `load_encoders` are mutually exclusive parameters and cannot be both "
1803+
"set to `True`.",
1804+
logger,
1805+
)
1806+
17931807
# use the name of the model being loaded with the saved weights
17941808
if model_name is None:
17951809
model_name = self.model_name
@@ -1816,39 +1830,6 @@ def load_weights_from_checkpoint(
18161830

18171831
ckpt_path = os.path.join(checkpoint_dir, file_name)
18181832
ckpt = torch.load(ckpt_path, **kwargs)
1819-
ckpt_hyper_params = ckpt["hyper_parameters"]
1820-
1821-
# verify that the arguments passed to the constructor match those of the checkpoint
1822-
# add_encoders is checked in _load_encoders()
1823-
skipped_params = list(
1824-
inspect.signature(TorchForecastingModel.__init__).parameters.keys()
1825-
) + [
1826-
"loss_fn",
1827-
"torch_metrics",
1828-
"optimizer_cls",
1829-
"optimizer_kwargs",
1830-
"lr_scheduler_cls",
1831-
"lr_scheduler_kwargs",
1832-
]
1833-
for param_key, param_value in self.model_params.items():
1834-
# TODO: there are discrepancies between the param names, for ex num_layer/n_rnn_layers
1835-
if (
1836-
param_key in ckpt_hyper_params.keys()
1837-
and param_key not in skipped_params
1838-
):
1839-
# some parameters must be converted
1840-
if isinstance(ckpt_hyper_params[param_key], list) and not isinstance(
1841-
param_value, list
1842-
):
1843-
param_value = [param_value] * len(ckpt_hyper_params[param_key])
1844-
1845-
raise_if(
1846-
param_value != ckpt_hyper_params[param_key],
1847-
f"The values of the hyper parameter {param_key} should be identical between "
1848-
f"the instantiated model ({param_value}) and the loaded checkpoint "
1849-
f"({ckpt_hyper_params[param_key]}). Please adjust the model accordingly.",
1850-
logger,
1851-
)
18521833

18531834
# indicate to the user than checkpoints generated with darts <= 0.23.1 are not supported
18541835
raise_if_not(
@@ -1867,17 +1848,32 @@ def load_weights_from_checkpoint(
18671848
]
18681849
self.train_sample = tuple(mock_train_sample)
18691850

1870-
# updating model attributes before self._init_model() which create new ckpt
1871-
tfm_save_file_path = os.path.join(tfm_save_file_dir, tfm_save_file_name)
1872-
with open(tfm_save_file_path, "rb") as tfm_save_file:
1873-
tfm_save: TorchForecastingModel = torch.load(
1874-
tfm_save_file, map_location=kwargs.get("map_location", None)
1875-
)
1851+
if not skip_checks:
1852+
# path to the tfm checkpoint (darts model, .pt extension)
1853+
tfm_save_file_path = os.path.join(tfm_save_file_dir, tfm_save_file_name)
1854+
if not os.path.exists(tfm_save_file_path):
1855+
raise_log(
1856+
FileNotFoundError(
1857+
f"Could not find {tfm_save_file_path}, necessary to load the encoders "
1858+
f"and run sanity checks on the model parameters."
1859+
),
1860+
logger,
1861+
)
1862+
1863+
# updating model attributes before self._init_model() which create new tfm ckpt
1864+
with open(tfm_save_file_path, "rb") as tfm_save_file:
1865+
tfm_save: TorchForecastingModel = torch.load(
1866+
tfm_save_file, map_location=kwargs.get("map_location", None)
1867+
)
1868+
18761869
# encoders are necessary for direct inference
18771870
self.encoders, self.add_encoders = self._load_encoders(
18781871
tfm_save, load_encoders
18791872
)
18801873

1874+
# meaningful error message if parameters are incompatible with the ckpt weights
1875+
self._check_ckpt_parameters(tfm_save)
1876+
18811877
# instanciate the model without having to call `fit_from_dataset`
18821878
self.model = self._init_model()
18831879
# cast model precision to correct type
@@ -1887,10 +1883,15 @@ def load_weights_from_checkpoint(
18871883
# update the fit_called attribute to allow for direct inference
18881884
self._fit_called = True
18891885

1890-
def load_weights(self, path: str, load_encoders: bool = True, **kwargs):
1886+
def load_weights(
1887+
self, path: str, load_encoders: bool = True, skip_checks: bool = False, **kwargs
1888+
):
18911889
"""
18921890
Loads the weights from a manually saved model (saved with :meth:`save() <TorchForecastingModel.save()>`).
18931891
1892+
Note: This method needs to be able to access the darts model checkpoint (.pt) in order to load the encoders
1893+
and perform sanity checks on the model parameters.
1894+
18941895
Parameters
18951896
----------
18961897
path
@@ -1899,6 +1900,9 @@ def load_weights(self, path: str, load_encoders: bool = True, **kwargs):
18991900
load_encoders
19001901
If set, will load the encoders from the model to enable direct call of fit() or predict().
19011902
Default: ``True``.
1903+
skip_checks
1904+
If set, will disable the loading of the encoders and the sanity checks on model parameters
1905+
(not recommended). Cannot be used with `load_encoders=True`. Default: ``False``.
19021906
**kwargs
19031907
Additional kwargs for PyTorch's :func:`load` method, such as ``map_location`` to load the model onto a
19041908
different device than the one from which it was saved.
@@ -1916,6 +1920,7 @@ def load_weights(self, path: str, load_encoders: bool = True, **kwargs):
19161920
self.load_weights_from_checkpoint(
19171921
file_name=path_ptl_ckpt,
19181922
load_encoders=load_encoders,
1923+
skip_checks=skip_checks,
19191924
**kwargs,
19201925
)
19211926

@@ -2058,6 +2063,75 @@ def _load_encoders(
20582063

20592064
return new_encoders, new_add_encoders
20602065

2066+
def _check_ckpt_parameters(self, tfm_save):
2067+
"""
2068+
Check that the positional parameters used to instantiate the new model loading the weights match those
2069+
of the saved model, to return meaningful messages in case of discrepancies.
2070+
"""
2071+
# parameters unrelated to the weights shape
2072+
skipped_params = list(
2073+
inspect.signature(TorchForecastingModel.__init__).parameters.keys()
2074+
) + [
2075+
"loss_fn",
2076+
"torch_metrics",
2077+
"optimizer_cls",
2078+
"optimizer_kwargs",
2079+
"lr_scheduler_cls",
2080+
"lr_scheduler_kwargs",
2081+
]
2082+
# model_params can be missing some kwargs
2083+
params_to_check = set(tfm_save.model_params.keys()).union(
2084+
self.model_params.keys()
2085+
) - set(skipped_params)
2086+
2087+
incorrect_params = []
2088+
missing_params = []
2089+
for param_key in params_to_check:
2090+
# param was not used at loading model creation
2091+
if param_key not in self.model_params.keys():
2092+
missing_params.append((param_key, tfm_save.model_params[param_key]))
2093+
# new param was used at loading model creation
2094+
elif param_key not in tfm_save.model_params.keys():
2095+
incorrect_params.append(
2096+
(
2097+
param_key,
2098+
None,
2099+
self.model_params[param_key],
2100+
)
2101+
)
2102+
# param was different at loading model creation
2103+
elif self.model_params[param_key] != tfm_save.model_params[param_key]:
2104+
# NOTE: for TFTModel, default is None but converted to `QuantileRegression()`
2105+
incorrect_params.append(
2106+
(
2107+
param_key,
2108+
tfm_save.model_params[param_key],
2109+
self.model_params[param_key],
2110+
)
2111+
)
2112+
2113+
# at least one discrepancy was detected
2114+
if len(missing_params) + len(incorrect_params) > 0:
2115+
msg = [
2116+
"The values of the hyper-parameters in the model and loaded checkpoint should be identical."
2117+
]
2118+
2119+
# warning messages formated to facilate copy-pasting
2120+
if len(missing_params) > 0:
2121+
msg += ["missing :"]
2122+
msg += [
2123+
f" - {param}={exp_val}" for (param, exp_val) in missing_params
2124+
]
2125+
2126+
if len(incorrect_params) > 0:
2127+
msg += ["incorrect :"]
2128+
msg += [
2129+
f" - found {param}={cur_val}, should be {param}={exp_val}"
2130+
for (param, exp_val, cur_val) in incorrect_params
2131+
]
2132+
2133+
raise_log(ValueError("\n".join(msg)), logger)
2134+
20612135
def __getstate__(self):
20622136
# do not pickle the PyTorch LightningModule, and Trainer
20632137
return {k: v for k, v in self.__dict__.items() if k not in TFM_ATTRS_NO_PICKLE}

0 commit comments

Comments
 (0)