From 827404f248711bcf94552912cc4079f9f1f6fa92 Mon Sep 17 00:00:00 2001 From: dennisbader Date: Sun, 9 Jan 2022 13:40:08 +0100 Subject: [PATCH 1/2] added check for untrained models for copying of model parameters --- darts/models/forecasting/ensemble_model.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/darts/models/forecasting/ensemble_model.py b/darts/models/forecasting/ensemble_model.py index 8548a62f2e..6f3fa93151 100644 --- a/darts/models/forecasting/ensemble_model.py +++ b/darts/models/forecasting/ensemble_model.py @@ -26,6 +26,11 @@ class EnsembleModel(GlobalForecastingModel): List of forecasting models whose predictions to ensemble """ def __init__(self, models: Union[List[ForecastingModel], List[GlobalForecastingModel]]): + raise_if(any([m._fit_called for m in models]), + "Cannot instantiate EnsembleModel with trained/fitted models. " + "Consider resetting all models with `my_model.untrained_model()`", + logger) + raise_if_not(isinstance(models, list) and models, "Cannot instantiate EnsembleModel with an empty list of models", logger) From 93d2a395c0c94fcb30f56f24c97e1ff3b354eb0a Mon Sep 17 00:00:00 2001 From: dennisbader Date: Sun, 9 Jan 2022 14:43:19 +0100 Subject: [PATCH 2/2] added test and changed position of checks due to failing tests --- darts/models/forecasting/ensemble_model.py | 11 ++++++----- .../tests/models/forecasting/test_ensemble_models.py | 9 +++++++++ 2 files changed, 15 insertions(+), 5 deletions(-) diff --git a/darts/models/forecasting/ensemble_model.py b/darts/models/forecasting/ensemble_model.py index 6f3fa93151..a8f921b066 100644 --- a/darts/models/forecasting/ensemble_model.py +++ b/darts/models/forecasting/ensemble_model.py @@ -26,11 +26,6 @@ class EnsembleModel(GlobalForecastingModel): List of forecasting models whose predictions to ensemble """ def __init__(self, models: Union[List[ForecastingModel], List[GlobalForecastingModel]]): - raise_if(any([m._fit_called for m in models]), - "Cannot instantiate EnsembleModel with trained/fitted models. " - "Consider resetting all models with `my_model.untrained_model()`", - logger) - raise_if_not(isinstance(models, list) and models, "Cannot instantiate EnsembleModel with an empty list of models", logger) @@ -42,6 +37,12 @@ def __init__(self, models: Union[List[ForecastingModel], List[GlobalForecastingM raise_if_not(is_local_ensemble or self.is_global_ensemble, "All models must either be GlobalForecastingModel instances, or none of them should be.", logger) + + raise_if(any([m._fit_called for m in models]), + "Cannot instantiate EnsembleModel with trained/fitted models. " + "Consider resetting all models with `my_model.untrained_model()`", + logger) + super().__init__() self.models = models self.is_single_series = None diff --git a/darts/tests/models/forecasting/test_ensemble_models.py b/darts/tests/models/forecasting/test_ensemble_models.py index d1b73af0ac..0f447cfd37 100644 --- a/darts/tests/models/forecasting/test_ensemble_models.py +++ b/darts/tests/models/forecasting/test_ensemble_models.py @@ -33,6 +33,15 @@ class EnsembleModelsTestCase(DartsBaseTestClass): seq1 = [_make_ts(0), _make_ts(10), _make_ts(20)] cov1 = [_make_ts(5), _make_ts(15), _make_ts(25)] + def test_untrained_models(self): + model = NaiveDrift() + _ = NaiveEnsembleModel([model]) + + # trained models should raise error + model.fit(self.series1) + with self.assertRaises(ValueError): + NaiveEnsembleModel([model]) + def test_input_models_local_models(self): with self.assertRaises(ValueError): NaiveEnsembleModel([])