Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat/tbats #816

Merged
merged 22 commits into from
Mar 9, 2022
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions darts/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from darts.models.forecasting.exponential_smoothing import ExponentialSmoothing
from darts.models.forecasting.fft import FFT
from darts.models.forecasting.kalman_forecaster import KalmanForecaster
from darts.models.forecasting.tbats import BATS, TBATS
from darts.models.forecasting.theta import FourTheta, Theta
from darts.models.forecasting.varima import VARIMA

Expand Down
230 changes: 230 additions & 0 deletions darts/models/forecasting/tbats.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,230 @@
"""
BATS and TBATS
--------------

(T)BATS models [1]_ stand for

* (Trigonometric)
* Box-Cox
* ARMA errors
* Trend
* Seasonal components

They are appropriate to model "complex
seasonal time series such as those with multiple
seasonal periods, high frequency seasonality,
non-integer seasonality and dual-calendar effects" [1]_.

References
----------
.. [1] https://robjhyndman.com/papers/ComplexSeasonality.pdf
"""

from abc import ABC, abstractmethod
from typing import List, Optional, Tuple

import numpy as np
from scipy.special import inv_boxcox
from tbats import BATS as tbats_BATS
from tbats import TBATS as tbats_TBATS

from darts.logging import get_logger
from darts.models.forecasting.forecasting_model import ForecastingModel
from darts.timeseries import TimeSeries

logger = get_logger(__name__)


def _seasonality_from_freq(series: TimeSeries):
"""
Infer a naive seasonality based on the frequency
"""
if series.has_range_index:
return [12]
elif series.freq_str == "B":
return [5]
elif series.freq_str == "D":
return [7]
elif series.freq_str == "W":
return [52]
elif series.freq_str in ["MS", "M"]:
return [12]
elif series.freq_str == ["Q", "BQ", "QS", "BQS"]:
return [4]
elif series.freq_str == ["H"]:
return [24]
return None


def _compute_samples(model, predictions, n_samples):
"""
This function is drawn from Model._calculate_confidence_intervals() in tbats.
We have to implement our own version here in order to compute the samples before
the inverse boxcox transform.
"""

# In the deterministic case we return the analytic mean
if n_samples == 1:
return np.expand_dims(predictions, axis=1)

F = model.matrix.make_F_matrix()
g = model.matrix.make_g_vector()
w = model.matrix.make_w_vector()

c = np.asarray([1.0] * len(predictions))
f_running = np.identity(F.shape[1])
for step in range(1, len(predictions)):
c[step] = w @ f_running @ g
f_running = f_running @ F
variance_multiplier = np.cumsum(c * c)

base_variance_boxcox = np.sum(model.resid_boxcox * model.resid_boxcox) / len(
model.y
)
variance_boxcox = base_variance_boxcox * variance_multiplier
std_boxcox = np.sqrt(variance_boxcox)

# get the samples before inverse boxcoxing
samples = np.random.normal(
loc=model._boxcox(predictions),
scale=std_boxcox,
size=(n_samples, len(predictions)),
).T
samples = np.expand_dims(samples, axis=1)

# apply inverse boxcox if needed
boxcox_lambda = model.params.box_cox_lambda
if boxcox_lambda is not None:
samples = inv_boxcox(samples, boxcox_lambda)

return samples


class _BaseBatsTbatsModel(ForecastingModel, ABC):
def __init__(
self,
use_box_cox: Optional[bool] = None,
box_cox_bounds: Tuple = (0, 1),
use_trend: Optional[bool] = None,
use_damped_trend: Optional[bool] = None,
seasonal_periods: Optional[List] = "freq",
use_arma_errors: Optional[bool] = True,
show_warnings: bool = False,
n_jobs: Optional[int] = None,
multiprocessing_start_method: Optional[str] = "spawn",
random_state: int = 0,
**kwargs,
):

"""
This is a wrapper around
`tbats
<https://github.com/intive-DataScience/tbats>`_.

This implementation also provides naive frequency inference (when "freq"
is provided for ``seasonal_periods``),
as well as Darts-compatible sampling of the resulting normal distribution.

For convenience, the tbats documentation of the parameters is reported here.

Parameters
----------
use_box_cox
If Box-Cox transformation of original series should be applied.
When None both cases shall be considered and better is selected by AIC.
box_cox_bounds
Minimal and maximal Box-Cox parameter values.
use_trend
Indicates whether to include a trend or not.
When None both cases shall be considered and better is selected by AIC.
use_damped_trend
Indicates whether to include a damping parameter in the trend or not.
Applies only when trend is used.
When None both cases shall be considered and better is selected by AIC.
seasonal_periods
Length of each of the periods (amount of observations in each period).
TBATS accepts int and float values here.
BATS accepts only int values.
When ``None`` or empty array, non-seasonal model shall be fitted.
If set to ``"freq"``, a single "naive" seasonality
based on the series frequency will be used (e.g. [12] for monthly series).
In this latter case, the seasonality will be recomputed every time the model is fit.
use_arma_errors
When True BATS will try to improve the model by modelling residuals with ARMA.
Best model will be selected by AIC.
If False, ARMA residuals modeling will not be considered.
show_warnings
If warnings should be shown or not.
n_jobs: int, optional (default=None)
How many jobs to run in parallel when fitting BATS model.
When not provided BATS shall try to utilize all available cpu cores.
multiprocessing_start_method: str, optional (default='spawn')
How threads should be started.
See https://docs.python.org/3/library/multiprocessing.html#contexts-and-start-methods
"""
super().__init__()

