-
Notifications
You must be signed in to change notification settings - Fork 918
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Feat/lagged features names #1679
Changes from 7 commits
e7efa59
cc77936
dbc942b
10060b7
69a3819
b798c50
f48ea8a
9fa93cb
89dbb47
47e4214
b496881
8d2a03e
1b624f5
76ee1c8
38c10a0
1325983
dd3798b
557dbfe
05428e5
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -35,7 +35,10 @@ | |||||
from darts.logging import get_logger, raise_if, raise_if_not, raise_log | ||||||
from darts.models.forecasting.forecasting_model import GlobalForecastingModel | ||||||
from darts.timeseries import TimeSeries | ||||||
from darts.utils.data.tabularization import create_lagged_training_data | ||||||
from darts.utils.data.tabularization import ( | ||||||
create_lagged_components_names, | ||||||
create_lagged_training_data, | ||||||
) | ||||||
from darts.utils.multioutput import MultiOutputRegressor | ||||||
from darts.utils.utils import ( | ||||||
_check_quantiles, | ||||||
|
@@ -358,6 +361,32 @@ def _create_lagged_data( | |||||
|
||||||
return training_samples, training_labels | ||||||
|
||||||
def _create_lagged_components_name( | ||||||
self, target_series, past_covariates, future_covariates | ||||||
): | ||||||
lags = self.lags.get("target") | ||||||
lags_past_covariates = self.lags.get("past") | ||||||
lags_future_covariates = self.lags.get("future") | ||||||
|
||||||
features_cols_name, labels_cols_name = create_lagged_components_names( | ||||||
target_series=target_series, | ||||||
past_covariates=past_covariates, | ||||||
future_covariates=future_covariates, | ||||||
lags=lags, | ||||||
lags_past_covariates=lags_past_covariates, | ||||||
lags_future_covariates=lags_future_covariates, | ||||||
output_chunk_length=self.output_chunk_length, | ||||||
concatenate=False, | ||||||
) | ||||||
|
||||||
# adding the static covariates on the right of each features_cols_name | ||||||
features_cols_name = self._add_static_covariates_name( | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we can move this into the helper function |
||||||
features_cols_name, | ||||||
target_series, | ||||||
) | ||||||
|
||||||
return features_cols_name, labels_cols_name | ||||||
|
||||||
def _add_static_covariates( | ||||||
self, | ||||||
features: Union[np.array, Sequence[np.array]], | ||||||
|
@@ -445,6 +474,41 @@ def _add_static_covariates( | |||||
features = features[0] | ||||||
return features | ||||||
|
||||||
def _add_static_covariates_name( | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. should be part of |
||||||
self, | ||||||
features_cols_name: List[List[str]], | ||||||
target_series: Union[TimeSeries, Sequence[TimeSeries]], | ||||||
) -> Union[np.array, Sequence[np.array]]: | ||||||
""" | ||||||
Add static covariates names to the features name for RegressionModels. | ||||||
Accounts for series with potentially different static covariates to accomodate for the maximum | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Aren't the number of static covariates guaranteed to be identical? The models should throw an error when using series with different static covariate numbers, no? |
||||||
number of available static_covariates in any of the given series in the sequence. | ||||||
|
||||||
Parameters | ||||||
---------- | ||||||
features_cols_name | ||||||
The name of the features of the numpy array(s) to which the static covariates will be added, generated with | ||||||
`create_lagged_components_names()` | ||||||
target_series | ||||||
The target series from which to read the static covariates. | ||||||
|
||||||
Returns | ||||||
------- | ||||||
features_cols_name | ||||||
The features' name list with appended static covariates names on the right. | ||||||
""" | ||||||
target_series = series2seq(target_series) | ||||||
|
||||||
# collect static covariates info, preserve the order | ||||||
static_covs_names = [] | ||||||
for ts in target_series: | ||||||
if ts.has_static_covariates: | ||||||
for static_cov_name in ts.static_covariates.keys(): | ||||||
if static_cov_name not in static_covs_names: | ||||||
static_covs_names.append(static_cov_name) | ||||||
|
||||||
return features_cols_name + static_covs_names | ||||||
|
||||||
def _fit_model( | ||||||
self, | ||||||
target_series, | ||||||
|
@@ -470,6 +534,14 @@ def _fit_model( | |||||
training_labels = training_labels.ravel() | ||||||
self.model.fit(training_samples, training_labels, **kwargs) | ||||||
|
||||||
# generate and store the lagged components names (for feature importance analysis) | ||||||
lagged_features_names, _ = self._create_lagged_components_name( | ||||||
target_series=target_series, | ||||||
past_covariates=past_covariates, | ||||||
future_covariates=future_covariates, | ||||||
) | ||||||
self.model.lagged_features_name_ = lagged_features_names | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's use a naming convention Also shouldn't we store this in the Darts model, rather than the actual one?
Suggested change
|
||||||
|
||||||
def fit( | ||||||
self, | ||||||
series: Union[TimeSeries, Sequence[TimeSeries]], | ||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we can remove this method and put everything into the helper function
create_lagged_component_names