Skip to content

Commit 803d851

Browse files
fix: replacing lambda with named function to make model pickable (#1594)
* fix: replacing lambda with named function to make model pickable * fix: issue was also occurring with the exponential de-trending function * fix: adding typing * fix: linting --------- Co-authored-by: Dennis Bader <[email protected]>
1 parent c80049b commit 803d851

File tree

1 file changed

+22
-9
lines changed
  • darts/models/forecasting

1 file changed

+22
-9
lines changed

darts/models/forecasting/fft.py

+22-9
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
----------------------
44
"""
55

6-
from typing import Optional
6+
from typing import Callable, Optional
77

88
import numpy as np
99
import pandas as pd
@@ -238,7 +238,7 @@ def __init__(
238238
pd.Timestamp attributes that are relevant for the seasonality automatically.
239239
trend
240240
If set, indicates what kind of detrending will be applied before performing DFT.
241-
Possible values: 'poly' or 'exp', for polynomial trend, or exponential trend, respectively.
241+
Possible values: 'poly', 'exp' or None, for polynomial trend, exponential trend or no trend, respectively.
242242
trend_poly_degree
243243
The degree of the polynomial that will be used for detrending, if `trend='poly'`.
244244
@@ -269,6 +269,20 @@ def __str__(self):
269269
+ ")"
270270
)
271271

272+
def _exp_trend(self, x) -> Callable:
273+
"""Helper function, used to make FFT model pickable."""
274+
return np.exp(self.trend_coefficients[1]) * np.exp(
275+
self.trend_coefficients[0] * x
276+
)
277+
278+
def _poly_trend(self, trend_coefficients) -> Callable:
279+
"""Helper function, for consistency with the other trends"""
280+
return np.poly1d(trend_coefficients)
281+
282+
def _null_trend(self, x) -> Callable:
283+
"""Helper function, used to make FFT model pickable."""
284+
return 0
285+
272286
def fit(self, series: TimeSeries):
273287
series = fill_missing_values(series)
274288
super().fit(series)
@@ -277,19 +291,18 @@ def fit(self, series: TimeSeries):
277291

278292
# determine trend
279293
if self.trend == "poly":
280-
trend_coefficients = np.polyfit(
294+
self.trend_coefficients = np.polyfit(
281295
range(len(series)), series.univariate_values(), self.trend_poly_degree
282296
)
283-
self.trend_function = np.poly1d(trend_coefficients)
297+
self.trend_function = self._poly_trend(self.trend_coefficients)
284298
elif self.trend == "exp":
285-
trend_coefficients = np.polyfit(
299+
self.trend_coefficients = np.polyfit(
286300
range(len(series)), np.log(series.univariate_values()), 1
287301
)
288-
self.trend_function = lambda x: np.exp(trend_coefficients[1]) * np.exp(
289-
trend_coefficients[0] * x
290-
)
302+
self.trend_function = self._exp_trend
291303
else:
292-
self.trend_function = lambda x: 0
304+
self.trend_coefficients = None
305+
self.trend_function = self._null_trend
293306

294307
# subtract trend
295308
detrended_values = series.univariate_values() - self.trend_function(

0 commit comments

Comments
 (0)