Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/four theta #123

Merged
merged 49 commits into from
Jul 20, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
49 commits
Select commit Hold shift + click to select a range
01e1837
feat(4Theta): naive implementation of 4Theta model
Droxef Jul 3, 2020
f6831be
fix(theta): avoid NaN values in theta, and unnecessary season test
Droxef Jul 3, 2020
53ceeab
feat(gridsearch): add possibility to compare with model.fitted_values
Droxef Jul 3, 2020
7ae0803
feat(4theta): add a method to auto select best model
Droxef Jul 3, 2020
4001009
Merge branch 'develop' of https://github.com/unit8co/darts into feat/…
Droxef Jul 3, 2020
264ddf3
refactor(4Theta): Specify univariate model
Droxef Jul 3, 2020
e5bc790
style(4Theta): Fix linter
Droxef Jul 3, 2020
e951554
style(4Theta): fix docstring
Droxef Jul 3, 2020
f596476
style(4Theta): Fix docstring
Droxef Jul 3, 2020
b712c7d
style(4Theta): Change link
Droxef Jul 3, 2020
dee5b0d
style(4Theta): Correct docstring
Droxef Jul 3, 2020
ba64351
style(4theta): correct ticks in docstring
Droxef Jul 6, 2020
36d99fe
Merge branch 'develop' of https://github.com/unit8co/darts into feat/…
Droxef Jul 7, 2020
1dc5bb3
Merge branch 'develop' of https://github.com/unit8co/darts into feat/…
Droxef Jul 7, 2020
4d24f75
refactor(4Theta): change different modes verification and add Enum
Droxef Jul 7, 2020
240cea4
refactor(theta): replace all string modes by Enum
Droxef Jul 8, 2020
7d5dffa
test(backtesting): Add a test to verify if fitted_values exist
Droxef Jul 8, 2020
b19cf97
Fix(Theta): Correct all Enums
Droxef Jul 8, 2020
bbd58cc
fix(Theta): compare with enum members value instead. Correct some min…
Droxef Jul 8, 2020
71aa8d3
fix(4theta): move the creation of enums in init file
Droxef Jul 8, 2020
8286062
test(4theta): Add 4Theta to autoregressive test. Move Enums to top in…
Droxef Jul 8, 2020
af69e63
test(4theta): Add 4Theta specific test
Droxef Jul 8, 2020
97f6956
style(backtesting): fix lint
Droxef Jul 8, 2020
97cff2a
test(4theta): Add another exception to test
Droxef Jul 9, 2020
f6df487
ref(4Theta): mode.fitted_values is now a TimeSeries to be consistent
Droxef Jul 13, 2020
2bad4f7
style(Theta): rename mode to season_mode to be consistent w/ FourTheta
Droxef Jul 13, 2020
49fafda
Merge branch 'develop' of https://github.com/unit8co/darts into feat/…
Droxef Jul 13, 2020
fe1b320
docs(thetas): correct errors in the different docs
Droxef Jul 13, 2020
fc04e68
refactor(4Theta): Correct
Droxef Jul 13, 2020
cae49f6
refactor(backtesting): add a 'use_fitted_values' parameter
Droxef Jul 13, 2020
27b9143
fix(4theta): correct select_best_model
Droxef Jul 13, 2020
2fa7f4f
test(4Theta): add a test for zero mean and correct others
Droxef Jul 13, 2020
5ed92bd
Merge branch 'develop' of https://github.com/unit8co/darts into feat/…
Droxef Jul 14, 2020
a2b64e4
style(backtesting): linter formatting
Droxef Jul 14, 2020
98d1b7a
refactor(4Theta): change Enums names, correct theta and backtesting docs
Droxef Jul 16, 2020
f405e8b
refactor(4theta): move creation of fitted_values timeseries to backte…
Droxef Jul 16, 2020
9cad7cf
Merge branch 'develop' of https://github.com/unit8co/darts into feat/…
Droxef Jul 16, 2020
1e6be8e
refactor(statistics): include Enums in extract and remove functions
Droxef Jul 16, 2020
2f9bc5b
refactor(4Theta): check earlier if univariate
Droxef Jul 16, 2020
a405e47
test(4Theta): correct backtesting and test best_model
Droxef Jul 16, 2020
25f3de2
test(4Theta): add new modes in test models
Droxef Jul 16, 2020
e15f52b
docs(4theta): Add a disclaimer for 4theta performance
Droxef Jul 16, 2020
61c0f17
refactor(Theta): change theta to have the same behavior as FourTheta
Droxef Jul 16, 2020
e63e215
examples(darts-intro): modify notebook to give the same results
Droxef Jul 16, 2020
be1427d
style(4Theta): correct deprecation warning for logger.warn
Droxef Jul 16, 2020
34a8c41
Merge branch 'develop' into feat/FourTheta
hrzn Jul 17, 2020
d8e890f
style(4theta): move comment to backtesting
Droxef Jul 20, 2020
e02d65d
Merge branch 'feat/FourTheta' of https://github.com/unit8co/darts int…
Droxef Jul 20, 2020
8e92eba
Merge branch 'develop' of https://github.com/unit8co/darts into feat/…
Droxef Jul 20, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions darts/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,24 @@

from .timeseries import TimeSeries

# Enums
from enum import Enum


class SeasonalityMode(Enum):
MULTIPLICATIVE = 'multiplicative'
ADDITIVE = 'additive'
NONE = None


