diff --git a/CHANGELOG.md b/CHANGELOG.md index 03517933b4..7f86780ed7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -14,6 +14,7 @@ but cannot always guarantee backwards compatibility. Changes that may **break co - Fixed a bug in `TimeSeries.from_dataframe()` when using a pandas.DataFrame with `df.columns.name != None`. [#1938](https://github.com/unit8co/darts/pull/1938) by [Antoine Madrona](https://github.com/madtoinou). - Fixed a bug in `RegressionEnsembleModel.extreme_lags` when the forecasting models have only covariates lags. [#1942](https://github.com/unit8co/darts/pull/1942) by [Antoine Madrona](https://github.com/madtoinou). - Fixed a bug when using `TFTExplainer` with a `TFTModel` running on GPU. [#1949](https://github.com/unit8co/darts/pull/1949) by [Dennis Bader](https://github.com/dennisbader). +- Fixed a bug in `TorchForecastingModel.load_weights()` that raised an error when loading the weights from a valid architecture. [#1952](https://github.com/unit8co/darts/pull/1952) by [Antoine Madrona](https://github.com/madtoinou). ## [0.25.0](https://github.com/unit8co/darts/tree/0.25.0) (2023-08-04) diff --git a/darts/models/forecasting/torch_forecasting_model.py b/darts/models/forecasting/torch_forecasting_model.py index 316d664a3f..5e27f3f504 100644 --- a/darts/models/forecasting/torch_forecasting_model.py +++ b/darts/models/forecasting/torch_forecasting_model.py @@ -1743,6 +1743,7 @@ def load_weights_from_checkpoint( best: bool = True, strict: bool = True, load_encoders: bool = True, + skip_checks: bool = False, **kwargs, ): """ @@ -1758,6 +1759,9 @@ def load_weights_from_checkpoint( For manually saved model, consider using :meth:`load() ` or :meth:`load_weights() ` instead. + Note: This method needs to be able to access the darts model checkpoint (.pt) in order to load the encoders + and perform sanity checks on the model parameters. + Parameters ---------- model_name @@ -1777,6 +1781,9 @@ def load_weights_from_checkpoint( load_encoders If set, will load the encoders from the model to enable direct call of fit() or predict(). Default: ``True``. + skip_checks + If set, will disable the loading of the encoders and the sanity checks on model parameters + (not recommended). Cannot be used with `load_encoders=True`. Default: ``False``. **kwargs Additional kwargs for PyTorch's :func:`load` method, such as ``map_location`` to load the model onto a different device than the one from which it was saved. @@ -1790,6 +1797,13 @@ def load_weights_from_checkpoint( logger, ) + raise_if( + skip_checks and load_encoders, + "`skip-checks` and `load_encoders` are mutually exclusive parameters and cannot be both " + "set to `True`.", + logger, + ) + # use the name of the model being loaded with the saved weights if model_name is None: model_name = self.model_name @@ -1816,39 +1830,6 @@ def load_weights_from_checkpoint( ckpt_path = os.path.join(checkpoint_dir, file_name) ckpt = torch.load(ckpt_path, **kwargs) - ckpt_hyper_params = ckpt["hyper_parameters"] - - # verify that the arguments passed to the constructor match those of the checkpoint - # add_encoders is checked in _load_encoders() - skipped_params = list( - inspect.signature(TorchForecastingModel.__init__).parameters.keys() - ) + [ - "loss_fn", - "torch_metrics", - "optimizer_cls", - "optimizer_kwargs", - "lr_scheduler_cls", - "lr_scheduler_kwargs", - ] - for param_key, param_value in self.model_params.items(): - # TODO: there are discrepancies between the param names, for ex num_layer/n_rnn_layers - if ( - param_key in ckpt_hyper_params.keys() - and param_key not in skipped_params - ): - # some parameters must be converted - if isinstance(ckpt_hyper_params[param_key], list) and not isinstance( - param_value, list - ): - param_value = [param_value] * len(ckpt_hyper_params[param_key]) - - raise_if( - param_value != ckpt_hyper_params[param_key], - f"The values of the hyper parameter {param_key} should be identical between " - f"the instantiated model ({param_value}) and the loaded checkpoint " - f"({ckpt_hyper_params[param_key]}). Please adjust the model accordingly.", - logger, - ) # indicate to the user than checkpoints generated with darts <= 0.23.1 are not supported raise_if_not( @@ -1867,17 +1848,32 @@ def load_weights_from_checkpoint( ] self.train_sample = tuple(mock_train_sample) - # updating model attributes before self._init_model() which create new ckpt - tfm_save_file_path = os.path.join(tfm_save_file_dir, tfm_save_file_name) - with open(tfm_save_file_path, "rb") as tfm_save_file: - tfm_save: TorchForecastingModel = torch.load( - tfm_save_file, map_location=kwargs.get("map_location", None) - ) + if not skip_checks: + # path to the tfm checkpoint (darts model, .pt extension) + tfm_save_file_path = os.path.join(tfm_save_file_dir, tfm_save_file_name) + if not os.path.exists(tfm_save_file_path): + raise_log( + FileNotFoundError( + f"Could not find {tfm_save_file_path}, necessary to load the encoders " + f"and run sanity checks on the model parameters." + ), + logger, + ) + + # updating model attributes before self._init_model() which create new tfm ckpt + with open(tfm_save_file_path, "rb") as tfm_save_file: + tfm_save: TorchForecastingModel = torch.load( + tfm_save_file, map_location=kwargs.get("map_location", None) + ) + # encoders are necessary for direct inference self.encoders, self.add_encoders = self._load_encoders( tfm_save, load_encoders ) + # meaningful error message if parameters are incompatible with the ckpt weights + self._check_ckpt_parameters(tfm_save) + # instanciate the model without having to call `fit_from_dataset` self.model = self._init_model() # cast model precision to correct type @@ -1887,10 +1883,15 @@ def load_weights_from_checkpoint( # update the fit_called attribute to allow for direct inference self._fit_called = True - def load_weights(self, path: str, load_encoders: bool = True, **kwargs): + def load_weights( + self, path: str, load_encoders: bool = True, skip_checks: bool = False, **kwargs + ): """ Loads the weights from a manually saved model (saved with :meth:`save() `). + Note: This method needs to be able to access the darts model checkpoint (.pt) in order to load the encoders + and perform sanity checks on the model parameters. + Parameters ---------- path @@ -1899,6 +1900,9 @@ def load_weights(self, path: str, load_encoders: bool = True, **kwargs): load_encoders If set, will load the encoders from the model to enable direct call of fit() or predict(). Default: ``True``. + skip_checks + If set, will disable the loading of the encoders and the sanity checks on model parameters + (not recommended). Cannot be used with `load_encoders=True`. Default: ``False``. **kwargs Additional kwargs for PyTorch's :func:`load` method, such as ``map_location`` to load the model onto a different device than the one from which it was saved. @@ -1916,6 +1920,7 @@ def load_weights(self, path: str, load_encoders: bool = True, **kwargs): self.load_weights_from_checkpoint( file_name=path_ptl_ckpt, load_encoders=load_encoders, + skip_checks=skip_checks, **kwargs, ) @@ -2058,6 +2063,75 @@ def _load_encoders( return new_encoders, new_add_encoders + def _check_ckpt_parameters(self, tfm_save): + """ + Check that the positional parameters used to instantiate the new model loading the weights match those + of the saved model, to return meaningful messages in case of discrepancies. + """ + # parameters unrelated to the weights shape + skipped_params = list( + inspect.signature(TorchForecastingModel.__init__).parameters.keys() + ) + [ + "loss_fn", + "torch_metrics", + "optimizer_cls", + "optimizer_kwargs", + "lr_scheduler_cls", + "lr_scheduler_kwargs", + ] + # model_params can be missing some kwargs + params_to_check = set(tfm_save.model_params.keys()).union( + self.model_params.keys() + ) - set(skipped_params) + + incorrect_params = [] + missing_params = [] + for param_key in params_to_check: + # param was not used at loading model creation + if param_key not in self.model_params.keys(): + missing_params.append((param_key, tfm_save.model_params[param_key])) + # new param was used at loading model creation + elif param_key not in tfm_save.model_params.keys(): + incorrect_params.append( + ( + param_key, + None, + self.model_params[param_key], + ) + ) + # param was different at loading model creation + elif self.model_params[param_key] != tfm_save.model_params[param_key]: + # NOTE: for TFTModel, default is None but converted to `QuantileRegression()` + incorrect_params.append( + ( + param_key, + tfm_save.model_params[param_key], + self.model_params[param_key], + ) + ) + + # at least one discrepancy was detected + if len(missing_params) + len(incorrect_params) > 0: + msg = [ + "The values of the hyper-parameters in the model and loaded checkpoint should be identical." + ] + + # warning messages formated to facilate copy-pasting + if len(missing_params) > 0: + msg += ["missing :"] + msg += [ + f" - {param}={exp_val}" for (param, exp_val) in missing_params + ] + + if len(incorrect_params) > 0: + msg += ["incorrect :"] + msg += [ + f" - found {param}={cur_val}, should be {param}={exp_val}" + for (param, exp_val, cur_val) in incorrect_params + ] + + raise_log(ValueError("\n".join(msg)), logger) + def __getstate__(self): # do not pickle the PyTorch LightningModule, and Trainer return {k: v for k, v in self.__dict__.items() if k not in TFM_ATTRS_NO_PICKLE} diff --git a/darts/tests/models/forecasting/test_torch_forecasting_model.py b/darts/tests/models/forecasting/test_torch_forecasting_model.py index c041eb1303..bc57a66b28 100644 --- a/darts/tests/models/forecasting/test_torch_forecasting_model.py +++ b/darts/tests/models/forecasting/test_torch_forecasting_model.py @@ -1,7 +1,7 @@ import os import shutil import tempfile -from typing import Any, Dict +from typing import Any, Dict, Optional from unittest.mock import patch import numpy as np @@ -277,24 +277,6 @@ def test_save_and_load_weights_w_encoders(self): Note: Using DLinear since it supports both past and future covariates """ - def create_DLinearModel( - model_name: str, - save_checkpoints: bool = False, - add_encoders: Dict = None, - ): - return DLinearModel( - input_chunk_length=4, - output_chunk_length=1, - kernel_size=5, - model_name=model_name, - add_encoders=add_encoders, - work_dir=self.temp_work_dir, - save_checkpoints=save_checkpoints, - random_state=42, - force_reset=True, - **tfm_kwargs, - ) - model_dir = os.path.join(self.temp_work_dir) manual_name = "save_manual" auto_name = "save_auto" @@ -332,18 +314,18 @@ def create_DLinearModel( "transformer": Scaler(), } - model_auto_save = create_DLinearModel( + model_auto_save = self.helper_create_DLinearModel( auto_name, save_checkpoints=True, add_encoders=encoders_past ) model_auto_save.fit(self.series, epochs=1) - model_manual_save = create_DLinearModel( + model_manual_save = self.helper_create_DLinearModel( manual_name, save_checkpoints=False, add_encoders=encoders_past ) model_manual_save.fit(self.series, epochs=1) model_manual_save.save(model_path_manual) - model_auto_save_other = create_DLinearModel( + model_auto_save_other = self.helper_create_DLinearModel( auto_name_other, save_checkpoints=True, add_encoders=encoders_other_past ) model_auto_save_other.fit(self.series, epochs=1) @@ -355,7 +337,9 @@ def create_DLinearModel( ) # model with undeclared encoders - model_no_enc = create_DLinearModel("no_encoder", add_encoders=None) + model_no_enc = self.helper_create_DLinearModel( + "no_encoder", add_encoders=None + ) # weights were trained with encoders, new model must be instantiated with encoders with self.assertRaises(ValueError): model_no_enc.load_weights_from_checkpoint( @@ -386,7 +370,7 @@ def create_DLinearModel( ) # model with identical encoders (fittable) - model_same_enc_noload = create_DLinearModel( + model_same_enc_noload = self.helper_create_DLinearModel( "same_encoder_noload", add_encoders=encoders_past ) model_same_enc_noload.load_weights( @@ -398,7 +382,7 @@ def create_DLinearModel( with self.assertRaises(ValueError): model_same_enc_noload.predict(n=4, series=self.series) - model_same_enc_load = create_DLinearModel( + model_same_enc_load = self.helper_create_DLinearModel( "same_encoder_load", add_encoders=encoders_past ) model_same_enc_load.load_weights( @@ -412,7 +396,7 @@ def create_DLinearModel( ) # model with different encoders (fittable) - model_other_enc_load = create_DLinearModel( + model_other_enc_load = self.helper_create_DLinearModel( "other_encoder_load", add_encoders=encoders_other_past ) # cannot overwritte different declared encoders @@ -424,7 +408,7 @@ def create_DLinearModel( ) # model with different encoders but same dimensions (fittable) - model_other_enc_noload = create_DLinearModel( + model_other_enc_noload = self.helper_create_DLinearModel( "other_encoder_noload", add_encoders=encoders_other_past ) model_other_enc_noload.load_weights( @@ -451,7 +435,7 @@ def create_DLinearModel( model_other_enc_noload.predict(n=4, series=self.series) # model with same encoders but no scaler (non-fittable) - model_new_enc_noscaler_noload = create_DLinearModel( + model_new_enc_noscaler_noload = self.helper_create_DLinearModel( "same_encoder_noscaler", add_encoders=encoders_past_noscaler ) model_new_enc_noscaler_noload.load_weights( @@ -470,7 +454,7 @@ def create_DLinearModel( model_new_enc_noscaler_noload.predict(n=4, series=self.series) # model with same encoders but different transformer (fittable) - model_new_enc_other_transformer = create_DLinearModel( + model_new_enc_other_transformer = self.helper_create_DLinearModel( "same_encoder_other_transform", add_encoders=encoders_past_other_transformer, ) @@ -496,7 +480,7 @@ def create_DLinearModel( model_new_enc_other_transformer.predict(n=4, series=self.series) # model with encoders containing more components (fittable) - model_new_enc_2_past = create_DLinearModel( + model_new_enc_2_past = self.helper_create_DLinearModel( "encoder_2_components_past", add_encoders=encoders_2_past ) # cannot overwritte different declared encoders @@ -515,7 +499,7 @@ def create_DLinearModel( ) # model with encoders containing past and future covs (fittable) - model_new_enc_past_n_future = create_DLinearModel( + model_new_enc_past_n_future = self.helper_create_DLinearModel( "encoder_past_n_future", add_encoders=encoders_past_n_future ) # cannot overwritte different declared encoders @@ -541,25 +525,6 @@ def test_save_and_load_weights_w_likelihood(self): for all but one test. Note: Using DLinear since it supports both past and future covariates """ - - def create_DLinearModel( - model_name: str, - save_checkpoints: bool = False, - likelihood: Likelihood = None, - ): - return DLinearModel( - input_chunk_length=4, - output_chunk_length=1, - kernel_size=5, - model_name=model_name, - work_dir=self.temp_work_dir, - save_checkpoints=save_checkpoints, - likelihood=likelihood, - random_state=42, - force_reset=True, - **tfm_kwargs, - ) - model_dir = os.path.join(self.temp_work_dir) manual_name = "save_manual" auto_name = "save_auto" @@ -571,7 +536,7 @@ def create_DLinearModel( checkpoint_path_manual, checkpoint_file_name ) - model_auto_save = create_DLinearModel( + model_auto_save = self.helper_create_DLinearModel( auto_name, save_checkpoints=True, likelihood=GaussianLikelihood(prior_mu=0.5), @@ -579,7 +544,7 @@ def create_DLinearModel( model_auto_save.fit(self.series, epochs=1) pred_auto = model_auto_save.predict(n=4, series=self.series) - model_manual_save = create_DLinearModel( + model_manual_save = self.helper_create_DLinearModel( manual_name, save_checkpoints=False, likelihood=GaussianLikelihood(prior_mu=0.5), @@ -592,7 +557,7 @@ def create_DLinearModel( self.assertTrue(np.array_equal(pred_auto.values(), pred_manual.values())) # model with identical likelihood - model_same_likelihood = create_DLinearModel( + model_same_likelihood = self.helper_create_DLinearModel( "same_likelihood", likelihood=GaussianLikelihood(prior_mu=0.5) ) model_same_likelihood.load_weights(model_path_manual, map_location="cpu") @@ -600,7 +565,7 @@ def create_DLinearModel( # cannot check predictions since this model is not fitted, random state is different # loading models weights with respective methods - model_manual_same_likelihood = create_DLinearModel( + model_manual_same_likelihood = self.helper_create_DLinearModel( "same_likelihood", likelihood=GaussianLikelihood(prior_mu=0.5) ) model_manual_same_likelihood.load_weights( @@ -610,7 +575,7 @@ def create_DLinearModel( n=4, series=self.series ) - model_auto_same_likelihood = create_DLinearModel( + model_auto_same_likelihood = self.helper_create_DLinearModel( "same_likelihood", likelihood=GaussianLikelihood(prior_mu=0.5) ) model_auto_same_likelihood.load_weights_from_checkpoint( @@ -622,39 +587,147 @@ def create_DLinearModel( # check that weights from checkpoint give identical predictions as weights from manual save self.assertTrue(preds_manual_from_weights == preds_auto_from_weights) - # model with no likelihood - model_no_likelihood = create_DLinearModel("no_likelihood", likelihood=None) - with self.assertRaises(ValueError): + # model with explicitely no likelihood + model_no_likelihood = self.helper_create_DLinearModel( + "no_likelihood", likelihood=None + ) + with pytest.raises(ValueError) as error_msg: model_no_likelihood.load_weights_from_checkpoint( auto_name, work_dir=self.temp_work_dir, best=False, map_location="cpu", ) + assert str(error_msg.value).startswith( + "The values of the hyper-parameters in the model and loaded checkpoint should be identical.\n" + "incorrect" + ) + + # model with missing likelihood (as if user forgot them) + model_no_likelihood_bis = DLinearModel( + input_chunk_length=4, + output_chunk_length=1, + model_name="no_likelihood_bis", + add_encoders=None, + work_dir=self.temp_work_dir, + save_checkpoints=False, + random_state=42, + force_reset=True, + n_epochs=1, + # likelihood=likelihood, + **tfm_kwargs, + ) + with pytest.raises(ValueError) as error_msg: + model_no_likelihood_bis.load_weights_from_checkpoint( + auto_name, + work_dir=self.temp_work_dir, + best=False, + map_location="cpu", + ) + assert str(error_msg.value).startswith( + "The values of the hyper-parameters in the model and loaded checkpoint should be identical.\n" + "missing" + ) # model with a different likelihood - model_other_likelihood = create_DLinearModel( + model_other_likelihood = self.helper_create_DLinearModel( "other_likelihood", likelihood=LaplaceLikelihood() ) - with self.assertRaises(ValueError): + with pytest.raises(ValueError) as error_msg: model_other_likelihood.load_weights( model_path_manual, map_location="cpu" ) + assert str(error_msg.value).startswith( + "The values of the hyper-parameters in the model and loaded checkpoint should be identical.\n" + "incorrect" + ) # model with the same likelihood but different parameters - model_same_likelihood_other_prior = create_DLinearModel( + model_same_likelihood_other_prior = self.helper_create_DLinearModel( "same_likelihood_other_prior", likelihood=GaussianLikelihood() ) - with self.assertRaises(ValueError): + with pytest.raises(ValueError) as error_msg: model_same_likelihood_other_prior.load_weights( model_path_manual, map_location="cpu" ) + assert str(error_msg.value).startswith( + "The values of the hyper-parameters in the model and loaded checkpoint should be identical.\n" + "incorrect" + ) + + def test_load_weights_params_check(self): + """ + Verify that the method comparing the parameters between the saved model and the loading model + behave as expected, used to return meaningful error message instead of the torch.load ones. + """ + model_name = "params_check" + ckpt_name = f"{model_name}.pt" + # barebone model + model = DLinearModel( + input_chunk_length=4, + output_chunk_length=1, + work_dir=self.temp_work_dir, + n_epochs=1, + ) + model.fit(self.series[:10]) + model.save(ckpt_name) + + # identical model + loading_model = DLinearModel( + input_chunk_length=4, + output_chunk_length=1, + work_dir=self.temp_work_dir, + ) + loading_model.load_weights(ckpt_name) + + # different optimizer + loading_model = DLinearModel( + input_chunk_length=4, + output_chunk_length=1, + work_dir=self.temp_work_dir, + optimizer_cls=torch.optim.AdamW, + ) + loading_model.load_weights(ckpt_name) + + # different pl_trainer_kwargs + loading_model = DLinearModel( + input_chunk_length=4, + output_chunk_length=1, + work_dir=self.temp_work_dir, + pl_trainer_kwargs={"enable_model_summary": False}, + ) + loading_model.load_weights(ckpt_name) + + # different input_chunk_length (tfm parameter) + loading_model = DLinearModel( + input_chunk_length=4 + 1, + output_chunk_length=1, + work_dir=self.temp_work_dir, + ) + with pytest.raises(ValueError) as error_msg: + loading_model.load_weights(ckpt_name) + assert str(error_msg.value).startswith( + "The values of the hyper-parameters in the model and loaded checkpoint should be identical.\n" + "incorrect" + ) + + # different kernel size (cls specific parameter) + loading_model = DLinearModel( + input_chunk_length=4, + output_chunk_length=1, + kernel_size=10, + work_dir=self.temp_work_dir, + ) + with pytest.raises(ValueError) as error_msg: + loading_model.load_weights(ckpt_name) + assert str(error_msg.value).startswith( + "The values of the hyper-parameters in the model and loaded checkpoint should be identical.\n" + "incorrect" + ) def test_create_instance_new_model_no_name_set(self): RNNModel(12, "RNN", 10, 10, work_dir=self.temp_work_dir, **tfm_kwargs) # no exception is raised - RNNModel(12, "RNN", 10, 10, work_dir=self.temp_work_dir, **tfm_kwargs) - # no exception is raised def test_create_instance_existing_model_with_name_no_fit(self): model_name = "test_model" @@ -669,17 +742,6 @@ def test_create_instance_existing_model_with_name_no_fit(self): ) # no exception is raised - RNNModel( - 12, - "RNN", - 10, - 10, - work_dir=self.temp_work_dir, - model_name=model_name, - **tfm_kwargs, - ) - # no exception is raised - @patch( "darts.models.forecasting.torch_forecasting_model.TorchForecastingModel.reset_model" ) @@ -1265,7 +1327,11 @@ def test_encoders(self): # 1 == output_chunk_length, 3 > output_chunk_length ns = [1, 3] - model = self.helper_create_DLinearModel() + model = self.helper_create_DLinearModel( + add_encoders={ + "datetime_attribute": {"past": ["hour"], "future": ["month"]} + } + ) model.fit(series) for n in ns: _ = model.predict(n=n) @@ -1276,7 +1342,11 @@ def test_encoders(self): with pytest.raises(ValueError): _ = model.predict(n=n, past_covariates=pc, future_covariates=fc) - model = self.helper_create_DLinearModel() + model = self.helper_create_DLinearModel( + add_encoders={ + "datetime_attribute": {"past": ["hour"], "future": ["month"]} + } + ) for n in ns: model.fit(series, past_covariates=pc) _ = model.predict(n=n) @@ -1286,7 +1356,11 @@ def test_encoders(self): with pytest.raises(ValueError): _ = model.predict(n=n, past_covariates=pc, future_covariates=fc) - model = self.helper_create_DLinearModel() + model = self.helper_create_DLinearModel( + add_encoders={ + "datetime_attribute": {"past": ["hour"], "future": ["month"]} + } + ) for n in ns: model.fit(series, future_covariates=fc) _ = model.predict(n=n) @@ -1296,7 +1370,11 @@ def test_encoders(self): with pytest.raises(ValueError): _ = model.predict(n=n, past_covariates=pc, future_covariates=fc) - model = self.helper_create_DLinearModel() + model = self.helper_create_DLinearModel( + add_encoders={ + "datetime_attribute": {"past": ["hour"], "future": ["month"]} + } + ) for n in ns: model.fit(series, past_covariates=pc, future_covariates=fc) _ = model.predict(n=n) @@ -1347,13 +1425,23 @@ def helper_create_RNNModel(self, model_name: str): **tfm_kwargs, ) - def helper_create_DLinearModel(self): + def helper_create_DLinearModel( + self, + model_name: str = "unitest_model", + add_encoders: Optional[Dict] = None, + save_checkpoints: bool = False, + likelihood: Optional[Likelihood] = None, + ): return DLinearModel( input_chunk_length=4, output_chunk_length=1, - add_encoders={ - "datetime_attribute": {"past": ["hour"], "future": ["month"]} - }, + model_name=model_name, + add_encoders=add_encoders, + work_dir=self.temp_work_dir, + save_checkpoints=save_checkpoints, + random_state=42, + force_reset=True, n_epochs=1, + likelihood=likelihood, **tfm_kwargs, ) diff --git a/darts/utils/likelihood_models.py b/darts/utils/likelihood_models.py index 9abb11fd91..900c47687d 100644 --- a/darts/utils/likelihood_models.py +++ b/darts/utils/likelihood_models.py @@ -31,6 +31,7 @@ """ import collections.abc +import inspect from abc import ABC, abstractmethod from typing import List, Optional, Tuple, Union @@ -232,6 +233,20 @@ def __eq__(self, other) -> bool: else: return False + def __repr__(self) -> str: + """Return the class and parameters of the instance in a nice format""" + cls_name = self.__class__.__name__ + # only display the constructor parameters as user cannot change the other attributes + init_signature = inspect.signature(self.__class__.__init__) + params_string = ", ".join( + [ + f"{str(v)}" + for _, v in init_signature.parameters.items() + if str(v) != "self" + ] + ) + return f"{cls_name}({params_string})" + class GaussianLikelihood(Likelihood): def __init__(