3
3
----------------------
4
4
"""
5
5
6
- from typing import Optional
6
+ from typing import Callable , Optional
7
7
8
8
import numpy as np
9
9
import pandas as pd
@@ -238,7 +238,7 @@ def __init__(
238
238
pd.Timestamp attributes that are relevant for the seasonality automatically.
239
239
trend
240
240
If set, indicates what kind of detrending will be applied before performing DFT.
241
- Possible values: 'poly' or 'exp', for polynomial trend, or exponential trend, respectively.
241
+ Possible values: 'poly', 'exp' or None , for polynomial trend, exponential trend or no trend, respectively.
242
242
trend_poly_degree
243
243
The degree of the polynomial that will be used for detrending, if `trend='poly'`.
244
244
@@ -269,6 +269,20 @@ def __str__(self):
269
269
+ ")"
270
270
)
271
271
272
+ def _exp_trend (self , x ) -> Callable :
273
+ """Helper function, used to make FFT model pickable."""
274
+ return np .exp (self .trend_coefficients [1 ]) * np .exp (
275
+ self .trend_coefficients [0 ] * x
276
+ )
277
+
278
+ def _poly_trend (self , trend_coefficients ) -> Callable :
279
+ """Helper function, for consistency with the other trends"""
280
+ return np .poly1d (trend_coefficients )
281
+
282
+ def _null_trend (self , x ) -> Callable :
283
+ """Helper function, used to make FFT model pickable."""
284
+ return 0
285
+
272
286
def fit (self , series : TimeSeries ):
273
287
series = fill_missing_values (series )
274
288
super ().fit (series )
@@ -277,19 +291,18 @@ def fit(self, series: TimeSeries):
277
291
278
292
# determine trend
279
293
if self .trend == "poly" :
280
- trend_coefficients = np .polyfit (
294
+ self . trend_coefficients = np .polyfit (
281
295
range (len (series )), series .univariate_values (), self .trend_poly_degree
282
296
)
283
- self .trend_function = np . poly1d ( trend_coefficients )
297
+ self .trend_function = self . _poly_trend ( self . trend_coefficients )
284
298
elif self .trend == "exp" :
285
- trend_coefficients = np .polyfit (
299
+ self . trend_coefficients = np .polyfit (
286
300
range (len (series )), np .log (series .univariate_values ()), 1
287
301
)
288
- self .trend_function = lambda x : np .exp (trend_coefficients [1 ]) * np .exp (
289
- trend_coefficients [0 ] * x
290
- )
302
+ self .trend_function = self ._exp_trend
291
303
else :
292
- self .trend_function = lambda x : 0
304
+ self .trend_coefficients = None
305
+ self .trend_function = self ._null_trend
293
306
294
307
# subtract trend
295
308
detrended_values = series .univariate_values () - self .trend_function (
0 commit comments