class TrendMode(Enum):
LINEAR = 'linear'
EXPONENTIAL = 'exponential'


class ModelMode(Enum):
MULTIPLICATIVE = 'multiplicative'
ADDITIVE = 'additive'


__version__ = 'dev'
43 changes: 32 additions & 11 deletions darts/backtesting/backtesting.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,12 +341,13 @@ def backtest_gridsearch(model_class: type,
use_full_output_length: bool = True,
val_series: Optional[TimeSeries] = None,
num_predictions: int = 10,
use_fitted_values: bool = False,
metric: Callable[[TimeSeries, TimeSeries], float] = metrics.mape,
verbose=False):
""" A function for finding the best hyperparameters.

This function has 2 modes of operation: Expanding window mode and split mode.
Both modes of operation evaluate every possible combination of hyperparameter values
This function has 3 modes of operation: Expanding window mode, split mode and comparison with fitted values.
The three modes of operation evaluate every possible combination of hyperparameter values
provided in the `parameters` dictionary by instantiating the `model_class` subclass
of ForecastingModel with each combination, and returning the best-performing model with regards
to the `metric` function. The `metric` function is expected to return an error value,
Expand All @@ -364,17 +365,24 @@ def backtest_gridsearch(model_class: type,
For every hyperparameter combination, the model is trained on `train_series` and
evaluated on `val_series`.

Comparison with fitted values (activated when `use_fitted_values` is passed):
For every hyperparameter combination, the model is trained on `train_series` and evaluated on the resulting
fitted values.
Not all models have fitted values, and this method raises an error if `model.fitted_values` doesn't exist.
The fitted values are the result of the fit of the model on the training series. Comparing with the fitted values
can be a quick way to assess the model, but one cannot see if the model overfits or underfits.


Parameters
----------
model
model_class
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I let this slip through, thanks!

The ForecastingModel subclass to be tuned for 'series'.
parameters
A dictionary containing as keys hyperparameter names, and as values lists of values for the
respective hyperparameter.
train_series
The univariate TimeSeries instance used for training (and also validation in split mode).
test_series
val_series
The univariate TimeSeries instance used for validation in split mode.
fcast_horizon_n
The integer value of the forecasting horizon used in expanding window mode.
Expand All @@ -389,6 +397,9 @@ def backtest_gridsearch(model_class: type,
as argument to the predict method of `model`.
num_predictions:
The number of train/prediction cycles performed in one iteration of expanding window mode.
use_fitted_values
If `True`, uses the comparison with the fitted values.
Raises an error if `fitted_values` is not an attribute of `model_class`.
metric:
A function that takes two TimeSeries instances as inputs and returns a float error value.
verbose:
Expand All @@ -399,18 +410,22 @@ def backtest_gridsearch(model_class: type,
ForecastingModel
An untrained 'model_class' instance with the best-performing hyperparameters from the given selection.
"""
raise_if_not((fcast_horizon_n is not None) + (val_series is not None) + use_fitted_values == 1,
"Please pass exactly one of the arguments 'forecast_horizon_n', 'val_series' or 'use_fitted_values'.",
logger)

if (val_series is not None):
if use_fitted_values:
model = model_class()
raise_if_not(hasattr(model, "fitted_values"), "The model must have a fitted_values attribute"
" to compare with the train TimeSeries", logger)
elif val_series is not None:
raise_if_not(train_series.width == val_series.width, "Training and validation series require the same"
" number of components.", logger)

raise_if_not((fcast_horizon_n is None) ^ (val_series is None),
"Please pass exactly one of the arguments 'forecast_horizon_n' or 'val_series'.", logger)

fit_kwargs, predict_kwargs = _create_parameter_dicts(model_class(), target_indices, component_index,
use_full_output_length)

if val_series is None:
if (val_series is None) and (not use_fitted_values):
backtest_start_time = train_series.end_time() - (num_predictions + fcast_horizon_n) * train_series.freq()
min_error = float('inf')
best_param_combination = {}
Expand All @@ -423,7 +438,13 @@ def backtest_gridsearch(model_class: type,
for param_combination in iterator:
param_combination_dict = dict(list(zip(parameters.keys(), param_combination)))
model = model_class(**param_combination_dict)
if val_series is None: # expanding window mode
if use_fitted_values:
model.fit(train_series)
# Takes too much time to create a TimeSeries
# Overhead: 2-10 ms in average
fitted_values = TimeSeries.from_times_and_values(train_series.time_index(), model.fitted_values)
error = metric(fitted_values, train_series)
elif val_series is None: # expanding window mode
backtest_forecast = backtest_forecasting(train_series, model, backtest_start_time, fcast_horizon_n,
target_indices, component_index, use_full_output_length)
error = metric(backtest_forecast, train_series)
Expand Down Expand Up @@ -499,7 +520,7 @@ def explore_models(train_series: TimeSeries,
'trend': [None, 'poly', 'exp']
}),
(Theta, {
'theta': np.delete(np.linspace(-10, 10, 51), 30)
'theta': np.delete(np.linspace(-10, 10, 51), 25)
}),
(Prophet, {}),
(AutoARIMA, {})
Expand Down
2 changes: 1 addition & 1 deletion darts/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from .exponential_smoothing import ExponentialSmoothing
from .rnn_model import RNNModel
from .tcn_model import TCNModel
from .theta import Theta
from .theta import Theta, FourTheta
from .fft import FFT

# Regression
Expand Down
Loading