self.kwargs = {
"use_box_cox": use_box_cox,
"box_cox_bounds": box_cox_bounds,
"use_trend": use_trend,
"use_damped_trend": use_damped_trend,
"seasonal_periods": seasonal_periods,
"use_arma_errors": use_arma_errors,
"show_warnings": show_warnings,
"n_jobs": n_jobs,
"multiprocessing_start_method": multiprocessing_start_method,
}

self.seasonal_periods = seasonal_periods
self.infer_seasonal_periods = seasonal_periods == "freq"
self.model = None
np.random.seed(random_state)

def __str__(self):
return "(T)BATS"

@abstractmethod
def _create_model(self):
pass
# return tbats_TBATS(**self.kwargs)

def fit(self, series: TimeSeries):
super().fit(series)
series = self.training_series

if self.infer_seasonal_periods:
self.kwargs["seasonal_periods"] = _seasonality_from_freq(series)

model = self._create_model()
fitted_model = model.fit(series.values())
self.model = fitted_model

return self

def predict(self, n, num_samples=1):
super().predict(n, num_samples)

yhat = self.model.forecast(steps=n)
samples = _compute_samples(self.model, yhat, num_samples)

return self._build_forecast_series(samples)

def _is_probabilistic(self) -> bool:
return True

@property
def min_train_series_length(self) -> int:
if isinstance(self.seasonal_periods, int) and self.seasonal_periods > 1:
return 2 * self.seasonal_periods
return 3


class TBATS(_BaseBatsTbatsModel):
def _create_model(self):
return tbats_TBATS(**self.kwargs)


class BATS(_BaseBatsTbatsModel):
def _create_model(self):
return tbats_BATS(**self.kwargs)
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@
from darts.metrics import mape
from darts.models import (
ARIMA,
BATS,
FFT,
TBATS,
VARIMA,
ExponentialSmoothing,
FourTheta,
Expand Down Expand Up @@ -44,6 +46,8 @@
(FourTheta(trend_mode=TrendMode.EXPONENTIAL), 5.5),
(FourTheta(model_mode=ModelMode.MULTIPLICATIVE), 11.4),
(FourTheta(season_mode=SeasonalityMode.ADDITIVE), 14.2),
(TBATS(use_trend=True, use_arma_errors=True, use_box_cox=True), 8.0),
(BATS(use_trend=True, use_arma_errors=True, use_box_cox=True), 10.0),
(FFT(trend="poly"), 11.4),
(NaiveSeasonal(), 32.4),
(KalmanForecaster(dim_x=3), 17.0),
Expand Down
34 changes: 28 additions & 6 deletions darts/tests/models/forecasting/test_probabilistic_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from darts import TimeSeries
from darts.logging import get_logger
from darts.metrics import mae
from darts.models import ARIMA, ExponentialSmoothing
from darts.models import ARIMA, BATS, TBATS, ExponentialSmoothing
from darts.models.forecasting.forecasting_model import GlobalForecastingModel
from darts.tests.base_test_class import DartsBaseTestClass
from darts.utils import timeseries_generation as tg
Expand Down Expand Up @@ -46,8 +46,30 @@
TORCH_AVAILABLE = False

models_cls_kwargs_errs = [
(ExponentialSmoothing, {}, 0.4),
(ARIMA, {"p": 1, "d": 0, "q": 1}, 0.17),
(ExponentialSmoothing, {}, 0.3),
(ARIMA, {"p": 1, "d": 0, "q": 1}, 0.03),
(
BATS,
{
"use_trend": False,
"use_damped_trend": False,
"use_box_cox": True,
"use_arma_errors": False,
"random_state": 42,
},
0.3,
),
(
TBATS,
{
"use_trend": False,
"use_damped_trend": False,
"use_box_cox": True,
"use_arma_errors": False,
"random_state": 42,
},
0.3,
),
]

if TORCH_AVAILABLE:
Expand Down Expand Up @@ -125,11 +147,11 @@ def test_fit_predict_determinism(self):

# whether the first predictions of two models initiated with the same random state are the same
model = model_cls(**model_kwargs)
model.fit(self.constant_ts)
model.fit(self.constant_noisy_ts)
pred1 = model.predict(n=10, num_samples=2).values()

model = model_cls(**model_kwargs)
model.fit(self.constant_ts)
model.fit(self.constant_noisy_ts)
pred2 = model.predict(n=10, num_samples=2).values()

self.assertTrue((pred1 == pred2).all())
Expand Down Expand Up @@ -210,7 +232,7 @@ def helper_test_probabilistic_forecast_accuracy(
(ExponentialLikelihood(), real_pos_series, 0.3, 2),
(DirichletLikelihood(), simplex_series, 0.3, 0.3),
(GeometricLikelihood(), discrete_pos_series, 1, 1),
(CauchyLikelihood(), real_series, 3, 10),
(CauchyLikelihood(), real_series, 3, 11),
(ContinuousBernoulliLikelihood(), bounded_series, 0.1, 0.1),
(HalfNormalLikelihood(), real_pos_series, 0.3, 8),
(LogNormalLikelihood(), real_pos_series, 0.3, 1),
Expand Down
1 change: 1 addition & 0 deletions requirements/core.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,5 +10,6 @@ requests>=2.22.0
scikit-learn>=1.0.1
scipy>=1.3.2
statsmodels>=0.13.0
tbats>=1.1.0
tqdm>=4.60.0
xarray>=0.17.0