Skip to content

Commit 28ca88d

Browse files
authored
adapt prophet calls with vectorized=True (#1208)
1 parent 3cff4a6 commit 28ca88d

File tree

2 files changed

+4
-4
lines changed

2 files changed

+4
-4
lines changed

darts/models/forecasting/prophet_model.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,7 @@ def _predict(
155155
predict_df = self._generate_predict_df(n=n, future_covariates=future_covariates)
156156

157157
if num_samples == 1:
158-
forecast = self.model.predict(predict_df)["yhat"].values
158+
forecast = self.model.predict(predict_df, vectorized=True)["yhat"].values
159159
else:
160160
forecast = np.expand_dims(
161161
self._stochastic_samples(predict_df, n_samples=num_samples), axis=1
@@ -203,7 +203,7 @@ def _stochastic_samples(self, predict_df, n_samples) -> np.ndarray:
203203

204204
predict_df["trend"] = self.model.predict_trend(predict_df)
205205

206-
forecast = self.model.sample_posterior_predictive(predict_df)
206+
forecast = self.model.sample_posterior_predictive(predict_df, vectorized=True)
207207

208208
# reset default number of uncertainty_samples
209209
self.model.uncertainty_samples = n_samples_default
@@ -221,7 +221,7 @@ def predict_raw(
221221

222222
predict_df = self._generate_predict_df(n=n, future_covariates=future_covariates)
223223

224-
return self.model.predict(predict_df)
224+
return self.model.predict(predict_df, vectorized=True)
225225

226226
def add_seasonality(
227227
self,

requirements/core.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ nfoursid>=1.0.0
88
numpy>=1.19.0
99
pandas>=1.0.5
1010
pmdarima>=1.8.0
11-
prophet>=1.1
11+
prophet>=1.1.1
1212
requests>=2.22.0
1313
scikit-learn>=1.0.1
1414
scipy>=1.3.2

0 commit comments

Comments
 (0)