Skip to content

Commit 4a97140

Browse files
authored
added check for untrained models for copying of model parameters (#728)
* added check for untrained models for copying of model parameters * added test and changed position of checks due to failing tests
1 parent 8fa7dbf commit 4a97140

File tree

2 files changed

+15
-0
lines changed

2 files changed

+15
-0
lines changed

darts/models/forecasting/ensemble_model.py

+6
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,12 @@ def __init__(self, models: Union[List[ForecastingModel], List[GlobalForecastingM
3737
raise_if_not(is_local_ensemble or self.is_global_ensemble,
3838
"All models must either be GlobalForecastingModel instances, or none of them should be.",
3939
logger)
40+
41+
raise_if(any([m._fit_called for m in models]),
42+
"Cannot instantiate EnsembleModel with trained/fitted models. "
43+
"Consider resetting all models with `my_model.untrained_model()`",
44+
logger)
45+
4046
super().__init__()
4147
self.models = models
4248
self.is_single_series = None

darts/tests/models/forecasting/test_ensemble_models.py

+9
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,15 @@ class EnsembleModelsTestCase(DartsBaseTestClass):
3333
seq1 = [_make_ts(0), _make_ts(10), _make_ts(20)]
3434
cov1 = [_make_ts(5), _make_ts(15), _make_ts(25)]
3535

36+
def test_untrained_models(self):
37+
model = NaiveDrift()
38+
_ = NaiveEnsembleModel([model])
39+
40+
# trained models should raise error
41+
model.fit(self.series1)
42+
with self.assertRaises(ValueError):
43+
NaiveEnsembleModel([model])
44+
3645
def test_input_models_local_models(self):
3746
with self.assertRaises(ValueError):
3847
NaiveEnsembleModel([])

0 commit comments

Comments
 (0)