Skip to content

Commit 55e4e42

Browse files
Droxefhrzn
andauthored
Feature/four theta (#123)
* feat(4Theta): naive implementation of 4Theta model * fix(theta): avoid NaN values in theta, and unnecessary season test * feat(gridsearch): add possibility to compare with model.fitted_values * feat(4theta): add a method to auto select best model * refactor(4Theta): Specify univariate model * style(4Theta): Fix linter * style(4Theta): fix docstring * style(4Theta): Fix docstring * style(4Theta): Change link * style(4Theta): Correct docstring * style(4theta): correct ticks in docstring * refactor(4Theta): change different modes verification and add Enum * refactor(theta): replace all string modes by Enum * test(backtesting): Add a test to verify if fitted_values exist * Fix(Theta): Correct all Enums * fix(Theta): compare with enum members value instead. Correct some minor bugs * fix(4theta): move the creation of enums in init file * test(4theta): Add 4Theta to autoregressive test. Move Enums to top init file * test(4theta): Add 4Theta specific test * style(backtesting): fix lint * test(4theta): Add another exception to test * ref(4Theta): mode.fitted_values is now a TimeSeries to be consistent * style(Theta): rename mode to season_mode to be consistent w/ FourTheta * docs(thetas): correct errors in the different docs * refactor(4Theta): Correct * Add normalization choice * Add comment to be clearer * Correct the docs * clean the code and add a check on mean=0 * refactor(backtesting): add a 'use_fitted_values' parameter * fix(4theta): correct select_best_model * test(4Theta): add a test for zero mean and correct others * style(backtesting): linter formatting * refactor(4Theta): change Enums names, correct theta and backtesting docs * refactor(4theta): move creation of fitted_values timeseries to backtesting * refactor(statistics): include Enums in extract and remove functions * refactor(4Theta): check earlier if univariate * test(4Theta): correct backtesting and test best_model * test(4Theta): add new modes in test models * docs(4theta): Add a disclaimer for 4theta performance * refactor(Theta): change theta to have the same behavior as FourTheta * examples(darts-intro): modify notebook to give the same results * style(4Theta): correct deprecation warning for logger.warn * style(4theta): move comment to backtesting Co-authored-by: Julien Herzen <[email protected]>
1 parent 0c7298d commit 55e4e42

File tree

8 files changed

+450
-55
lines changed

8 files changed

+450
-55
lines changed

darts/__init__.py

+20
Original file line numberDiff line numberDiff line change
@@ -5,4 +5,24 @@
55

66
from .timeseries import TimeSeries
77

8+
# Enums
9+
from enum import Enum
10+
11+
12+
class SeasonalityMode(Enum):
13+
MULTIPLICATIVE = 'multiplicative'
14+
ADDITIVE = 'additive'
15+
NONE = None
16+
17+
18+
class TrendMode(Enum):
19+
LINEAR = 'linear'
20+
EXPONENTIAL = 'exponential'
21+
22+
23+
class ModelMode(Enum):
24+
MULTIPLICATIVE = 'multiplicative'
25+
ADDITIVE = 'additive'
26+
27+
828
__version__ = 'dev'

darts/backtesting/backtesting.py

+32-11
Original file line numberDiff line numberDiff line change
@@ -341,12 +341,13 @@ def backtest_gridsearch(model_class: type,
341341
use_full_output_length: bool = True,
342342
val_series: Optional[TimeSeries] = None,
343343
num_predictions: int = 10,
344+
use_fitted_values: bool = False,
344345
metric: Callable[[TimeSeries, TimeSeries], float] = metrics.mape,
345346
verbose=False):
346347
""" A function for finding the best hyperparameters.
347348
348-
This function has 2 modes of operation: Expanding window mode and split mode.
349-
Both modes of operation evaluate every possible combination of hyperparameter values
349+
This function has 3 modes of operation: Expanding window mode, split mode and comparison with fitted values.
350+
The three modes of operation evaluate every possible combination of hyperparameter values
350351
provided in the `parameters` dictionary by instantiating the `model_class` subclass
351352
of ForecastingModel with each combination, and returning the best-performing model with regards
352353
to the `metric` function. The `metric` function is expected to return an error value,
@@ -364,17 +365,24 @@ def backtest_gridsearch(model_class: type,
364365
For every hyperparameter combination, the model is trained on `train_series` and
365366
evaluated on `val_series`.
366367
368+
Comparison with fitted values (activated when `use_fitted_values` is passed):
369+
For every hyperparameter combination, the model is trained on `train_series` and evaluated on the resulting
370+
fitted values.
371+
Not all models have fitted values, and this method raises an error if `model.fitted_values` doesn't exist.
372+
The fitted values are the result of the fit of the model on the training series. Comparing with the fitted values
373+
can be a quick way to assess the model, but one cannot see if the model overfits or underfits.
374+
367375
368376
Parameters
369377
----------
370-
model
378+
model_class
371379
The ForecastingModel subclass to be tuned for 'series'.
372380
parameters
373381
A dictionary containing as keys hyperparameter names, and as values lists of values for the
374382
respective hyperparameter.
375383
train_series
376384
The univariate TimeSeries instance used for training (and also validation in split mode).
377-
test_series
385+
val_series
378386
The univariate TimeSeries instance used for validation in split mode.
379387
fcast_horizon_n
380388
The integer value of the forecasting horizon used in expanding window mode.
@@ -389,6 +397,9 @@ def backtest_gridsearch(model_class: type,
389397
as argument to the predict method of `model`.
390398
num_predictions:
391399
The number of train/prediction cycles performed in one iteration of expanding window mode.
400+
use_fitted_values
401+
If `True`, uses the comparison with the fitted values.
402+
Raises an error if `fitted_values` is not an attribute of `model_class`.
392403
metric:
393404
A function that takes two TimeSeries instances as inputs and returns a float error value.
394405
verbose:
@@ -399,18 +410,22 @@ def backtest_gridsearch(model_class: type,
399410
ForecastingModel
400411
An untrained 'model_class' instance with the best-performing hyperparameters from the given selection.
401412
"""
413+
raise_if_not((fcast_horizon_n is not None) + (val_series is not None) + use_fitted_values == 1,
414+
"Please pass exactly one of the arguments 'forecast_horizon_n', 'val_series' or 'use_fitted_values'.",
415+
logger)
402416

403-
if (val_series is not None):
417+
if use_fitted_values:
418+
model = model_class()
419+
raise_if_not(hasattr(model, "fitted_values"), "The model must have a fitted_values attribute"
420+
" to compare with the train TimeSeries", logger)
421+
elif val_series is not None:
404422
raise_if_not(train_series.width == val_series.width, "Training and validation series require the same"
405423
" number of components.", logger)
406424

407-
raise_if_not((fcast_horizon_n is None) ^ (val_series is None),
408-
"Please pass exactly one of the arguments 'forecast_horizon_n' or 'val_series'.", logger)
409-
410425
fit_kwargs, predict_kwargs = _create_parameter_dicts(model_class(), target_indices, component_index,
411426
use_full_output_length)
412427

413-
if val_series is None:
428+
if (val_series is None) and (not use_fitted_values):
414429
backtest_start_time = train_series.end_time() - (num_predictions + fcast_horizon_n) * train_series.freq()
415430
min_error = float('inf')
416431
best_param_combination = {}
@@ -423,7 +438,13 @@ def backtest_gridsearch(model_class: type,
423438
for param_combination in iterator:
424439
param_combination_dict = dict(list(zip(parameters.keys(), param_combination)))
425440
model = model_class(**param_combination_dict)
426-
if val_series is None: # expanding window mode
441+
if use_fitted_values:
442+
model.fit(train_series)
443+
# Takes too much time to create a TimeSeries
444+
# Overhead: 2-10 ms in average
445+
fitted_values = TimeSeries.from_times_and_values(train_series.time_index(), model.fitted_values)
446+
error = metric(fitted_values, train_series)
447+
elif val_series is None: # expanding window mode
427448
backtest_forecast = backtest_forecasting(train_series, model, backtest_start_time, fcast_horizon_n,
428449
target_indices, component_index, use_full_output_length)
429450
error = metric(backtest_forecast, train_series)
@@ -499,7 +520,7 @@ def explore_models(train_series: TimeSeries,
499520
'trend': [None, 'poly', 'exp']
500521
}),
501522
(Theta, {
502-
'theta': np.delete(np.linspace(-10, 10, 51), 30)
523+
'theta': np.delete(np.linspace(-10, 10, 51), 25)
503524
}),
504525
(Prophet, {}),
505526
(AutoARIMA, {})

darts/models/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from .exponential_smoothing import ExponentialSmoothing
1111
from .rnn_model import RNNModel
1212
from .tcn_model import TCNModel
13-
from .theta import Theta
13+
from .theta import Theta, FourTheta
1414
from .fft import FFT
1515

1616
# Regression

0 commit comments

Comments
 (0)