-
Notifications
You must be signed in to change notification settings - Fork 918
/
Copy pathtbats.py
243 lines (197 loc) · 7.73 KB
/
tbats.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
"""
BATS and TBATS
--------------
(T)BATS models [1]_ stand for
* (Trigonometric)
* Box-Cox
* ARMA errors
* Trend
* Seasonal components
They are appropriate to model "complex
seasonal time series such as those with multiple
seasonal periods, high frequency seasonality,
non-integer seasonality and dual-calendar effects" [1]_.
References
----------
.. [1] https://robjhyndman.com/papers/ComplexSeasonality.pdf
"""
from abc import ABC, abstractmethod
from typing import List, Optional, Tuple, Union
import numpy as np
from scipy.special import inv_boxcox
from tbats import BATS as tbats_BATS
from tbats import TBATS as tbats_TBATS
from darts.logging import get_logger
from darts.models.forecasting.forecasting_model import ForecastingModel
from darts.timeseries import TimeSeries
logger = get_logger(__name__)
def _seasonality_from_freq(series: TimeSeries):
"""
Infer a naive seasonality based on the frequency
"""
if series.has_range_index:
return None
freq = series.freq_str
if freq in ["B", "C"]:
return [5]
elif freq == "D":
return [7]
elif freq == "W":
return [52]
elif freq in ["M", "BM", "CBM", "SM"] or freq.startswith(
("M", "BM", "BS", "CBM", "SM")
):
return [12] # month
elif freq in ["Q", "BQ", "REQ"] or freq.startswith(("Q", "BQ", "REQ")):
return [4] # quarter
elif freq in ["H", "BH", "CBH"]:
return [24] # hour
elif freq in ["T", "min"]:
return [60] # minute
elif freq == "S":
return [60] # second
return None
def _compute_samples(model, predictions, n_samples):
"""
This function is drawn from Model._calculate_confidence_intervals() in tbats.
We have to implement our own version here in order to compute the samples before
the inverse boxcox transform.
"""
# In the deterministic case we return the analytic mean
if n_samples == 1:
return np.expand_dims(predictions, axis=1)
F = model.matrix.make_F_matrix()
g = model.matrix.make_g_vector()
w = model.matrix.make_w_vector()
c = np.asarray([1.0] * len(predictions))
f_running = np.identity(F.shape[1])
for step in range(1, len(predictions)):
c[step] = w @ f_running @ g
f_running = f_running @ F
variance_multiplier = np.cumsum(c * c)
base_variance_boxcox = np.sum(model.resid_boxcox * model.resid_boxcox) / len(
model.y
)
variance_boxcox = base_variance_boxcox * variance_multiplier
std_boxcox = np.sqrt(variance_boxcox)
# get the samples before inverse boxcoxing
samples = np.random.normal(
loc=model._boxcox(predictions),
scale=std_boxcox,
size=(n_samples, len(predictions)),
).T
samples = np.expand_dims(samples, axis=1)
# apply inverse boxcox if needed
boxcox_lambda = model.params.box_cox_lambda
if boxcox_lambda is not None:
samples = inv_boxcox(samples, boxcox_lambda)
return samples
class _BaseBatsTbatsModel(ForecastingModel, ABC):
def __init__(
self,
use_box_cox: Optional[bool] = None,
box_cox_bounds: Tuple = (0, 1),
use_trend: Optional[bool] = None,
use_damped_trend: Optional[bool] = None,
seasonal_periods: Optional[Union[str, List]] = "freq",
use_arma_errors: Optional[bool] = True,
show_warnings: bool = False,
n_jobs: Optional[int] = None,
multiprocessing_start_method: Optional[str] = "spawn",
random_state: int = 0,
):
"""
This is a wrapper around
`tbats
<https://github.com/intive-DataScience/tbats>`_.
This implementation also provides naive frequency inference (when "freq"
is provided for ``seasonal_periods``),
as well as Darts-compatible sampling of the resulting normal distribution.
For convenience, the tbats documentation of the parameters is reported here.
Parameters
----------
use_box_cox
If Box-Cox transformation of original series should be applied.
When ``None`` both cases shall be considered and better is selected by AIC.
box_cox_bounds
Minimal and maximal Box-Cox parameter values.
use_trend
Indicates whether to include a trend or not.
When ``None``, both cases shall be considered and the better one is selected by AIC.
use_damped_trend
Indicates whether to include a damping parameter in the trend or not.
Applies only when trend is used.
When ``None``, both cases shall be considered and the better one is selected by AIC.
seasonal_periods
Length of each of the periods (amount of observations in each period).
TBATS accepts int and float values here.
BATS accepts only int values.
When ``None`` or empty array, non-seasonal model shall be fitted.
If set to ``"freq"``, a single "naive" seasonality
based on the series frequency will be used (e.g. [12] for monthly series).
In this latter case, the seasonality will be recomputed every time the model is fit.
use_arma_errors
When True BATS will try to improve the model by modelling residuals with ARMA.
Best model will be selected by AIC.
If ``False``, ARMA residuals modeling will not be considered.
show_warnings
If warnings should be shown or not.
n_jobs
How many jobs to run in parallel when fitting BATS model.
When not provided BATS shall try to utilize all available cpu cores.
multiprocessing_start_method
How threads should be started.
See https://docs.python.org/3/library/multiprocessing.html#contexts-and-start-methods
random_state
Sets the underlying random seed at model initialization time.
"""
super().__init__()
self.kwargs = {
"use_box_cox": use_box_cox,
"box_cox_bounds": box_cox_bounds,
"use_trend": use_trend,
"use_damped_trend": use_damped_trend,
"seasonal_periods": seasonal_periods,
"use_arma_errors": use_arma_errors,
"show_warnings": show_warnings,
"n_jobs": n_jobs,
"multiprocessing_start_method": multiprocessing_start_method,
}
self.seasonal_periods = seasonal_periods
self.infer_seasonal_periods = seasonal_periods == "freq"
self.model = None
np.random.seed(random_state)
def __str__(self):
return "(T)BATS"
@abstractmethod
def _create_model(self):
pass
def fit(self, series: TimeSeries):
super().fit(series)
series = self.training_series
if self.infer_seasonal_periods:
seasonality = _seasonality_from_freq(series)
self.kwargs["seasonal_periods"] = seasonality
self.seasonal_periods = seasonality
model = self._create_model()
fitted_model = model.fit(series.values())
self.model = fitted_model
return self
def predict(self, n, num_samples=1):
super().predict(n, num_samples)
yhat = self.model.forecast(steps=n)
samples = _compute_samples(self.model, yhat, num_samples)
return self._build_forecast_series(samples)
def _is_probabilistic(self) -> bool:
return True
@property
def min_train_series_length(self) -> int:
if isinstance(self.seasonal_periods, int) and self.seasonal_periods > 1:
return 2 * self.seasonal_periods
return 3
class TBATS(_BaseBatsTbatsModel):
def _create_model(self):
return tbats_TBATS(**self.kwargs)
class BATS(_BaseBatsTbatsModel):
def _create_model(self):
return tbats_BATS(**self.kwargs)