diff --git a/darts/models/__init__.py b/darts/models/__init__.py index 01de7d37ea..d39845c0d7 100644 --- a/darts/models/__init__.py +++ b/darts/models/__init__.py @@ -14,6 +14,7 @@ from darts.models.forecasting.exponential_smoothing import ExponentialSmoothing from darts.models.forecasting.fft import FFT from darts.models.forecasting.kalman_forecaster import KalmanForecaster +from darts.models.forecasting.tbats import BATS, TBATS from darts.models.forecasting.theta import FourTheta, Theta from darts.models.forecasting.varima import VARIMA diff --git a/darts/models/forecasting/tbats.py b/darts/models/forecasting/tbats.py new file mode 100644 index 0000000000..619361acd3 --- /dev/null +++ b/darts/models/forecasting/tbats.py @@ -0,0 +1,243 @@ +""" +BATS and TBATS +-------------- + +(T)BATS models [1]_ stand for + +* (Trigonometric) +* Box-Cox +* ARMA errors +* Trend +* Seasonal components + +They are appropriate to model "complex +seasonal time series such as those with multiple +seasonal periods, high frequency seasonality, +non-integer seasonality and dual-calendar effects" [1]_. + +References +---------- +.. [1] https://robjhyndman.com/papers/ComplexSeasonality.pdf +""" + +from abc import ABC, abstractmethod +from typing import List, Optional, Tuple, Union + +import numpy as np +from scipy.special import inv_boxcox +from tbats import BATS as tbats_BATS +from tbats import TBATS as tbats_TBATS + +from darts.logging import get_logger +from darts.models.forecasting.forecasting_model import ForecastingModel +from darts.timeseries import TimeSeries + +logger = get_logger(__name__) + + +def _seasonality_from_freq(series: TimeSeries): + """ + Infer a naive seasonality based on the frequency + """ + + if series.has_range_index: + return None + + freq = series.freq_str + + if freq in ["B", "C"]: + return [5] + elif freq == "D": + return [7] + elif freq == "W": + return [52] + elif freq in ["M", "BM", "CBM", "SM"] or freq.startswith( + ("M", "BM", "BS", "CBM", "SM") + ): + return [12] # month + elif freq in ["Q", "BQ", "REQ"] or freq.startswith(("Q", "BQ", "REQ")): + return [4] # quarter + elif freq in ["H", "BH", "CBH"]: + return [24] # hour + elif freq in ["T", "min"]: + return [60] # minute + elif freq == "S": + return [60] # second + + return None + + +def _compute_samples(model, predictions, n_samples): + """ + This function is drawn from Model._calculate_confidence_intervals() in tbats. + We have to implement our own version here in order to compute the samples before + the inverse boxcox transform. + """ + + # In the deterministic case we return the analytic mean + if n_samples == 1: + return np.expand_dims(predictions, axis=1) + + F = model.matrix.make_F_matrix() + g = model.matrix.make_g_vector() + w = model.matrix.make_w_vector() + + c = np.asarray([1.0] * len(predictions)) + f_running = np.identity(F.shape[1]) + for step in range(1, len(predictions)): + c[step] = w @ f_running @ g + f_running = f_running @ F + variance_multiplier = np.cumsum(c * c) + + base_variance_boxcox = np.sum(model.resid_boxcox * model.resid_boxcox) / len( + model.y + ) + variance_boxcox = base_variance_boxcox * variance_multiplier + std_boxcox = np.sqrt(variance_boxcox) + + # get the samples before inverse boxcoxing + samples = np.random.normal( + loc=model._boxcox(predictions), + scale=std_boxcox, + size=(n_samples, len(predictions)), + ).T + samples = np.expand_dims(samples, axis=1) + + # apply inverse boxcox if needed + boxcox_lambda = model.params.box_cox_lambda + if boxcox_lambda is not None: + samples = inv_boxcox(samples, boxcox_lambda) + + return samples + + +class _BaseBatsTbatsModel(ForecastingModel, ABC): + def __init__( + self, + use_box_cox: Optional[bool] = None, + box_cox_bounds: Tuple = (0, 1), + use_trend: Optional[bool] = None, + use_damped_trend: Optional[bool] = None, + seasonal_periods: Optional[Union[str, List]] = "freq", + use_arma_errors: Optional[bool] = True, + show_warnings: bool = False, + n_jobs: Optional[int] = None, + multiprocessing_start_method: Optional[str] = "spawn", + random_state: int = 0, + ): + + """ + This is a wrapper around + `tbats + `_. + + This implementation also provides naive frequency inference (when "freq" + is provided for ``seasonal_periods``), + as well as Darts-compatible sampling of the resulting normal distribution. + + For convenience, the tbats documentation of the parameters is reported here. + + Parameters + ---------- + use_box_cox + If Box-Cox transformation of original series should be applied. + When ``None`` both cases shall be considered and better is selected by AIC. + box_cox_bounds + Minimal and maximal Box-Cox parameter values. + use_trend + Indicates whether to include a trend or not. + When ``None``, both cases shall be considered and the better one is selected by AIC. + use_damped_trend + Indicates whether to include a damping parameter in the trend or not. + Applies only when trend is used. + When ``None``, both cases shall be considered and the better one is selected by AIC. + seasonal_periods + Length of each of the periods (amount of observations in each period). + TBATS accepts int and float values here. + BATS accepts only int values. + When ``None`` or empty array, non-seasonal model shall be fitted. + If set to ``"freq"``, a single "naive" seasonality + based on the series frequency will be used (e.g. [12] for monthly series). + In this latter case, the seasonality will be recomputed every time the model is fit. + use_arma_errors + When True BATS will try to improve the model by modelling residuals with ARMA. + Best model will be selected by AIC. + If ``False``, ARMA residuals modeling will not be considered. + show_warnings + If warnings should be shown or not. + n_jobs + How many jobs to run in parallel when fitting BATS model. + When not provided BATS shall try to utilize all available cpu cores. + multiprocessing_start_method + How threads should be started. + See https://docs.python.org/3/library/multiprocessing.html#contexts-and-start-methods + random_state + Sets the underlying random seed at model initialization time. + """ + super().__init__() + + self.kwargs = { + "use_box_cox": use_box_cox, + "box_cox_bounds": box_cox_bounds, + "use_trend": use_trend, + "use_damped_trend": use_damped_trend, + "seasonal_periods": seasonal_periods, + "use_arma_errors": use_arma_errors, + "show_warnings": show_warnings, + "n_jobs": n_jobs, + "multiprocessing_start_method": multiprocessing_start_method, + } + + self.seasonal_periods = seasonal_periods + self.infer_seasonal_periods = seasonal_periods == "freq" + self.model = None + np.random.seed(random_state) + + def __str__(self): + return "(T)BATS" + + @abstractmethod + def _create_model(self): + pass + + def fit(self, series: TimeSeries): + super().fit(series) + series = self.training_series + + if self.infer_seasonal_periods: + seasonality = _seasonality_from_freq(series) + self.kwargs["seasonal_periods"] = seasonality + self.seasonal_periods = seasonality + + model = self._create_model() + fitted_model = model.fit(series.values()) + self.model = fitted_model + + return self + + def predict(self, n, num_samples=1): + super().predict(n, num_samples) + + yhat = self.model.forecast(steps=n) + samples = _compute_samples(self.model, yhat, num_samples) + + return self._build_forecast_series(samples) + + def _is_probabilistic(self) -> bool: + return True + + @property + def min_train_series_length(self) -> int: + if isinstance(self.seasonal_periods, int) and self.seasonal_periods > 1: + return 2 * self.seasonal_periods + return 3 + + +class TBATS(_BaseBatsTbatsModel): + def _create_model(self): + return tbats_TBATS(**self.kwargs) + + +class BATS(_BaseBatsTbatsModel): + def _create_model(self): + return tbats_BATS(**self.kwargs) diff --git a/darts/tests/models/forecasting/test_local_forecasting_models.py b/darts/tests/models/forecasting/test_local_forecasting_models.py index 12c7f66bbd..b8cf77b763 100644 --- a/darts/tests/models/forecasting/test_local_forecasting_models.py +++ b/darts/tests/models/forecasting/test_local_forecasting_models.py @@ -74,9 +74,11 @@ logger.warning("Prophet not installed - will be skipping Prophet tests") try: - from darts.models import AutoARIMA + from darts.models import BATS, TBATS, AutoARIMA models.append((AutoARIMA(), 12.2)) + models.append((TBATS(use_trend=True, use_arma_errors=True, use_box_cox=True), 8.0)) + models.append((BATS(use_trend=True, use_arma_errors=True, use_box_cox=True), 10.0)) dual_models.append(AutoARIMA()) PMDARIMA_AVAILABLE = True except ImportError: diff --git a/darts/tests/models/forecasting/test_probabilistic_models.py b/darts/tests/models/forecasting/test_probabilistic_models.py index 16b04af8ba..9ca5041a82 100644 --- a/darts/tests/models/forecasting/test_probabilistic_models.py +++ b/darts/tests/models/forecasting/test_probabilistic_models.py @@ -10,6 +10,14 @@ logger = get_logger(__name__) +try: + from darts.models import BATS, TBATS + + PMDARIMA_AVAILABLE = True +except ImportError: + logger.warning("pmdarima not available. BATS/TBATS probabilistic tests skipped.") + PMDARIMA_AVAILABLE = False + try: import torch @@ -46,10 +54,36 @@ TORCH_AVAILABLE = False models_cls_kwargs_errs = [ - (ExponentialSmoothing, {}, 0.4), - (ARIMA, {"p": 1, "d": 0, "q": 1}, 0.17), + (ExponentialSmoothing, {}, 0.3), + (ARIMA, {"p": 1, "d": 0, "q": 1}, 0.03), ] +if PMDARIMA_AVAILABLE: + models_cls_kwargs_errs += [ + ( + BATS, + { + "use_trend": False, + "use_damped_trend": False, + "use_box_cox": True, + "use_arma_errors": False, + "random_state": 42, + }, + 0.3, + ), + ( + TBATS, + { + "use_trend": False, + "use_damped_trend": False, + "use_box_cox": True, + "use_arma_errors": False, + "random_state": 42, + }, + 0.3, + ), + ] + if TORCH_AVAILABLE: models_cls_kwargs_errs += [ ( @@ -125,11 +159,11 @@ def test_fit_predict_determinism(self): # whether the first predictions of two models initiated with the same random state are the same model = model_cls(**model_kwargs) - model.fit(self.constant_ts) + model.fit(self.constant_noisy_ts) pred1 = model.predict(n=10, num_samples=2).values() model = model_cls(**model_kwargs) - model.fit(self.constant_ts) + model.fit(self.constant_noisy_ts) pred2 = model.predict(n=10, num_samples=2).values() self.assertTrue((pred1 == pred2).all()) @@ -210,7 +244,7 @@ def helper_test_probabilistic_forecast_accuracy( (ExponentialLikelihood(), real_pos_series, 0.3, 2), (DirichletLikelihood(), simplex_series, 0.3, 0.3), (GeometricLikelihood(), discrete_pos_series, 1, 1), - (CauchyLikelihood(), real_series, 3, 10), + (CauchyLikelihood(), real_series, 3, 11), (ContinuousBernoulliLikelihood(), bounded_series, 0.1, 0.1), (HalfNormalLikelihood(), real_pos_series, 0.3, 8), (LogNormalLikelihood(), real_pos_series, 0.3, 1), diff --git a/requirements/pmdarima.txt b/requirements/pmdarima.txt index 5fa71e5333..e73f98849e 100644 --- a/requirements/pmdarima.txt +++ b/requirements/pmdarima.txt @@ -1 +1,2 @@ pmdarima>=1.8.0 +tbats>=1.1.0