diff --git a/darts/models/forecasting/ensemble_model.py b/darts/models/forecasting/ensemble_model.py index 8548a62f2e..a8f921b066 100644 --- a/darts/models/forecasting/ensemble_model.py +++ b/darts/models/forecasting/ensemble_model.py @@ -37,6 +37,12 @@ def __init__(self, models: Union[List[ForecastingModel], List[GlobalForecastingM raise_if_not(is_local_ensemble or self.is_global_ensemble, "All models must either be GlobalForecastingModel instances, or none of them should be.", logger) + + raise_if(any([m._fit_called for m in models]), + "Cannot instantiate EnsembleModel with trained/fitted models. " + "Consider resetting all models with `my_model.untrained_model()`", + logger) + super().__init__() self.models = models self.is_single_series = None diff --git a/darts/tests/models/forecasting/test_ensemble_models.py b/darts/tests/models/forecasting/test_ensemble_models.py index d1b73af0ac..0f447cfd37 100644 --- a/darts/tests/models/forecasting/test_ensemble_models.py +++ b/darts/tests/models/forecasting/test_ensemble_models.py @@ -33,6 +33,15 @@ class EnsembleModelsTestCase(DartsBaseTestClass): seq1 = [_make_ts(0), _make_ts(10), _make_ts(20)] cov1 = [_make_ts(5), _make_ts(15), _make_ts(25)] + def test_untrained_models(self): + model = NaiveDrift() + _ = NaiveEnsembleModel([model]) + + # trained models should raise error + model.fit(self.series1) + with self.assertRaises(ValueError): + NaiveEnsembleModel([model]) + def test_input_models_local_models(self): with self.assertRaises(ValueError): NaiveEnsembleModel([])