From 7dff28a413a6cd28c120d67a0bf6abedb990e9ff Mon Sep 17 00:00:00 2001 From: madtoinou Date: Thu, 10 Aug 2023 14:57:44 +0200 Subject: [PATCH 01/11] fix: comparing the parameters stored in .model_params (saved and loading models) to make it more robust. this check can be skipped (not recommended). --- .../forecasting/torch_forecasting_model.py | 121 ++++++++++++------ 1 file changed, 81 insertions(+), 40 deletions(-) diff --git a/darts/models/forecasting/torch_forecasting_model.py b/darts/models/forecasting/torch_forecasting_model.py index 316d664a3f..509d00613c 100644 --- a/darts/models/forecasting/torch_forecasting_model.py +++ b/darts/models/forecasting/torch_forecasting_model.py @@ -1743,6 +1743,7 @@ def load_weights_from_checkpoint( best: bool = True, strict: bool = True, load_encoders: bool = True, + skip_checks: bool = False, **kwargs, ): """ @@ -1758,6 +1759,9 @@ def load_weights_from_checkpoint( For manually saved model, consider using :meth:`load() ` or :meth:`load_weights() ` instead. + Note: This method needs to be able to access the darts model checkpoint (.pt) in order to load the encoders + and perform sanity checks on the model parameters. + Parameters ---------- model_name @@ -1777,6 +1781,9 @@ def load_weights_from_checkpoint( load_encoders If set, will load the encoders from the model to enable direct call of fit() or predict(). Default: ``True``. + skip_checks + If set, will disable the loading of the encoders and the sanity checks on model parameters + (not recommended). Cannot be used with `load_encoders=True`. Default: ``False``. **kwargs Additional kwargs for PyTorch's :func:`load` method, such as ``map_location`` to load the model onto a different device than the one from which it was saved. @@ -1790,6 +1797,13 @@ def load_weights_from_checkpoint( logger, ) + raise_if( + skip_checks and load_encoders, + "`skip-checks` and `load_encoders` are mutually exclusive parameters and cannot be both " + "set to `True`.", + logger, + ) + # use the name of the model being loaded with the saved weights if model_name is None: model_name = self.model_name @@ -1816,39 +1830,6 @@ def load_weights_from_checkpoint( ckpt_path = os.path.join(checkpoint_dir, file_name) ckpt = torch.load(ckpt_path, **kwargs) - ckpt_hyper_params = ckpt["hyper_parameters"] - - # verify that the arguments passed to the constructor match those of the checkpoint - # add_encoders is checked in _load_encoders() - skipped_params = list( - inspect.signature(TorchForecastingModel.__init__).parameters.keys() - ) + [ - "loss_fn", - "torch_metrics", - "optimizer_cls", - "optimizer_kwargs", - "lr_scheduler_cls", - "lr_scheduler_kwargs", - ] - for param_key, param_value in self.model_params.items(): - # TODO: there are discrepancies between the param names, for ex num_layer/n_rnn_layers - if ( - param_key in ckpt_hyper_params.keys() - and param_key not in skipped_params - ): - # some parameters must be converted - if isinstance(ckpt_hyper_params[param_key], list) and not isinstance( - param_value, list - ): - param_value = [param_value] * len(ckpt_hyper_params[param_key]) - - raise_if( - param_value != ckpt_hyper_params[param_key], - f"The values of the hyper parameter {param_key} should be identical between " - f"the instantiated model ({param_value}) and the loaded checkpoint " - f"({ckpt_hyper_params[param_key]}). Please adjust the model accordingly.", - logger, - ) # indicate to the user than checkpoints generated with darts <= 0.23.1 are not supported raise_if_not( @@ -1867,17 +1848,32 @@ def load_weights_from_checkpoint( ] self.train_sample = tuple(mock_train_sample) - # updating model attributes before self._init_model() which create new ckpt - tfm_save_file_path = os.path.join(tfm_save_file_dir, tfm_save_file_name) - with open(tfm_save_file_path, "rb") as tfm_save_file: - tfm_save: TorchForecastingModel = torch.load( - tfm_save_file, map_location=kwargs.get("map_location", None) - ) + if not skip_checks: + # path to the tfm checkpoint (darts model, .pt extension) + tfm_save_file_path = os.path.join(tfm_save_file_dir, tfm_save_file_name) + if not os.path.exists(tfm_save_file_path): + raise_log( + FileNotFoundError( + f"Could not find {tfm_save_file_path}, necessary to load the encoders " + f"and run sanity checks on the model parameters." + ), + logger, + ) + + # updating model attributes before self._init_model() which create new tfm ckpt + with open(tfm_save_file_path, "rb") as tfm_save_file: + tfm_save: TorchForecastingModel = torch.load( + tfm_save_file, map_location=kwargs.get("map_location", None) + ) + # encoders are necessary for direct inference self.encoders, self.add_encoders = self._load_encoders( tfm_save, load_encoders ) + # meaningful error message if parameters are incompatible with the ckpt weights + self._check_ckpt_parameters(tfm_save) + # instanciate the model without having to call `fit_from_dataset` self.model = self._init_model() # cast model precision to correct type @@ -1887,10 +1883,15 @@ def load_weights_from_checkpoint( # update the fit_called attribute to allow for direct inference self._fit_called = True - def load_weights(self, path: str, load_encoders: bool = True, **kwargs): + def load_weights( + self, path: str, load_encoders: bool = True, skip_checks: bool = False, **kwargs + ): """ Loads the weights from a manually saved model (saved with :meth:`save() `). + Note: This method needs to be able to access the darts model checkpoint (.pt) in order to load the encoders + and perform sanity checks on the model parameters. + Parameters ---------- path @@ -1899,6 +1900,9 @@ def load_weights(self, path: str, load_encoders: bool = True, **kwargs): load_encoders If set, will load the encoders from the model to enable direct call of fit() or predict(). Default: ``True``. + skip_checks + If set, will disable the loading of the encoders and the sanity checks on model parameters + (not recommended). Cannot be used with `load_encoders=True`. Default: ``False``. **kwargs Additional kwargs for PyTorch's :func:`load` method, such as ``map_location`` to load the model onto a different device than the one from which it was saved. @@ -1916,6 +1920,7 @@ def load_weights(self, path: str, load_encoders: bool = True, **kwargs): self.load_weights_from_checkpoint( file_name=path_ptl_ckpt, load_encoders=load_encoders, + skip_checks=skip_checks, **kwargs, ) @@ -2058,6 +2063,42 @@ def _load_encoders( return new_encoders, new_add_encoders + def _check_ckpt_parameters(self, tfm_save): + """ + Check that the parameters used to instantiate the new model loading the weights match those of + of the saved model, to return meaningful messages in case of discrepancies. + """ + # parameters unrelated to the weights shape + skipped_params = list( + inspect.signature(TorchForecastingModel.__init__).parameters.keys() + ) + [ + "loss_fn", + "torch_metrics", + "optimizer_cls", + "optimizer_kwargs", + "lr_scheduler_cls", + "lr_scheduler_kwargs", + ] + ckpt_model_params = tfm_save.model_params + for param_key, param_value in self.model_params.items(): + if ( + param_key in ckpt_model_params.keys() + and param_key not in skipped_params + ): + # some parameters must be converted + if isinstance(ckpt_model_params[param_key], list) and not isinstance( + param_value, list + ): + param_value = [param_value] * len(ckpt_model_params[param_key]) + + raise_if( + param_value != ckpt_model_params[param_key], + f"The values of the hyper parameter {param_key} should be identical between " + f"the instantiated model ({param_value}) and the loaded checkpoint " + f"({ckpt_model_params[param_key]}). Please adjust the model accordingly.", + logger, + ) + def __getstate__(self): # do not pickle the PyTorch LightningModule, and Trainer return {k: v for k, v in self.__dict__.items() if k not in TFM_ATTRS_NO_PICKLE} From dd97bd5136313119f2a0f020c477067e43200423 Mon Sep 17 00:00:00 2001 From: madtoinou Date: Fri, 11 Aug 2023 13:47:21 +0200 Subject: [PATCH 02/11] feat: nicer message, all the hp discrepancies are listed at once --- .../forecasting/torch_forecasting_model.py | 58 +++++++++++++------ 1 file changed, 39 insertions(+), 19 deletions(-) diff --git a/darts/models/forecasting/torch_forecasting_model.py b/darts/models/forecasting/torch_forecasting_model.py index 509d00613c..f3890ee945 100644 --- a/darts/models/forecasting/torch_forecasting_model.py +++ b/darts/models/forecasting/torch_forecasting_model.py @@ -2065,7 +2065,7 @@ def _load_encoders( def _check_ckpt_parameters(self, tfm_save): """ - Check that the parameters used to instantiate the new model loading the weights match those of + Check that the positional parameters used to instantiate the new model loading the weights match those of the saved model, to return meaningful messages in case of discrepancies. """ # parameters unrelated to the weights shape @@ -2079,26 +2079,46 @@ def _check_ckpt_parameters(self, tfm_save): "lr_scheduler_cls", "lr_scheduler_kwargs", ] - ckpt_model_params = tfm_save.model_params - for param_key, param_value in self.model_params.items(): - if ( - param_key in ckpt_model_params.keys() - and param_key not in skipped_params - ): - # some parameters must be converted - if isinstance(ckpt_model_params[param_key], list) and not isinstance( - param_value, list - ): - param_value = [param_value] * len(ckpt_model_params[param_key]) - - raise_if( - param_value != ckpt_model_params[param_key], - f"The values of the hyper parameter {param_key} should be identical between " - f"the instantiated model ({param_value}) and the loaded checkpoint " - f"({ckpt_model_params[param_key]}). Please adjust the model accordingly.", - logger, + params_to_check = set(tfm_save.model_params.keys()) - set(skipped_params) + + incorrect_params = [] + missing_params = [] + for param_key in params_to_check: + if param_key not in self.model_params.keys(): + # param name, expected value + missing_params.append((param_key, tfm_save.model_params[param_key])) + elif self.model_params[param_key] != tfm_save.model_params[param_key]: + # param name, expected value, current value + incorrect_params.append( + ( + param_key, + tfm_save.model_params[param_key], + self.model_params[param_key], + ) ) + # at least one discrepancy was detected + if len(missing_params) + len(incorrect_params) > 0: + msg = [ + "The values of the hyper-parameters in the model and loaded checkpoint should be identical." + ] + + # warning messages formated to facilate copy-pasting + if len(missing_params) > 0: + msg += ["missing :"] + msg += [ + f"\t - {param}={exp_val}" for (param, exp_val) in missing_params + ] + + if len(incorrect_params) > 0: + msg += ["incorrect :"] + msg += [ + f"\t - found {param}={cur_val}, should be {param}={exp_val}" + for (param, exp_val, cur_val) in incorrect_params + ] + + raise_log(ValueError("\n".join(msg)), logger) + def __getstate__(self): # do not pickle the PyTorch LightningModule, and Trainer return {k: v for k, v in self.__dict__.items() if k not in TFM_ATTRS_NO_PICKLE} From c08583088dbc8886141edda9b008096ea53d45fb Mon Sep 17 00:00:00 2001 From: madtoinou Date: Fri, 11 Aug 2023 14:29:10 +0200 Subject: [PATCH 03/11] feat: repr method for LikelihoodModel --- darts/utils/likelihood_models.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/darts/utils/likelihood_models.py b/darts/utils/likelihood_models.py index 9abb11fd91..29487d41c2 100644 --- a/darts/utils/likelihood_models.py +++ b/darts/utils/likelihood_models.py @@ -31,6 +31,7 @@ """ import collections.abc +import inspect from abc import ABC, abstractmethod from typing import List, Optional, Tuple, Union @@ -232,6 +233,20 @@ def __eq__(self, other) -> bool: else: return False + def __repr__(self, include_default_params: bool = True) -> str: + """Return the class and parameters of the instance in a nice format""" + cls_name = self.__class__.__name__ + # only display the constructor parameters as user cannot change the other attributes + init_signature = inspect.signature(self.__class__.__init__) + params_string = ", ".join( + [ + f"{str(v)}" + for _, v in init_signature.parameters.items() + if str(v) != "self" + ] + ) + return f"{cls_name}({params_string})" + class GaussianLikelihood(Likelihood): def __init__( From 8d9d830e6d700aeba40e897c0397651e8fd57290 Mon Sep 17 00:00:00 2001 From: madtoinou Date: Fri, 11 Aug 2023 14:29:57 +0200 Subject: [PATCH 04/11] fix: removed unused param from LikelihoodModel.__repr__ --- darts/utils/likelihood_models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/darts/utils/likelihood_models.py b/darts/utils/likelihood_models.py index 29487d41c2..900c47687d 100644 --- a/darts/utils/likelihood_models.py +++ b/darts/utils/likelihood_models.py @@ -233,7 +233,7 @@ def __eq__(self, other) -> bool: else: return False - def __repr__(self, include_default_params: bool = True) -> str: + def __repr__(self) -> str: """Return the class and parameters of the instance in a nice format""" cls_name = self.__class__.__name__ # only display the constructor parameters as user cannot change the other attributes From aa10741593752a8425c09f9ff8b56bdd05e7a785 Mon Sep 17 00:00:00 2001 From: madtoinou Date: Fri, 11 Aug 2023 15:27:44 +0200 Subject: [PATCH 05/11] fix: better sanity check of the kwargs during weights loading from ckpt --- .../forecasting/torch_forecasting_model.py | 23 +++++++++++++++---- 1 file changed, 18 insertions(+), 5 deletions(-) diff --git a/darts/models/forecasting/torch_forecasting_model.py b/darts/models/forecasting/torch_forecasting_model.py index f3890ee945..5e27f3f504 100644 --- a/darts/models/forecasting/torch_forecasting_model.py +++ b/darts/models/forecasting/torch_forecasting_model.py @@ -2079,16 +2079,29 @@ def _check_ckpt_parameters(self, tfm_save): "lr_scheduler_cls", "lr_scheduler_kwargs", ] - params_to_check = set(tfm_save.model_params.keys()) - set(skipped_params) + # model_params can be missing some kwargs + params_to_check = set(tfm_save.model_params.keys()).union( + self.model_params.keys() + ) - set(skipped_params) incorrect_params = [] missing_params = [] for param_key in params_to_check: + # param was not used at loading model creation if param_key not in self.model_params.keys(): - # param name, expected value missing_params.append((param_key, tfm_save.model_params[param_key])) + # new param was used at loading model creation + elif param_key not in tfm_save.model_params.keys(): + incorrect_params.append( + ( + param_key, + None, + self.model_params[param_key], + ) + ) + # param was different at loading model creation elif self.model_params[param_key] != tfm_save.model_params[param_key]: - # param name, expected value, current value + # NOTE: for TFTModel, default is None but converted to `QuantileRegression()` incorrect_params.append( ( param_key, @@ -2107,13 +2120,13 @@ def _check_ckpt_parameters(self, tfm_save): if len(missing_params) > 0: msg += ["missing :"] msg += [ - f"\t - {param}={exp_val}" for (param, exp_val) in missing_params + f" - {param}={exp_val}" for (param, exp_val) in missing_params ] if len(incorrect_params) > 0: msg += ["incorrect :"] msg += [ - f"\t - found {param}={cur_val}, should be {param}={exp_val}" + f" - found {param}={cur_val}, should be {param}={exp_val}" for (param, exp_val, cur_val) in incorrect_params ] From 9b153fad6f908a2976e95f3dc0826cbb98b735bd Mon Sep 17 00:00:00 2001 From: madtoinou Date: Fri, 11 Aug 2023 15:38:39 +0200 Subject: [PATCH 06/11] fix: removed copy-paste leftovers from unittests --- .../forecasting/test_torch_forecasting_model.py | 13 ------------- 1 file changed, 13 deletions(-) diff --git a/darts/tests/models/forecasting/test_torch_forecasting_model.py b/darts/tests/models/forecasting/test_torch_forecasting_model.py index c041eb1303..d0803afe0f 100644 --- a/darts/tests/models/forecasting/test_torch_forecasting_model.py +++ b/darts/tests/models/forecasting/test_torch_forecasting_model.py @@ -653,8 +653,6 @@ def create_DLinearModel( def test_create_instance_new_model_no_name_set(self): RNNModel(12, "RNN", 10, 10, work_dir=self.temp_work_dir, **tfm_kwargs) # no exception is raised - RNNModel(12, "RNN", 10, 10, work_dir=self.temp_work_dir, **tfm_kwargs) - # no exception is raised def test_create_instance_existing_model_with_name_no_fit(self): model_name = "test_model" @@ -669,17 +667,6 @@ def test_create_instance_existing_model_with_name_no_fit(self): ) # no exception is raised - RNNModel( - 12, - "RNN", - 10, - 10, - work_dir=self.temp_work_dir, - model_name=model_name, - **tfm_kwargs, - ) - # no exception is raised - @patch( "darts.models.forecasting.torch_forecasting_model.TorchForecastingModel.reset_model" ) From 5882a5bd610f53dce1e21d8e258da44501f2fb47 Mon Sep 17 00:00:00 2001 From: madtoinou Date: Fri, 11 Aug 2023 15:48:38 +0200 Subject: [PATCH 07/11] fix: removing redudant helper code --- .../test_torch_forecasting_model.py | 128 +++++++++--------- 1 file changed, 62 insertions(+), 66 deletions(-) diff --git a/darts/tests/models/forecasting/test_torch_forecasting_model.py b/darts/tests/models/forecasting/test_torch_forecasting_model.py index d0803afe0f..5542341620 100644 --- a/darts/tests/models/forecasting/test_torch_forecasting_model.py +++ b/darts/tests/models/forecasting/test_torch_forecasting_model.py @@ -1,7 +1,7 @@ import os import shutil import tempfile -from typing import Any, Dict +from typing import Any, Dict, Optional from unittest.mock import patch import numpy as np @@ -277,24 +277,6 @@ def test_save_and_load_weights_w_encoders(self): Note: Using DLinear since it supports both past and future covariates """ - def create_DLinearModel( - model_name: str, - save_checkpoints: bool = False, - add_encoders: Dict = None, - ): - return DLinearModel( - input_chunk_length=4, - output_chunk_length=1, - kernel_size=5, - model_name=model_name, - add_encoders=add_encoders, - work_dir=self.temp_work_dir, - save_checkpoints=save_checkpoints, - random_state=42, - force_reset=True, - **tfm_kwargs, - ) - model_dir = os.path.join(self.temp_work_dir) manual_name = "save_manual" auto_name = "save_auto" @@ -332,18 +314,18 @@ def create_DLinearModel( "transformer": Scaler(), } - model_auto_save = create_DLinearModel( + model_auto_save = self.helper_create_DLinearModel( auto_name, save_checkpoints=True, add_encoders=encoders_past ) model_auto_save.fit(self.series, epochs=1) - model_manual_save = create_DLinearModel( + model_manual_save = self.helper_create_DLinearModel( manual_name, save_checkpoints=False, add_encoders=encoders_past ) model_manual_save.fit(self.series, epochs=1) model_manual_save.save(model_path_manual) - model_auto_save_other = create_DLinearModel( + model_auto_save_other = self.helper_create_DLinearModel( auto_name_other, save_checkpoints=True, add_encoders=encoders_other_past ) model_auto_save_other.fit(self.series, epochs=1) @@ -355,7 +337,9 @@ def create_DLinearModel( ) # model with undeclared encoders - model_no_enc = create_DLinearModel("no_encoder", add_encoders=None) + model_no_enc = self.helper_create_DLinearModel( + "no_encoder", add_encoders=None + ) # weights were trained with encoders, new model must be instantiated with encoders with self.assertRaises(ValueError): model_no_enc.load_weights_from_checkpoint( @@ -386,7 +370,7 @@ def create_DLinearModel( ) # model with identical encoders (fittable) - model_same_enc_noload = create_DLinearModel( + model_same_enc_noload = self.helper_create_DLinearModel( "same_encoder_noload", add_encoders=encoders_past ) model_same_enc_noload.load_weights( @@ -398,7 +382,7 @@ def create_DLinearModel( with self.assertRaises(ValueError): model_same_enc_noload.predict(n=4, series=self.series) - model_same_enc_load = create_DLinearModel( + model_same_enc_load = self.helper_create_DLinearModel( "same_encoder_load", add_encoders=encoders_past ) model_same_enc_load.load_weights( @@ -412,7 +396,7 @@ def create_DLinearModel( ) # model with different encoders (fittable) - model_other_enc_load = create_DLinearModel( + model_other_enc_load = self.helper_create_DLinearModel( "other_encoder_load", add_encoders=encoders_other_past ) # cannot overwritte different declared encoders @@ -424,7 +408,7 @@ def create_DLinearModel( ) # model with different encoders but same dimensions (fittable) - model_other_enc_noload = create_DLinearModel( + model_other_enc_noload = self.helper_create_DLinearModel( "other_encoder_noload", add_encoders=encoders_other_past ) model_other_enc_noload.load_weights( @@ -451,7 +435,7 @@ def create_DLinearModel( model_other_enc_noload.predict(n=4, series=self.series) # model with same encoders but no scaler (non-fittable) - model_new_enc_noscaler_noload = create_DLinearModel( + model_new_enc_noscaler_noload = self.helper_create_DLinearModel( "same_encoder_noscaler", add_encoders=encoders_past_noscaler ) model_new_enc_noscaler_noload.load_weights( @@ -470,7 +454,7 @@ def create_DLinearModel( model_new_enc_noscaler_noload.predict(n=4, series=self.series) # model with same encoders but different transformer (fittable) - model_new_enc_other_transformer = create_DLinearModel( + model_new_enc_other_transformer = self.helper_create_DLinearModel( "same_encoder_other_transform", add_encoders=encoders_past_other_transformer, ) @@ -496,7 +480,7 @@ def create_DLinearModel( model_new_enc_other_transformer.predict(n=4, series=self.series) # model with encoders containing more components (fittable) - model_new_enc_2_past = create_DLinearModel( + model_new_enc_2_past = self.helper_create_DLinearModel( "encoder_2_components_past", add_encoders=encoders_2_past ) # cannot overwritte different declared encoders @@ -515,7 +499,7 @@ def create_DLinearModel( ) # model with encoders containing past and future covs (fittable) - model_new_enc_past_n_future = create_DLinearModel( + model_new_enc_past_n_future = self.helper_create_DLinearModel( "encoder_past_n_future", add_encoders=encoders_past_n_future ) # cannot overwritte different declared encoders @@ -541,25 +525,6 @@ def test_save_and_load_weights_w_likelihood(self): for all but one test. Note: Using DLinear since it supports both past and future covariates """ - - def create_DLinearModel( - model_name: str, - save_checkpoints: bool = False, - likelihood: Likelihood = None, - ): - return DLinearModel( - input_chunk_length=4, - output_chunk_length=1, - kernel_size=5, - model_name=model_name, - work_dir=self.temp_work_dir, - save_checkpoints=save_checkpoints, - likelihood=likelihood, - random_state=42, - force_reset=True, - **tfm_kwargs, - ) - model_dir = os.path.join(self.temp_work_dir) manual_name = "save_manual" auto_name = "save_auto" @@ -571,7 +536,7 @@ def create_DLinearModel( checkpoint_path_manual, checkpoint_file_name ) - model_auto_save = create_DLinearModel( + model_auto_save = self.helper_create_DLinearModel( auto_name, save_checkpoints=True, likelihood=GaussianLikelihood(prior_mu=0.5), @@ -579,7 +544,7 @@ def create_DLinearModel( model_auto_save.fit(self.series, epochs=1) pred_auto = model_auto_save.predict(n=4, series=self.series) - model_manual_save = create_DLinearModel( + model_manual_save = self.helper_create_DLinearModel( manual_name, save_checkpoints=False, likelihood=GaussianLikelihood(prior_mu=0.5), @@ -592,7 +557,7 @@ def create_DLinearModel( self.assertTrue(np.array_equal(pred_auto.values(), pred_manual.values())) # model with identical likelihood - model_same_likelihood = create_DLinearModel( + model_same_likelihood = self.helper_create_DLinearModel( "same_likelihood", likelihood=GaussianLikelihood(prior_mu=0.5) ) model_same_likelihood.load_weights(model_path_manual, map_location="cpu") @@ -600,7 +565,7 @@ def create_DLinearModel( # cannot check predictions since this model is not fitted, random state is different # loading models weights with respective methods - model_manual_same_likelihood = create_DLinearModel( + model_manual_same_likelihood = self.helper_create_DLinearModel( "same_likelihood", likelihood=GaussianLikelihood(prior_mu=0.5) ) model_manual_same_likelihood.load_weights( @@ -610,7 +575,7 @@ def create_DLinearModel( n=4, series=self.series ) - model_auto_same_likelihood = create_DLinearModel( + model_auto_same_likelihood = self.helper_create_DLinearModel( "same_likelihood", likelihood=GaussianLikelihood(prior_mu=0.5) ) model_auto_same_likelihood.load_weights_from_checkpoint( @@ -623,7 +588,9 @@ def create_DLinearModel( self.assertTrue(preds_manual_from_weights == preds_auto_from_weights) # model with no likelihood - model_no_likelihood = create_DLinearModel("no_likelihood", likelihood=None) + model_no_likelihood = self.helper_create_DLinearModel( + "no_likelihood", likelihood=None + ) with self.assertRaises(ValueError): model_no_likelihood.load_weights_from_checkpoint( auto_name, @@ -633,7 +600,7 @@ def create_DLinearModel( ) # model with a different likelihood - model_other_likelihood = create_DLinearModel( + model_other_likelihood = self.helper_create_DLinearModel( "other_likelihood", likelihood=LaplaceLikelihood() ) with self.assertRaises(ValueError): @@ -642,7 +609,7 @@ def create_DLinearModel( ) # model with the same likelihood but different parameters - model_same_likelihood_other_prior = create_DLinearModel( + model_same_likelihood_other_prior = self.helper_create_DLinearModel( "same_likelihood_other_prior", likelihood=GaussianLikelihood() ) with self.assertRaises(ValueError): @@ -650,6 +617,9 @@ def create_DLinearModel( model_path_manual, map_location="cpu" ) + def test_load_weights_from_checkpoint_params_check(self): + pass + def test_create_instance_new_model_no_name_set(self): RNNModel(12, "RNN", 10, 10, work_dir=self.temp_work_dir, **tfm_kwargs) # no exception is raised @@ -1252,7 +1222,11 @@ def test_encoders(self): # 1 == output_chunk_length, 3 > output_chunk_length ns = [1, 3] - model = self.helper_create_DLinearModel() + model = self.helper_create_DLinearModel( + add_encoders={ + "datetime_attribute": {"past": ["hour"], "future": ["month"]} + } + ) model.fit(series) for n in ns: _ = model.predict(n=n) @@ -1263,7 +1237,11 @@ def test_encoders(self): with pytest.raises(ValueError): _ = model.predict(n=n, past_covariates=pc, future_covariates=fc) - model = self.helper_create_DLinearModel() + model = self.helper_create_DLinearModel( + add_encoders={ + "datetime_attribute": {"past": ["hour"], "future": ["month"]} + } + ) for n in ns: model.fit(series, past_covariates=pc) _ = model.predict(n=n) @@ -1273,7 +1251,11 @@ def test_encoders(self): with pytest.raises(ValueError): _ = model.predict(n=n, past_covariates=pc, future_covariates=fc) - model = self.helper_create_DLinearModel() + model = self.helper_create_DLinearModel( + add_encoders={ + "datetime_attribute": {"past": ["hour"], "future": ["month"]} + } + ) for n in ns: model.fit(series, future_covariates=fc) _ = model.predict(n=n) @@ -1283,7 +1265,11 @@ def test_encoders(self): with pytest.raises(ValueError): _ = model.predict(n=n, past_covariates=pc, future_covariates=fc) - model = self.helper_create_DLinearModel() + model = self.helper_create_DLinearModel( + add_encoders={ + "datetime_attribute": {"past": ["hour"], "future": ["month"]} + } + ) for n in ns: model.fit(series, past_covariates=pc, future_covariates=fc) _ = model.predict(n=n) @@ -1334,13 +1320,23 @@ def helper_create_RNNModel(self, model_name: str): **tfm_kwargs, ) - def helper_create_DLinearModel(self): + def helper_create_DLinearModel( + self, + model_name: str = "unitest_model", + add_encoders: Optional[Dict] = None, + save_checkpoints: bool = False, + likelihood: Likelihood = None, + ): return DLinearModel( input_chunk_length=4, output_chunk_length=1, - add_encoders={ - "datetime_attribute": {"past": ["hour"], "future": ["month"]} - }, + model_name=model_name, + add_encoders=add_encoders, + work_dir=self.temp_work_dir, + save_checkpoints=save_checkpoints, + random_state=42, + force_reset=True, n_epochs=1, + likelihood=likelihood, **tfm_kwargs, ) From 294b7123fff72e26293fa9f41a909529482747dd Mon Sep 17 00:00:00 2001 From: madtoinou Date: Fri, 11 Aug 2023 16:23:22 +0200 Subject: [PATCH 08/11] test: adding tests, made the assertion more specific to distinguish incorrect/missing --- .../test_torch_forecasting_model.py | 128 +++++++++++++++++- 1 file changed, 122 insertions(+), 6 deletions(-) diff --git a/darts/tests/models/forecasting/test_torch_forecasting_model.py b/darts/tests/models/forecasting/test_torch_forecasting_model.py index 5542341620..f92aa8debf 100644 --- a/darts/tests/models/forecasting/test_torch_forecasting_model.py +++ b/darts/tests/models/forecasting/test_torch_forecasting_model.py @@ -587,38 +587,154 @@ def test_save_and_load_weights_w_likelihood(self): # check that weights from checkpoint give identical predictions as weights from manual save self.assertTrue(preds_manual_from_weights == preds_auto_from_weights) - # model with no likelihood + # model with explicitely no likelihood model_no_likelihood = self.helper_create_DLinearModel( "no_likelihood", likelihood=None ) - with self.assertRaises(ValueError): + with pytest.raises(ValueError) as error_msg: model_no_likelihood.load_weights_from_checkpoint( auto_name, work_dir=self.temp_work_dir, best=False, map_location="cpu", ) + self.assertTrue( + str(error_msg.value).startswith( + "The values of the hyper-parameters in the model and loaded checkpoint should be identical.\n" + "incorrect" + ) + ) + + # model with missing likelihood (as if user forgot them) + model_no_likelihood_bis = DLinearModel( + input_chunk_length=4, + output_chunk_length=1, + model_name="no_likelihood_bis", + add_encoders=None, + work_dir=self.temp_work_dir, + save_checkpoints=False, + random_state=42, + force_reset=True, + n_epochs=1, + # likelihood=likelihood, + **tfm_kwargs, + ) + with pytest.raises(ValueError) as error_msg: + model_no_likelihood_bis.load_weights_from_checkpoint( + auto_name, + work_dir=self.temp_work_dir, + best=False, + map_location="cpu", + ) + self.assertTrue( + str(error_msg.value).startswith( + "The values of the hyper-parameters in the model and loaded checkpoint should be identical.\n" + "missing" + ) + ) # model with a different likelihood model_other_likelihood = self.helper_create_DLinearModel( "other_likelihood", likelihood=LaplaceLikelihood() ) - with self.assertRaises(ValueError): + with pytest.raises(ValueError) as error_msg: model_other_likelihood.load_weights( model_path_manual, map_location="cpu" ) + self.assertTrue( + str(error_msg.value).startswith( + "The values of the hyper-parameters in the model and loaded checkpoint should be identical.\n" + "incorrect" + ) + ) # model with the same likelihood but different parameters model_same_likelihood_other_prior = self.helper_create_DLinearModel( "same_likelihood_other_prior", likelihood=GaussianLikelihood() ) - with self.assertRaises(ValueError): + with pytest.raises(ValueError) as error_msg: model_same_likelihood_other_prior.load_weights( model_path_manual, map_location="cpu" ) + self.assertTrue( + str(error_msg.value).startswith( + "The values of the hyper-parameters in the model and loaded checkpoint should be identical.\n" + "incorrect" + ) + ) def test_load_weights_from_checkpoint_params_check(self): - pass + """ + Verify that the method comparing the parameters between the saved model and the loading model + behave as expected, used to return meaningful error message instead of the torch.load ones. + """ + model_name = "params_check" + ckpt_name = f"{model_name}.pt" + # barebone model + model = DLinearModel( + input_chunk_length=4, + output_chunk_length=1, + work_dir=self.temp_work_dir, + ) + model.fit(self.series[:10]) + model.save(ckpt_name) + + # identical model + loading_model = DLinearModel( + input_chunk_length=4, + output_chunk_length=1, + work_dir=self.temp_work_dir, + ) + loading_model.load_weights(ckpt_name) + + # different optimizer + loading_model = DLinearModel( + input_chunk_length=4, + output_chunk_length=1, + work_dir=self.temp_work_dir, + optimizer_cls=torch.optim.AdamW, + ) + loading_model.load_weights(ckpt_name) + + # different pl_trainer_kwargs + loading_model = DLinearModel( + input_chunk_length=4, + output_chunk_length=1, + work_dir=self.temp_work_dir, + pl_trainer_kwargs={"enable_model_summary": False}, + ) + loading_model.load_weights(ckpt_name) + + # different input_chunk_length (tfm parameter) + loading_model = DLinearModel( + input_chunk_length=4 + 1, + output_chunk_length=1, + work_dir=self.temp_work_dir, + ) + with pytest.raises(ValueError) as error_msg: + loading_model.load_weights(ckpt_name) + self.assertTrue( + str(error_msg.value).startswith( + "The values of the hyper-parameters in the model and loaded checkpoint should be identical.\n" + "incorrect" + ) + ) + + # different kernel size (cls specific parameter) + loading_model = DLinearModel( + input_chunk_length=4, + output_chunk_length=1, + kernel_size=10, + work_dir=self.temp_work_dir, + ) + with pytest.raises(ValueError) as error_msg: + loading_model.load_weights(ckpt_name) + self.assertTrue( + str(error_msg.value).startswith( + "The values of the hyper-parameters in the model and loaded checkpoint should be identical.\n" + "incorrect" + ) + ) def test_create_instance_new_model_no_name_set(self): RNNModel(12, "RNN", 10, 10, work_dir=self.temp_work_dir, **tfm_kwargs) @@ -1325,7 +1441,7 @@ def helper_create_DLinearModel( model_name: str = "unitest_model", add_encoders: Optional[Dict] = None, save_checkpoints: bool = False, - likelihood: Likelihood = None, + likelihood: Optional[Likelihood] = None, ): return DLinearModel( input_chunk_length=4, From a9811f4123bbf1fef0beae15323110085eeab02a Mon Sep 17 00:00:00 2001 From: madtoinou <32447896+madtoinou@users.noreply.github.com> Date: Mon, 14 Aug 2023 11:07:27 +0200 Subject: [PATCH 09/11] Apply suggestions from code review Co-authored-by: Dennis Bader --- .../forecasting/test_torch_forecasting_model.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/darts/tests/models/forecasting/test_torch_forecasting_model.py b/darts/tests/models/forecasting/test_torch_forecasting_model.py index f92aa8debf..6e16edbbf2 100644 --- a/darts/tests/models/forecasting/test_torch_forecasting_model.py +++ b/darts/tests/models/forecasting/test_torch_forecasting_model.py @@ -598,7 +598,7 @@ def test_save_and_load_weights_w_likelihood(self): best=False, map_location="cpu", ) - self.assertTrue( + assert ( str(error_msg.value).startswith( "The values of the hyper-parameters in the model and loaded checkpoint should be identical.\n" "incorrect" @@ -626,7 +626,7 @@ def test_save_and_load_weights_w_likelihood(self): best=False, map_location="cpu", ) - self.assertTrue( + assert ( str(error_msg.value).startswith( "The values of the hyper-parameters in the model and loaded checkpoint should be identical.\n" "missing" @@ -641,7 +641,7 @@ def test_save_and_load_weights_w_likelihood(self): model_other_likelihood.load_weights( model_path_manual, map_location="cpu" ) - self.assertTrue( + assert ( str(error_msg.value).startswith( "The values of the hyper-parameters in the model and loaded checkpoint should be identical.\n" "incorrect" @@ -656,14 +656,14 @@ def test_save_and_load_weights_w_likelihood(self): model_same_likelihood_other_prior.load_weights( model_path_manual, map_location="cpu" ) - self.assertTrue( + assert ( str(error_msg.value).startswith( "The values of the hyper-parameters in the model and loaded checkpoint should be identical.\n" "incorrect" ) ) - def test_load_weights_from_checkpoint_params_check(self): + def test_load_weights_params_check(self): """ Verify that the method comparing the parameters between the saved model and the loading model behave as expected, used to return meaningful error message instead of the torch.load ones. @@ -675,6 +675,7 @@ def test_load_weights_from_checkpoint_params_check(self): input_chunk_length=4, output_chunk_length=1, work_dir=self.temp_work_dir, + n_epochs=1, ) model.fit(self.series[:10]) model.save(ckpt_name) @@ -713,7 +714,7 @@ def test_load_weights_from_checkpoint_params_check(self): ) with pytest.raises(ValueError) as error_msg: loading_model.load_weights(ckpt_name) - self.assertTrue( + assert ( str(error_msg.value).startswith( "The values of the hyper-parameters in the model and loaded checkpoint should be identical.\n" "incorrect" @@ -729,7 +730,7 @@ def test_load_weights_from_checkpoint_params_check(self): ) with pytest.raises(ValueError) as error_msg: loading_model.load_weights(ckpt_name) - self.assertTrue( + assert ( str(error_msg.value).startswith( "The values of the hyper-parameters in the model and loaded checkpoint should be identical.\n" "incorrect" From b748689814c851be5139ceec6a4e217cf1e06707 Mon Sep 17 00:00:00 2001 From: madtoinou Date: Mon, 14 Aug 2023 11:15:41 +0200 Subject: [PATCH 10/11] fix: linting --- .../test_torch_forecasting_model.py | 48 +++++++------------ 1 file changed, 18 insertions(+), 30 deletions(-) diff --git a/darts/tests/models/forecasting/test_torch_forecasting_model.py b/darts/tests/models/forecasting/test_torch_forecasting_model.py index 6e16edbbf2..bc57a66b28 100644 --- a/darts/tests/models/forecasting/test_torch_forecasting_model.py +++ b/darts/tests/models/forecasting/test_torch_forecasting_model.py @@ -598,11 +598,9 @@ def test_save_and_load_weights_w_likelihood(self): best=False, map_location="cpu", ) - assert ( - str(error_msg.value).startswith( - "The values of the hyper-parameters in the model and loaded checkpoint should be identical.\n" - "incorrect" - ) + assert str(error_msg.value).startswith( + "The values of the hyper-parameters in the model and loaded checkpoint should be identical.\n" + "incorrect" ) # model with missing likelihood (as if user forgot them) @@ -626,11 +624,9 @@ def test_save_and_load_weights_w_likelihood(self): best=False, map_location="cpu", ) - assert ( - str(error_msg.value).startswith( - "The values of the hyper-parameters in the model and loaded checkpoint should be identical.\n" - "missing" - ) + assert str(error_msg.value).startswith( + "The values of the hyper-parameters in the model and loaded checkpoint should be identical.\n" + "missing" ) # model with a different likelihood @@ -641,11 +637,9 @@ def test_save_and_load_weights_w_likelihood(self): model_other_likelihood.load_weights( model_path_manual, map_location="cpu" ) - assert ( - str(error_msg.value).startswith( - "The values of the hyper-parameters in the model and loaded checkpoint should be identical.\n" - "incorrect" - ) + assert str(error_msg.value).startswith( + "The values of the hyper-parameters in the model and loaded checkpoint should be identical.\n" + "incorrect" ) # model with the same likelihood but different parameters @@ -656,11 +650,9 @@ def test_save_and_load_weights_w_likelihood(self): model_same_likelihood_other_prior.load_weights( model_path_manual, map_location="cpu" ) - assert ( - str(error_msg.value).startswith( - "The values of the hyper-parameters in the model and loaded checkpoint should be identical.\n" - "incorrect" - ) + assert str(error_msg.value).startswith( + "The values of the hyper-parameters in the model and loaded checkpoint should be identical.\n" + "incorrect" ) def test_load_weights_params_check(self): @@ -714,11 +706,9 @@ def test_load_weights_params_check(self): ) with pytest.raises(ValueError) as error_msg: loading_model.load_weights(ckpt_name) - assert ( - str(error_msg.value).startswith( - "The values of the hyper-parameters in the model and loaded checkpoint should be identical.\n" - "incorrect" - ) + assert str(error_msg.value).startswith( + "The values of the hyper-parameters in the model and loaded checkpoint should be identical.\n" + "incorrect" ) # different kernel size (cls specific parameter) @@ -730,11 +720,9 @@ def test_load_weights_params_check(self): ) with pytest.raises(ValueError) as error_msg: loading_model.load_weights(ckpt_name) - assert ( - str(error_msg.value).startswith( - "The values of the hyper-parameters in the model and loaded checkpoint should be identical.\n" - "incorrect" - ) + assert str(error_msg.value).startswith( + "The values of the hyper-parameters in the model and loaded checkpoint should be identical.\n" + "incorrect" ) def test_create_instance_new_model_no_name_set(self): From 19b2bb36f13518198a6d323691e2ef812767f67a Mon Sep 17 00:00:00 2001 From: Dennis Bader Date: Tue, 15 Aug 2023 09:47:47 +0200 Subject: [PATCH 11/11] Update CHANGELOG.md --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 03517933b4..7f86780ed7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -14,6 +14,7 @@ but cannot always guarantee backwards compatibility. Changes that may **break co - Fixed a bug in `TimeSeries.from_dataframe()` when using a pandas.DataFrame with `df.columns.name != None`. [#1938](https://github.com/unit8co/darts/pull/1938) by [Antoine Madrona](https://github.com/madtoinou). - Fixed a bug in `RegressionEnsembleModel.extreme_lags` when the forecasting models have only covariates lags. [#1942](https://github.com/unit8co/darts/pull/1942) by [Antoine Madrona](https://github.com/madtoinou). - Fixed a bug when using `TFTExplainer` with a `TFTModel` running on GPU. [#1949](https://github.com/unit8co/darts/pull/1949) by [Dennis Bader](https://github.com/dennisbader). +- Fixed a bug in `TorchForecastingModel.load_weights()` that raised an error when loading the weights from a valid architecture. [#1952](https://github.com/unit8co/darts/pull/1952) by [Antoine Madrona](https://github.com/madtoinou). ## [0.25.0](https://github.com/unit8co/darts/tree/0.25.0) (2023-08-04)