-
Notifications
You must be signed in to change notification settings - Fork 917
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Refactor/fit args #161
Refactor/fit args #161
Conversation
Co-authored-by: Julien Herzen <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks good, I think it's a good move. The only drawback I can see is that it makes the specification of validation time series for fit() of Torch models a bit more complex. @pennfranc what's your opinion?
darts/models/forecasting_model.py
Outdated
covariate_series | ||
The training time series on which to fit the model (can be multivariate or univariate). | ||
target_series | ||
The target values used ad dependent variables when training the model |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The target can also be multivariate or univariate. I would unify the comments here.
darts/models/forecasting_model.py
Outdated
|
||
self.target_indices = target_indices | ||
super().fit(series) | ||
raise_if_not(len(covariate_series) == len(target_series), "covariate_series and target_series musth have same " |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You should probably compare the time indexes here instead of the lengths.
Optionally, a validation time series, which will be used to compute the validation loss | ||
throughout training and keep track of the best performing models. | ||
Optionally, 2 validation time series (covariate and target), which will be used to compute the validation | ||
loss throughout training and keep track of the best performing models. | ||
verbose | ||
Optionally, whether to print progress. | ||
target_indices |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
to remove
target_indices: Optional[List[int]] = None) -> None: | ||
covariate_series: TimeSeries, | ||
target_series: Optional[TimeSeries] = None, | ||
val_series: Optional[Tuple[TimeSeries]] = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How about splitting in two:
val_covariate_series: Optional[TimeSeries]
val_target_series: Optional[TimeSeries]
That would be more consistent with the first two arguments.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I did it like this first but having a Tuple makes it more easy to check that the length provided is either None or 2 as opposed to having two separate args
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would say that it's more important to have a good signature here, even if the check is slightly more complex behind the scenes.
series | ||
the training time series on which to fit the model | ||
Implements behavior that should happen when calling the `fit` method of every forcasting model regardless of | ||
wether they are univariate or multivariate. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
whether (sorry)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No problem :) you can directly propose edits for such typo
""" | ||
raise_if_not(len(series) >= self.min_train_series_length, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As far as I can see this is currently not being tested for the univariate model.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
good catch I moved this check in the parent class
series | ||
the training time series on which to fit the model | ||
Implements behavior that should happen when calling the `fit` method of every forcasting model regardless of | ||
wether they are univariate or multivariate. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
wether they are univariate or multivariate. | |
whether they are univariate or multivariate. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I didn't know you could do this!
The time series to be included in the dataset. | ||
target_series | ||
The time series used has target. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
as?
Overall I really like this change, just a couple of things I'm not quite sure about:
Edit: Ok looking at your example above I think I get why you did it this way, this provides the user with all possible combinations and with an example it's pretty intuitive! I still think @hrzn I agree that it requires a bit more effort to list all the fit arguments now, but I think this is a price worth paying. |
I renamed |
* add .DS_Store to .gitignore * add proposal.md * add draft version of backtest forcasting * add backtest to model (simple refactoring) * extract backtest sanity checks in a method * extract building fit_kwargs and predict_kwargs in a method * minor fix import comment and assertion * refactor all backtest factoring tests * update progress on proposal.md * add coverage.sh * fix permission on coverage.sh * improve coverage sh script * add coverage.xml to .igtignore * improve doc on coverage.sh * fix doc * fix doc for real * univariate fcast model only support univariate ts * MultivariateFcasModel fits on the whole training ts * refactor torch forcasting model to use covariate_series * fix unused imports * allow to specify only covaraite_series * enforce covariate_series and target_series inputs for multivariate model * adapt torch datasets to use covariate / target series * adapt validation series provided as a Tuple * fix typo * adapt create_dataset on tcn model * remove component index from fit function * adapt tests to new syntax * add proposal.md * add draft version of backtest forcasting * add backtest to model (simple refactoring) * extract backtest sanity checks in a method * extract building fit_kwargs and predict_kwargs in a method * minor fix import comment and assertion * refactor all backtest factoring tests * update progress on proposal.md * fix doc * fix doc for real * fix typos and remove diagram in backtest doc * WIP add residuals * add decorator for sanity checks * clean forecasting_model * add start multitype parameter support * fix check on undefined param in sanity checks * add comments * fix(backtesting, tests): fixed bugs so that all forecasting backtest tests pass, corrected some typos * feature(backtesting): changed handling of residuals (re-introduced own function instead of being by-product of backtest) * fix(test_forecasting_model): deleted old file that was renamed due to type * feat(backtesting): moved gridsearch to ForecastingModel, removed functions from backtesting module that have been moved to ForecastingModel class, adapted tests * feat(backtesting): adapted docstring of gridsearch function * fix(Theta): adapted FourTheta model to use new gridsearch function * fix(forecasting_model, torch_forecasting_model): fixed docstrings * feat(backtesting): moved backtest_regression to regression model class * fix(forecasting_model): renamed covariate_series to training_series * fix(forecasting_model): fixed residuals function * fix(style): linter * feat(backtesting): renamed backtest_gridsearch to gridsearch * fix(tests): fixed residuals test case * feat(backtesting): moved residuals plotting function to statistics module * feat(backtesting): removed backtesting module * fix(style): linter * fix(style): linter * fix(torch_forecasting_model): fixed check in predict function * fix(forecasting_model): fixed backtest sanity check * fix(torch_forecasting_model): removed unnecessary (and bug-causing) sanity check method * feat(examples): refactored notebooks to support new function signatures * fix(style): linter * updated PROPOSAL.md * feat(forecasting_model): improved documentation * fix(torch_forecasting_model): removed redundant function * style(torch_forecasting_model): linter * fix(torch_forecasting_model): fixed docstring typo * fix(torch_forecasting_model, tests): clean up old comments * fix(statistics): improved docstrings * fix(forecasting_model, regression_model): improved variable names, fixed documentation * fix(tests): fixed old variable name in backtesting tests * removed PROPOSAL.md * feat(regression_model): added stride functionality to backtest method * fix(forecasting_model, regression_model): improved documentation * fix(forecasting_model): improved documentation * fix(forecasting_model): improved start parameter documentation * fix(forecasting_model, regression_model): cleaned up code, improved docstrings, added missing checks * feat(forecasting_model): improved backtest docstring * fix(forecasting_model, tests): improved backtest sanity checks, added corresponding test cases * feat(backtesting): replaced 'num_predictions' parameter by 'start' parameter in 'ForecastingModel.gridsearch' * fix(examples): updated notebooks Co-authored-by: Guillaume Raille <[email protected]> Co-authored-by: pennfranc <[email protected]>
Fixes #DARTS-164.
Summary
target_indices
parameter to all fit methods on multivariate modelscomponent_index
parameter to all fit method on univariate modelscovariate_series
andtarget_series
parameters to replace previous syntaxNote: backtesting test are not passing but they will soon be refactored so I didn't invest time on that.
New vs Old API: