@@ -691,17 +691,16 @@ def historical_forecasts(
691
691
retrain_func = _retrain_wrapper (
692
692
lambda counter : counter % int (retrain ) == 0 if retrain else False
693
693
)
694
-
695
694
elif isinstance (retrain , Callable ):
696
695
retrain_func = _retrain_wrapper (retrain )
697
-
698
696
else :
699
697
raise_log (
700
698
ValueError (
701
699
"`retrain` argument must be either `bool`, positive `int` or `Callable` (returning `bool`)"
702
700
),
703
701
logger ,
704
702
)
703
+
705
704
retrain_func_signature = tuple (
706
705
inspect .signature (retrain_func ).parameters .keys ()
707
706
)
@@ -728,7 +727,6 @@ def historical_forecasts(
728
727
729
728
forecasts_list = []
730
729
for idx , series_ in enumerate (outer_iterator ):
731
-
732
730
past_covariates_ = past_covariates [idx ] if past_covariates else None
733
731
future_covariates_ = future_covariates [idx ] if future_covariates else None
734
732
@@ -765,15 +763,12 @@ def historical_forecasts(
765
763
766
764
# prepare the start parameter -> pd.Timestamp
767
765
if start is not None :
768
-
769
766
historical_forecasts_time_index = drop_before_index (
770
767
historical_forecasts_time_index ,
771
768
series_ .get_timestamp_at_point (start ),
772
769
)
773
-
774
770
else :
775
771
if (retrain is not False ) or (not self ._fit_called ):
776
-
777
772
if train_length :
778
773
historical_forecasts_time_index = drop_before_index (
779
774
historical_forecasts_time_index ,
@@ -804,9 +799,9 @@ def historical_forecasts(
804
799
(not self ._fit_called )
805
800
and (retrain is False )
806
801
and (not train_length ),
807
- " The model has not been fitted yet, and `start` and train_length are not specified. "
808
- " The model is not retraining during the historical forecasts. Hence the "
809
- "the first and only training would be done on 2 samples." ,
802
+ "The model has not been fitted yet, and `start` and train_length are not specified. "
803
+ "The model is not retraining during the historical forecasts. Hence the "
804
+ "first and only training would be done on 2 samples." ,
810
805
logger ,
811
806
)
812
807
@@ -837,7 +832,6 @@ def historical_forecasts(
837
832
838
833
# iterate and forecast
839
834
for _counter , pred_time in enumerate (iterator ):
840
-
841
835
# build the training series
842
836
if min_timestamp > series_ .time_index [0 ]:
843
837
train_series = series_ .drop_before (
@@ -866,13 +860,17 @@ def historical_forecasts(
866
860
if future_covariates_
867
861
else None ,
868
862
):
869
- self ._fit_wrapper (
863
+ # avoid fitting the same model multiple times
864
+ model = self .untrained_model ()
865
+ model ._fit_wrapper (
870
866
series = train_series ,
871
867
past_covariates = past_covariates_ ,
872
868
future_covariates = future_covariates_ ,
873
869
)
870
+ else :
871
+ model = self
874
872
875
- forecast = self ._predict_wrapper (
873
+ forecast = model ._predict_wrapper (
876
874
n = forecast_horizon ,
877
875
series = train_series ,
878
876
past_covariates = past_covariates_ ,
@@ -901,7 +899,6 @@ def historical_forecasts(
901
899
hierarchy = series_ .hierarchy ,
902
900
)
903
901
)
904
-
905
902
else :
906
903
forecasts_list .append (forecasts )
907
904
@@ -1526,7 +1523,7 @@ def _extract_model_creation_params(self):
1526
1523
return model_params
1527
1524
1528
1525
def untrained_model (self ):
1529
- return self .__class__ (** self .model_params )
1526
+ return self .__class__ (** copy . deepcopy ( self .model_params ) )
1530
1527
1531
1528
@property
1532
1529
def model_params (self ) -> dict :
0 commit comments