From a7122656cc457bdc8f8404364598d51aea3abfa7 Mon Sep 17 00:00:00 2001 From: Julien Herzen Date: Tue, 22 Feb 2022 12:03:12 +0100 Subject: [PATCH 01/15] Base version of TBATS --- darts/models/__init__.py | 1 + darts/models/forecasting/tbats.py | 193 ++++++++++++++++++++++++++++++ requirements/core.txt | 1 + 3 files changed, 195 insertions(+) create mode 100644 darts/models/forecasting/tbats.py diff --git a/darts/models/__init__.py b/darts/models/__init__.py index 01de7d37ea..d65fc3b730 100644 --- a/darts/models/__init__.py +++ b/darts/models/__init__.py @@ -14,6 +14,7 @@ from darts.models.forecasting.exponential_smoothing import ExponentialSmoothing from darts.models.forecasting.fft import FFT from darts.models.forecasting.kalman_forecaster import KalmanForecaster +from darts.models.forecasting.tbats import TBATS from darts.models.forecasting.theta import FourTheta, Theta from darts.models.forecasting.varima import VARIMA diff --git a/darts/models/forecasting/tbats.py b/darts/models/forecasting/tbats.py new file mode 100644 index 0000000000..981054e39f --- /dev/null +++ b/darts/models/forecasting/tbats.py @@ -0,0 +1,193 @@ +""" +BATS and TBATS +-------------- + +(T)BATS models [1]_ stand for + +* (Trigonometric) +* Box-Cox +* ARMA errors +* Trend +* Seasonal components + +They are appropriate to model "complex +seasonal time series such as those with multiple +seasonal periods, high frequency seasonality, +non-integer seasonality and dual-calendar effects" [1]_. + +References +---------- +.. [1] https://robjhyndman.com/papers/ComplexSeasonality.pdf +""" + +from typing import List, Optional + +import numpy as np +from scipy.special import inv_boxcox +from tbats import TBATS as tbats_TBATS + +from darts.logging import get_logger +from darts.models.forecasting.forecasting_model import ForecastingModel +from darts.timeseries import TimeSeries + +logger = get_logger(__name__) + + +class TBATS(ForecastingModel): + def __init__( + self, + seasonal_periods: Optional[List[int]] = "freq", + use_arma_errors: Optional[bool] = None, + use_box_cox: Optional[bool] = None, + use_trend: Optional[bool] = None, + use_damped_trend: Optional[bool] = None, + random_state: int = 0, + **kwargs, + ): + + """TBATS + + This is a wrapper around + `tbats TBATS model + `_; + we refer to this link for the documentation on the parameters. + + Parameters + ---------- + seasonal_periods + A list of seasonal periods. If ``None``, no seasonality will be set. + If set to ``"freq"``, a single "naive" seasonality + based on the series frequency will be used (e.g. [12] for monthly series). + use_arma_errors + Whether to use ARMA errors (``None``: try with and without) + use_box_cox + Whether to use BoxCox transform (``None``: try with and without) + use_trend + Whether to use trend (``None``: try with and without) + use_damped_trend + Whether to use damped trend (``None``: try with and without) + kwargs + Other optional keyword arguments that will be used to call + :class:`tbats.TBATS`. + """ + super().__init__() + self.seasonal_periods = seasonal_periods + self.use_arma_errors = use_arma_errors + self.use_box_cox = use_box_cox + self.use_trend = use_trend + self.use_damped_trend = use_damped_trend + self.tbats_kwargs = kwargs + + self.infer_seasonal_periods = seasonal_periods == "freq" + self.model = None + np.random.seed(random_state) + + def __str__(self): + return ( + f"TBATS(periods={self.seasonal_periods}, arma_errs={self.use_arma_errors}, " + f"boxcox={self.use_box_cox}, trend={self.use_trend}, damped_trend={self.use_damped_trend}" + ) + + @staticmethod + def _infer_naive_seasonality(series: TimeSeries): + """ + Infer a naive seasonality based on the frequency + """ + if series.has_range_index: + return [12] + elif series.freq_str == "B": + return [5] + elif series.freq_str == "D": + return [7] + elif series.freq_str == "W": + return [52] + elif series.freq_str in ["MS", "M"]: + return [12] + elif series.freq_str == ["Q", "BQ", "QS", "BQS"]: + return [4] + elif series.freq_str == ["H"]: + return [24] + return None + + @staticmethod + def _darts_calculate_confidence_intervals(model, predictions, n_samples): + """ + This function is drawn from Model._calculate_confidence_intervals() in tbats. + We have to implement our own version here in order to compute the samples before + the inverse boxcox transform. + """ + F = model.matrix.make_F_matrix() + g = model.matrix.make_g_vector() + w = model.matrix.make_w_vector() + + c = np.asarray([1.0] * len(predictions)) + f_running = np.identity(F.shape[1]) + for step in range(1, len(predictions)): + c[step] = w @ f_running @ g + f_running = f_running @ F + variance_multiplier = np.cumsum(c * c) + + base_variance_boxcox = np.sum(model.resid_boxcox * model.resid_boxcox) / len( + model.y + ) + variance_boxcox = base_variance_boxcox * variance_multiplier + std_boxcox = np.sqrt(variance_boxcox) + + # get the samples before inverse boxcoxing + samples = np.random.normal( + loc=model._boxcox(predictions), + scale=std_boxcox, + size=(n_samples, len(predictions)), + ).T + samples = np.expand_dims(samples, axis=1) + + # apply inverse boxcox if needed + boxcox_lambda = model.params.box_cox_lambda + if boxcox_lambda is not None: + samples = inv_boxcox(samples, boxcox_lambda) + + return samples + + def fit(self, series: TimeSeries): + super().fit(series) + series = self.training_series + + if self.infer_seasonal_periods: + seasonal_periods = TBATS._infer_naive_seasonality(series) + else: + seasonal_periods = self.seasonal_periods + + model = tbats_TBATS( + seasonal_periods=seasonal_periods, + use_arma_errors=self.use_arma_errors, + use_box_cox=self.use_box_cox, + use_trend=self.use_trend, + use_damped_trend=self.use_damped_trend, + show_warnings=False, + **self.tbats_kwargs, + ) + fitted_model = model.fit(series.values()) + self.model = fitted_model + + return self + + def predict(self, n, num_samples=1): + super().predict(n, num_samples) + + yhat = self.model.forecast(steps=n) + if num_samples == 1: + samples = yhat.view(len(yhat), 1) + else: + samples = TBATS._darts_calculate_confidence_intervals( + self.model, yhat, num_samples + ) + return self._build_forecast_series(samples) + + def _is_probabilistic(self) -> bool: + return True + + @property + def min_train_series_length(self) -> int: + if isinstance(self.seasonal_periods, int) and self.seasonal_periods > 1: + return 2 * self.seasonal_periods + return 3 diff --git a/requirements/core.txt b/requirements/core.txt index 739f073a95..fe1bc6df71 100644 --- a/requirements/core.txt +++ b/requirements/core.txt @@ -10,5 +10,6 @@ requests>=2.22.0 scikit-learn>=1.0.1 scipy>=1.3.2 statsmodels>=0.13.0 +tbats>=1.1.0 tqdm>=4.60.0 xarray>=0.17.0 From 618bf4a91f3c9f361834e34321c5b253a0667c84 Mon Sep 17 00:00:00 2001 From: Julien Herzen Date: Tue, 22 Feb 2022 17:43:03 +0100 Subject: [PATCH 02/15] Add both BATS and TBATS --- darts/models/forecasting/tbats.py | 257 +++++++++++++++++------------- 1 file changed, 150 insertions(+), 107 deletions(-) diff --git a/darts/models/forecasting/tbats.py b/darts/models/forecasting/tbats.py index 981054e39f..841f408a19 100644 --- a/darts/models/forecasting/tbats.py +++ b/darts/models/forecasting/tbats.py @@ -20,10 +20,12 @@ .. [1] https://robjhyndman.com/papers/ComplexSeasonality.pdf """ -from typing import List, Optional +from abc import ABC, abstractmethod +from typing import List, Optional, Tuple import numpy as np from scipy.special import inv_boxcox +from tbats import BATS as tbats_BATS from tbats import TBATS as tbats_TBATS from darts.logging import get_logger @@ -33,139 +35,169 @@ logger = get_logger(__name__) -class TBATS(ForecastingModel): +def _seasonality_from_freq(series: TimeSeries): + """ + Infer a naive seasonality based on the frequency + """ + if series.has_range_index: + return [12] + elif series.freq_str == "B": + return [5] + elif series.freq_str == "D": + return [7] + elif series.freq_str == "W": + return [52] + elif series.freq_str in ["MS", "M"]: + return [12] + elif series.freq_str == ["Q", "BQ", "QS", "BQS"]: + return [4] + elif series.freq_str == ["H"]: + return [24] + return None + + +def _compute_samples(model, predictions, n_samples): + """ + This function is drawn from Model._calculate_confidence_intervals() in tbats. + We have to implement our own version here in order to compute the samples before + the inverse boxcox transform. + """ + + # In the deterministic case we return the analytic mean + if n_samples == 1: + return predictions.view(len(predictions), 1) + + F = model.matrix.make_F_matrix() + g = model.matrix.make_g_vector() + w = model.matrix.make_w_vector() + + c = np.asarray([1.0] * len(predictions)) + f_running = np.identity(F.shape[1]) + for step in range(1, len(predictions)): + c[step] = w @ f_running @ g + f_running = f_running @ F + variance_multiplier = np.cumsum(c * c) + + base_variance_boxcox = np.sum(model.resid_boxcox * model.resid_boxcox) / len( + model.y + ) + variance_boxcox = base_variance_boxcox * variance_multiplier + std_boxcox = np.sqrt(variance_boxcox) + + # get the samples before inverse boxcoxing + samples = np.random.normal( + loc=model._boxcox(predictions), + scale=std_boxcox, + size=(n_samples, len(predictions)), + ).T + samples = np.expand_dims(samples, axis=1) + + # apply inverse boxcox if needed + boxcox_lambda = model.params.box_cox_lambda + if boxcox_lambda is not None: + samples = inv_boxcox(samples, boxcox_lambda) + + return samples + + +class _BaseBatsTbatsModel(ForecastingModel, ABC): def __init__( self, - seasonal_periods: Optional[List[int]] = "freq", - use_arma_errors: Optional[bool] = None, use_box_cox: Optional[bool] = None, + box_cox_bounds: Tuple = (0, 1), use_trend: Optional[bool] = None, use_damped_trend: Optional[bool] = None, + seasonal_periods: Optional[List] = "freq", + use_arma_errors: Optional[bool] = True, + show_warnings: bool = False, + n_jobs: Optional[int] = None, + multiprocessing_start_method: Optional[str] = "spawn", random_state: int = 0, **kwargs, ): - """TBATS + """BATS & TBATS This is a wrapper around - `tbats TBATS model - `_; - we refer to this link for the documentation on the parameters. + `tbats BATS or TBATS model + `_. + + This implementation also provides naive frequency inference (when "freq" + is provided for ``seasonal_periods``), + as well as Darts-compatible sampling of the resulting normal distribution. + + For convenience, the tbats documentation of the parameters is reported here. Parameters ---------- + use_box_cox + If Box-Cox transformation of original series should be applied. + When None both cases shall be considered and better is selected by AIC. + box_cox_bounds + Minimal and maximal Box-Cox parameter values. + use_trend + Indicates whether to include a trend or not. + When None both cases shall be considered and better is selected by AIC. + use_damped_trend + Indicates whether to include a damping parameter in the trend or not. + Applies only when trend is used. + When None both cases shall be considered and better is selected by AIC. seasonal_periods - A list of seasonal periods. If ``None``, no seasonality will be set. + Length of each of the periods (amount of observations in each period). + TBATS accepts int and float values here. + BATS accepts only int values. + When ``None`` or empty array, non-seasonal model shall be fitted. If set to ``"freq"``, a single "naive" seasonality based on the series frequency will be used (e.g. [12] for monthly series). + In this latter case, the seasonality will be recomputed every time the model is fit. use_arma_errors - Whether to use ARMA errors (``None``: try with and without) - use_box_cox - Whether to use BoxCox transform (``None``: try with and without) - use_trend - Whether to use trend (``None``: try with and without) - use_damped_trend - Whether to use damped trend (``None``: try with and without) - kwargs - Other optional keyword arguments that will be used to call - :class:`tbats.TBATS`. + When True BATS will try to improve the model by modelling residuals with ARMA. + Best model will be selected by AIC. + If False, ARMA residuals modeling will not be considered. + show_warnings + If warnings should be shown or not. + n_jobs: int, optional (default=None) + How many jobs to run in parallel when fitting BATS model. + When not provided BATS shall try to utilize all available cpu cores. + multiprocessing_start_method: str, optional (default='spawn') + How threads should be started. + See https://docs.python.org/3/library/multiprocessing.html#contexts-and-start-methods """ super().__init__() - self.seasonal_periods = seasonal_periods - self.use_arma_errors = use_arma_errors - self.use_box_cox = use_box_cox - self.use_trend = use_trend - self.use_damped_trend = use_damped_trend - self.tbats_kwargs = kwargs + self.kwargs = { + "use_box_cox": use_box_cox, + "box_cox_bounds": box_cox_bounds, + "use_trend": use_trend, + "use_damped_trend": use_damped_trend, + "seasonal_periods": seasonal_periods, + "use_arma_errors": use_arma_errors, + "show_warnings": show_warnings, + "n_jobs": n_jobs, + "multiprocessing_start_method": multiprocessing_start_method, + } + + self.seasonal_periods = seasonal_periods self.infer_seasonal_periods = seasonal_periods == "freq" self.model = None np.random.seed(random_state) def __str__(self): - return ( - f"TBATS(periods={self.seasonal_periods}, arma_errs={self.use_arma_errors}, " - f"boxcox={self.use_box_cox}, trend={self.use_trend}, damped_trend={self.use_damped_trend}" - ) + return "(T)BATS" - @staticmethod - def _infer_naive_seasonality(series: TimeSeries): - """ - Infer a naive seasonality based on the frequency - """ - if series.has_range_index: - return [12] - elif series.freq_str == "B": - return [5] - elif series.freq_str == "D": - return [7] - elif series.freq_str == "W": - return [52] - elif series.freq_str in ["MS", "M"]: - return [12] - elif series.freq_str == ["Q", "BQ", "QS", "BQS"]: - return [4] - elif series.freq_str == ["H"]: - return [24] - return None - - @staticmethod - def _darts_calculate_confidence_intervals(model, predictions, n_samples): - """ - This function is drawn from Model._calculate_confidence_intervals() in tbats. - We have to implement our own version here in order to compute the samples before - the inverse boxcox transform. - """ - F = model.matrix.make_F_matrix() - g = model.matrix.make_g_vector() - w = model.matrix.make_w_vector() - - c = np.asarray([1.0] * len(predictions)) - f_running = np.identity(F.shape[1]) - for step in range(1, len(predictions)): - c[step] = w @ f_running @ g - f_running = f_running @ F - variance_multiplier = np.cumsum(c * c) - - base_variance_boxcox = np.sum(model.resid_boxcox * model.resid_boxcox) / len( - model.y - ) - variance_boxcox = base_variance_boxcox * variance_multiplier - std_boxcox = np.sqrt(variance_boxcox) - - # get the samples before inverse boxcoxing - samples = np.random.normal( - loc=model._boxcox(predictions), - scale=std_boxcox, - size=(n_samples, len(predictions)), - ).T - samples = np.expand_dims(samples, axis=1) - - # apply inverse boxcox if needed - boxcox_lambda = model.params.box_cox_lambda - if boxcox_lambda is not None: - samples = inv_boxcox(samples, boxcox_lambda) - - return samples + @abstractmethod + def _create_model(self): + pass + # return tbats_TBATS(**self.kwargs) def fit(self, series: TimeSeries): super().fit(series) series = self.training_series if self.infer_seasonal_periods: - seasonal_periods = TBATS._infer_naive_seasonality(series) - else: - seasonal_periods = self.seasonal_periods - - model = tbats_TBATS( - seasonal_periods=seasonal_periods, - use_arma_errors=self.use_arma_errors, - use_box_cox=self.use_box_cox, - use_trend=self.use_trend, - use_damped_trend=self.use_damped_trend, - show_warnings=False, - **self.tbats_kwargs, - ) + self.kwargs["seasonal_periods"] = _seasonality_from_freq(series) + + model = self._create_model() fitted_model = model.fit(series.values()) self.model = fitted_model @@ -175,12 +207,8 @@ def predict(self, n, num_samples=1): super().predict(n, num_samples) yhat = self.model.forecast(steps=n) - if num_samples == 1: - samples = yhat.view(len(yhat), 1) - else: - samples = TBATS._darts_calculate_confidence_intervals( - self.model, yhat, num_samples - ) + samples = _compute_samples(self.model, yhat, num_samples) + return self._build_forecast_series(samples) def _is_probabilistic(self) -> bool: @@ -191,3 +219,18 @@ def min_train_series_length(self) -> int: if isinstance(self.seasonal_periods, int) and self.seasonal_periods > 1: return 2 * self.seasonal_periods return 3 + + +class TBATS(_BaseBatsTbatsModel): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def _create_model(self): + pass + return tbats_TBATS(**self.kwargs) + + +class BATS(_BaseBatsTbatsModel): + def _create_model(self): + pass + return tbats_BATS(**self.kwargs) From 8f66cd0e4095ae924c0a9cc1c42aac60b6427f56 Mon Sep 17 00:00:00 2001 From: Julien Herzen Date: Tue, 22 Feb 2022 22:11:33 +0100 Subject: [PATCH 03/15] import both bats and tbats --- darts/models/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/darts/models/__init__.py b/darts/models/__init__.py index d65fc3b730..d39845c0d7 100644 --- a/darts/models/__init__.py +++ b/darts/models/__init__.py @@ -14,7 +14,7 @@ from darts.models.forecasting.exponential_smoothing import ExponentialSmoothing from darts.models.forecasting.fft import FFT from darts.models.forecasting.kalman_forecaster import KalmanForecaster -from darts.models.forecasting.tbats import TBATS +from darts.models.forecasting.tbats import BATS, TBATS from darts.models.forecasting.theta import FourTheta, Theta from darts.models.forecasting.varima import VARIMA From 21e71fd105c1406a09326b64b95cb4d5834a7ef0 Mon Sep 17 00:00:00 2001 From: Julien Herzen Date: Tue, 22 Feb 2022 22:11:51 +0100 Subject: [PATCH 04/15] fix an issue --- darts/models/forecasting/tbats.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/darts/models/forecasting/tbats.py b/darts/models/forecasting/tbats.py index 841f408a19..0507b00293 100644 --- a/darts/models/forecasting/tbats.py +++ b/darts/models/forecasting/tbats.py @@ -65,7 +65,7 @@ def _compute_samples(model, predictions, n_samples): # In the deterministic case we return the analytic mean if n_samples == 1: - return predictions.view(len(predictions), 1) + return np.expand_dims(predictions, axis=1) F = model.matrix.make_F_matrix() g = model.matrix.make_g_vector() From 5311a2297a20a2c23e1fb4ed12bb8a7baa9f5dd7 Mon Sep 17 00:00:00 2001 From: Julien Herzen Date: Tue, 22 Feb 2022 22:13:20 +0100 Subject: [PATCH 05/15] added unit tests for bats and tbats probabilistic --- .../forecasting/test_probabilistic_models.py | 34 +++++++++++++++---- 1 file changed, 28 insertions(+), 6 deletions(-) diff --git a/darts/tests/models/forecasting/test_probabilistic_models.py b/darts/tests/models/forecasting/test_probabilistic_models.py index cabf30040f..8e6227048f 100644 --- a/darts/tests/models/forecasting/test_probabilistic_models.py +++ b/darts/tests/models/forecasting/test_probabilistic_models.py @@ -3,7 +3,7 @@ from darts import TimeSeries from darts.logging import get_logger from darts.metrics import mae -from darts.models import ARIMA, ExponentialSmoothing +from darts.models import ARIMA, BATS, TBATS, ExponentialSmoothing from darts.models.forecasting.forecasting_model import GlobalForecastingModel from darts.tests.base_test_class import DartsBaseTestClass from darts.utils import timeseries_generation as tg @@ -46,8 +46,30 @@ TORCH_AVAILABLE = False models_cls_kwargs_errs = [ - (ExponentialSmoothing, {}, 0.4), - (ARIMA, {"p": 1, "d": 0, "q": 1}, 0.17), + (ExponentialSmoothing, {}, 0.3), + (ARIMA, {"p": 1, "d": 0, "q": 1}, 0.03), + ( + BATS, + { + "use_trend": False, + "use_damped_trend": False, + "use_box_cox": True, + "use_arma_errors": False, + "random_state": 42, + }, + 0.3, + ), + ( + TBATS, + { + "use_trend": False, + "use_damped_trend": False, + "use_box_cox": True, + "use_arma_errors": False, + "random_state": 42, + }, + 0.3, + ), ] if TORCH_AVAILABLE: @@ -125,11 +147,11 @@ def test_fit_predict_determinism(self): # whether the first predictions of two models initiated with the same random state are the same model = model_cls(**model_kwargs) - model.fit(self.constant_ts) + model.fit(self.constant_noisy_ts) pred1 = model.predict(n=10, num_samples=2).values() model = model_cls(**model_kwargs) - model.fit(self.constant_ts) + model.fit(self.constant_noisy_ts) pred2 = model.predict(n=10, num_samples=2).values() self.assertTrue((pred1 == pred2).all()) @@ -210,7 +232,7 @@ def helper_test_probabilistic_forecast_accuracy( (ExponentialLikelihood(), real_pos_series, 0.3, 2), (DirichletLikelihood(), simplex_series, 0.3, 0.3), (GeometricLikelihood(), discrete_pos_series, 1, 1), - (CauchyLikelihood(), real_series, 3, 10), + (CauchyLikelihood(), real_series, 3, 11), (ContinuousBernoulliLikelihood(), bounded_series, 0.1, 0.1), (HalfNormalLikelihood(), real_pos_series, 0.3, 8), (LogNormalLikelihood(), real_pos_series, 0.3, 1), From d4e41c7b3d9787199b4031281bf14a9b98ed8da3 Mon Sep 17 00:00:00 2001 From: Julien Herzen Date: Tue, 22 Feb 2022 22:17:36 +0100 Subject: [PATCH 06/15] add accuracy unit tests --- .../tests/models/forecasting/test_local_forecasting_models.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/darts/tests/models/forecasting/test_local_forecasting_models.py b/darts/tests/models/forecasting/test_local_forecasting_models.py index 12c7f66bbd..5795c77e4f 100644 --- a/darts/tests/models/forecasting/test_local_forecasting_models.py +++ b/darts/tests/models/forecasting/test_local_forecasting_models.py @@ -6,7 +6,9 @@ from darts.metrics import mape from darts.models import ( ARIMA, + BATS, FFT, + TBATS, VARIMA, ExponentialSmoothing, FourTheta, @@ -44,6 +46,8 @@ (FourTheta(trend_mode=TrendMode.EXPONENTIAL), 5.5), (FourTheta(model_mode=ModelMode.MULTIPLICATIVE), 11.4), (FourTheta(season_mode=SeasonalityMode.ADDITIVE), 14.2), + (TBATS(use_trend=True, use_arma_errors=True, use_box_cox=True), 8.0), + (BATS(use_trend=True, use_arma_errors=True, use_box_cox=True), 10.0), (FFT(trend="poly"), 11.4), (NaiveSeasonal(), 32.4), (KalmanForecaster(dim_x=3), 17.0), From c8e52c844ec3eed00bb87958e6504d3d794c1cb0 Mon Sep 17 00:00:00 2001 From: Julien Herzen Date: Wed, 23 Feb 2022 10:38:26 +0100 Subject: [PATCH 07/15] Remove useless lines --- darts/models/forecasting/tbats.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/darts/models/forecasting/tbats.py b/darts/models/forecasting/tbats.py index 0507b00293..b1c5b9292a 100644 --- a/darts/models/forecasting/tbats.py +++ b/darts/models/forecasting/tbats.py @@ -226,11 +226,9 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) def _create_model(self): - pass return tbats_TBATS(**self.kwargs) class BATS(_BaseBatsTbatsModel): def _create_model(self): - pass return tbats_BATS(**self.kwargs) From d6883a6491f06ceae7121bf2ea50c2e70a597d7b Mon Sep 17 00:00:00 2001 From: Julien Herzen Date: Wed, 23 Feb 2022 10:40:05 +0100 Subject: [PATCH 08/15] Improve doc --- darts/models/forecasting/tbats.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/darts/models/forecasting/tbats.py b/darts/models/forecasting/tbats.py index b1c5b9292a..e51cefb694 100644 --- a/darts/models/forecasting/tbats.py +++ b/darts/models/forecasting/tbats.py @@ -116,10 +116,9 @@ def __init__( **kwargs, ): - """BATS & TBATS - + """ This is a wrapper around - `tbats BATS or TBATS model + `tbats `_. This implementation also provides naive frequency inference (when "freq" From ee62a0c1c0903df12353d24adc995672e7600790 Mon Sep 17 00:00:00 2001 From: Julien Herzen Date: Wed, 23 Feb 2022 13:35:38 +0100 Subject: [PATCH 09/15] Small fix --- darts/models/forecasting/tbats.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/darts/models/forecasting/tbats.py b/darts/models/forecasting/tbats.py index e51cefb694..6f9964730b 100644 --- a/darts/models/forecasting/tbats.py +++ b/darts/models/forecasting/tbats.py @@ -221,9 +221,6 @@ def min_train_series_length(self) -> int: class TBATS(_BaseBatsTbatsModel): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - def _create_model(self): return tbats_TBATS(**self.kwargs) From 5e21ad30d10f5b2ecef3d1c5a9f1baca12c6e066 Mon Sep 17 00:00:00 2001 From: Julien Herzen Date: Tue, 8 Mar 2022 14:47:31 +0100 Subject: [PATCH 10/15] Add BATS/TBATS to pmdarima flavour --- .../test_local_forecasting_models.py | 8 +-- .../forecasting/test_probabilistic_models.py | 58 +++++++++++-------- requirements/core.txt | 1 - requirements/pmdarima.txt | 1 + 4 files changed, 39 insertions(+), 29 deletions(-) diff --git a/darts/tests/models/forecasting/test_local_forecasting_models.py b/darts/tests/models/forecasting/test_local_forecasting_models.py index 5795c77e4f..b8cf77b763 100644 --- a/darts/tests/models/forecasting/test_local_forecasting_models.py +++ b/darts/tests/models/forecasting/test_local_forecasting_models.py @@ -6,9 +6,7 @@ from darts.metrics import mape from darts.models import ( ARIMA, - BATS, FFT, - TBATS, VARIMA, ExponentialSmoothing, FourTheta, @@ -46,8 +44,6 @@ (FourTheta(trend_mode=TrendMode.EXPONENTIAL), 5.5), (FourTheta(model_mode=ModelMode.MULTIPLICATIVE), 11.4), (FourTheta(season_mode=SeasonalityMode.ADDITIVE), 14.2), - (TBATS(use_trend=True, use_arma_errors=True, use_box_cox=True), 8.0), - (BATS(use_trend=True, use_arma_errors=True, use_box_cox=True), 10.0), (FFT(trend="poly"), 11.4), (NaiveSeasonal(), 32.4), (KalmanForecaster(dim_x=3), 17.0), @@ -78,9 +74,11 @@ logger.warning("Prophet not installed - will be skipping Prophet tests") try: - from darts.models import AutoARIMA + from darts.models import BATS, TBATS, AutoARIMA models.append((AutoARIMA(), 12.2)) + models.append((TBATS(use_trend=True, use_arma_errors=True, use_box_cox=True), 8.0)) + models.append((BATS(use_trend=True, use_arma_errors=True, use_box_cox=True), 10.0)) dual_models.append(AutoARIMA()) PMDARIMA_AVAILABLE = True except ImportError: diff --git a/darts/tests/models/forecasting/test_probabilistic_models.py b/darts/tests/models/forecasting/test_probabilistic_models.py index 4d980a060d..9ca5041a82 100644 --- a/darts/tests/models/forecasting/test_probabilistic_models.py +++ b/darts/tests/models/forecasting/test_probabilistic_models.py @@ -3,13 +3,21 @@ from darts import TimeSeries from darts.logging import get_logger from darts.metrics import mae -from darts.models import ARIMA, BATS, TBATS, ExponentialSmoothing +from darts.models import ARIMA, ExponentialSmoothing from darts.models.forecasting.forecasting_model import GlobalForecastingModel from darts.tests.base_test_class import DartsBaseTestClass from darts.utils import timeseries_generation as tg logger = get_logger(__name__) +try: + from darts.models import BATS, TBATS + + PMDARIMA_AVAILABLE = True +except ImportError: + logger.warning("pmdarima not available. BATS/TBATS probabilistic tests skipped.") + PMDARIMA_AVAILABLE = False + try: import torch @@ -48,30 +56,34 @@ models_cls_kwargs_errs = [ (ExponentialSmoothing, {}, 0.3), (ARIMA, {"p": 1, "d": 0, "q": 1}, 0.03), - ( - BATS, - { - "use_trend": False, - "use_damped_trend": False, - "use_box_cox": True, - "use_arma_errors": False, - "random_state": 42, - }, - 0.3, - ), - ( - TBATS, - { - "use_trend": False, - "use_damped_trend": False, - "use_box_cox": True, - "use_arma_errors": False, - "random_state": 42, - }, - 0.3, - ), ] +if PMDARIMA_AVAILABLE: + models_cls_kwargs_errs += [ + ( + BATS, + { + "use_trend": False, + "use_damped_trend": False, + "use_box_cox": True, + "use_arma_errors": False, + "random_state": 42, + }, + 0.3, + ), + ( + TBATS, + { + "use_trend": False, + "use_damped_trend": False, + "use_box_cox": True, + "use_arma_errors": False, + "random_state": 42, + }, + 0.3, + ), + ] + if TORCH_AVAILABLE: models_cls_kwargs_errs += [ ( diff --git a/requirements/core.txt b/requirements/core.txt index fe1bc6df71..739f073a95 100644 --- a/requirements/core.txt +++ b/requirements/core.txt @@ -10,6 +10,5 @@ requests>=2.22.0 scikit-learn>=1.0.1 scipy>=1.3.2 statsmodels>=0.13.0 -tbats>=1.1.0 tqdm>=4.60.0 xarray>=0.17.0 diff --git a/requirements/pmdarima.txt b/requirements/pmdarima.txt index 5fa71e5333..e73f98849e 100644 --- a/requirements/pmdarima.txt +++ b/requirements/pmdarima.txt @@ -1 +1,2 @@ pmdarima>=1.8.0 +tbats>=1.1.0 From 1bfd84f6ff60a77eddf3fbfef75def10df7e304e Mon Sep 17 00:00:00 2001 From: Julien Herzen Date: Tue, 8 Mar 2022 14:59:40 +0100 Subject: [PATCH 11/15] better frequency support --- darts/models/forecasting/tbats.py | 31 +++++++++++++++++++++---------- 1 file changed, 21 insertions(+), 10 deletions(-) diff --git a/darts/models/forecasting/tbats.py b/darts/models/forecasting/tbats.py index 6f9964730b..be40cc07ba 100644 --- a/darts/models/forecasting/tbats.py +++ b/darts/models/forecasting/tbats.py @@ -39,20 +39,31 @@ def _seasonality_from_freq(series: TimeSeries): """ Infer a naive seasonality based on the frequency """ + if series.has_range_index: - return [12] - elif series.freq_str == "B": + return None + + freq = series.freq_str + + if freq in ["B", "C"]: return [5] - elif series.freq_str == "D": + elif freq == "D": return [7] - elif series.freq_str == "W": + elif freq == "W": return [52] - elif series.freq_str in ["MS", "M"]: - return [12] - elif series.freq_str == ["Q", "BQ", "QS", "BQS"]: - return [4] - elif series.freq_str == ["H"]: - return [24] + elif freq in ["M", "BM", "CBM", "SM"] or freq.startswith( + ("M", "BM", "BS", "CBM", "SM") + ): + return [12] # month + elif freq in ["Q", "BQ", "REQ"] or freq.startswith(("Q", "BQ", "REQ")): + return [4] # quarter + elif freq in ["H", "BH", "CBH"]: + return [24] # hour + elif freq in ["T", "min"]: + return [60] # minute + elif freq == "S": + return [60] # second + return None From bf24d6d1feb7e86ccea463984a30c1bdb4b2c82c Mon Sep 17 00:00:00 2001 From: Julien Herzen Date: Tue, 8 Mar 2022 15:07:25 +0100 Subject: [PATCH 12/15] Update darts/models/forecasting/tbats.py Co-authored-by: Dennis Bader --- darts/models/forecasting/tbats.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/darts/models/forecasting/tbats.py b/darts/models/forecasting/tbats.py index be40cc07ba..51b4d9b33d 100644 --- a/darts/models/forecasting/tbats.py +++ b/darts/models/forecasting/tbats.py @@ -147,7 +147,7 @@ def __init__( Minimal and maximal Box-Cox parameter values. use_trend Indicates whether to include a trend or not. - When None both cases shall be considered and better is selected by AIC. + When None, both cases shall be considered and the better one is selected by AIC. use_damped_trend Indicates whether to include a damping parameter in the trend or not. Applies only when trend is used. From f325847e64eb7ec37d8858131b900274e25cb9f7 Mon Sep 17 00:00:00 2001 From: Julien Herzen Date: Tue, 8 Mar 2022 15:07:32 +0100 Subject: [PATCH 13/15] Update darts/models/forecasting/tbats.py Co-authored-by: Dennis Bader --- darts/models/forecasting/tbats.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/darts/models/forecasting/tbats.py b/darts/models/forecasting/tbats.py index 51b4d9b33d..b7d592f54b 100644 --- a/darts/models/forecasting/tbats.py +++ b/darts/models/forecasting/tbats.py @@ -151,7 +151,7 @@ def __init__( use_damped_trend Indicates whether to include a damping parameter in the trend or not. Applies only when trend is used. - When None both cases shall be considered and better is selected by AIC. + When None, both cases shall be considered and the better one is selected by AIC. seasonal_periods Length of each of the periods (amount of observations in each period). TBATS accepts int and float values here. From 717675d588e488ecaa497d0988dee4d50596c79b Mon Sep 17 00:00:00 2001 From: Julien Herzen Date: Tue, 8 Mar 2022 15:09:02 +0100 Subject: [PATCH 14/15] some PR comments --- darts/models/forecasting/tbats.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/darts/models/forecasting/tbats.py b/darts/models/forecasting/tbats.py index be40cc07ba..57c3ffb56a 100644 --- a/darts/models/forecasting/tbats.py +++ b/darts/models/forecasting/tbats.py @@ -21,7 +21,7 @@ """ from abc import ABC, abstractmethod -from typing import List, Optional, Tuple +from typing import List, Optional, Tuple, Union import numpy as np from scipy.special import inv_boxcox @@ -118,13 +118,12 @@ def __init__( box_cox_bounds: Tuple = (0, 1), use_trend: Optional[bool] = None, use_damped_trend: Optional[bool] = None, - seasonal_periods: Optional[List] = "freq", + seasonal_periods: Optional[Union[str, List]] = "freq", use_arma_errors: Optional[bool] = True, show_warnings: bool = False, n_jobs: Optional[int] = None, multiprocessing_start_method: Optional[str] = "spawn", random_state: int = 0, - **kwargs, ): """ @@ -172,6 +171,8 @@ def __init__( multiprocessing_start_method: str, optional (default='spawn') How threads should be started. See https://docs.python.org/3/library/multiprocessing.html#contexts-and-start-methods + random_state + Sets the underlying random seed at model initialization time. """ super().__init__() From 44a5fbf872a15fd10e6c33d1a6a1a6c99eef2d86 Mon Sep 17 00:00:00 2001 From: Julien Herzen Date: Tue, 8 Mar 2022 15:13:57 +0100 Subject: [PATCH 15/15] address PR comments --- darts/models/forecasting/tbats.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/darts/models/forecasting/tbats.py b/darts/models/forecasting/tbats.py index 06f3156a08..619361acd3 100644 --- a/darts/models/forecasting/tbats.py +++ b/darts/models/forecasting/tbats.py @@ -141,16 +141,16 @@ def __init__( ---------- use_box_cox If Box-Cox transformation of original series should be applied. - When None both cases shall be considered and better is selected by AIC. + When ``None`` both cases shall be considered and better is selected by AIC. box_cox_bounds Minimal and maximal Box-Cox parameter values. use_trend Indicates whether to include a trend or not. - When None, both cases shall be considered and the better one is selected by AIC. + When ``None``, both cases shall be considered and the better one is selected by AIC. use_damped_trend Indicates whether to include a damping parameter in the trend or not. Applies only when trend is used. - When None, both cases shall be considered and the better one is selected by AIC. + When ``None``, both cases shall be considered and the better one is selected by AIC. seasonal_periods Length of each of the periods (amount of observations in each period). TBATS accepts int and float values here. @@ -162,13 +162,13 @@ def __init__( use_arma_errors When True BATS will try to improve the model by modelling residuals with ARMA. Best model will be selected by AIC. - If False, ARMA residuals modeling will not be considered. + If ``False``, ARMA residuals modeling will not be considered. show_warnings If warnings should be shown or not. - n_jobs: int, optional (default=None) + n_jobs How many jobs to run in parallel when fitting BATS model. When not provided BATS shall try to utilize all available cpu cores. - multiprocessing_start_method: str, optional (default='spawn') + multiprocessing_start_method How threads should be started. See https://docs.python.org/3/library/multiprocessing.html#contexts-and-start-methods random_state @@ -199,14 +199,15 @@ def __str__(self): @abstractmethod def _create_model(self): pass - # return tbats_TBATS(**self.kwargs) def fit(self, series: TimeSeries): super().fit(series) series = self.training_series if self.infer_seasonal_periods: - self.kwargs["seasonal_periods"] = _seasonality_from_freq(series) + seasonality = _seasonality_from_freq(series) + self.kwargs["seasonal_periods"] = seasonality + self.seasonal_periods = seasonality model = self._create_model() fitted_model = model.fit(series.values())