-
Notifications
You must be signed in to change notification settings - Fork 82
Add DeseasonalityTransform
#1307
Merged
Merged
Changes from 7 commits
Commits
Show all changes
10 commits
Select commit
Hold shift + click to select a range
a3fc592
Add `DeseasonalityTransform`
ostreech1997 d810bdb
Add tests for `DeseasonalityTransform`
ostreech1997 06b2e82
Merge branch 'master' into issue-680
ostreech1997 3c1c99d
Reformat code
ostreech1997 090f2e7
Update CHANGELOG.md
ostreech1997 4727541
Code refactoring
ostreech1997 b567a3c
Merge with master
ostreech1997 cfa5ab0
Refactoring code v.2
ostreech1997 1a6270e
Merge with master
ostreech1997 bbede11
Code refactoring v.3
ostreech1997 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,211 @@ | ||
from enum import Enum | ||
from typing import Dict | ||
from typing import List | ||
from typing import Optional | ||
|
||
import numpy as np | ||
import pandas as pd | ||
from statsmodels.tsa.seasonal import seasonal_decompose | ||
|
||
from etna.distributions import BaseDistribution | ||
from etna.distributions import CategoricalDistribution | ||
from etna.models.utils import determine_freq | ||
from etna.models.utils import determine_num_steps | ||
from etna.transforms.base import OneSegmentTransform | ||
from etna.transforms.base import ReversiblePerSegmentWrapper | ||
from etna.transforms.utils import match_target_quantiles | ||
|
||
|
||
class DeseasonalModel(str, Enum): | ||
"""Enum for different types of deseasonality model.""" | ||
|
||
additive = "additive" | ||
multiplicative = "multiplicative" | ||
|
||
@classmethod | ||
def _missing_(cls, value): | ||
raise NotImplementedError( | ||
f"{value} is not a valid {cls.__name__}. Only {', '.join([repr(m.value) for m in cls])} types allowed." | ||
) | ||
|
||
|
||
class _OneSegmentDeseasonalityTransform(OneSegmentTransform): | ||
def __init__(self, in_column: str, period: int, model: str = DeseasonalModel.additive): | ||
""" | ||
Init _OneSegmentDeseasonalityTransform. | ||
|
||
Parameters | ||
---------- | ||
in_column: | ||
name of processed column | ||
period: | ||
size of seasonality | ||
model: | ||
'additive' (default) or 'multiplicative' | ||
""" | ||
self.in_column = in_column | ||
self.period = period | ||
self.model = DeseasonalModel(model) | ||
self._seasonal: Optional[pd.Series] = None | ||
|
||
def _roll_seasonal(self, x: pd.Series) -> np.ndarray: | ||
""" | ||
Roll out seasonal component by x's time index. | ||
|
||
Parameters | ||
---------- | ||
x: | ||
processed column | ||
|
||
Returns | ||
------- | ||
result: | ||
seasonal component | ||
""" | ||
if self._seasonal is None: | ||
raise ValueError("Transform is not fitted! Fit the Transform before calling.") | ||
freq = determine_freq(x.index) | ||
if self._seasonal.index[0] <= x.index[0]: | ||
shift = -determine_num_steps(self._seasonal.index[0], x.index[0], freq) % self.period | ||
else: | ||
shift = determine_num_steps(x.index[0], self._seasonal.index[0], freq) % self.period | ||
return np.resize(np.roll(self._seasonal, shift=shift), x.shape[0]) | ||
|
||
def fit(self, df: pd.DataFrame) -> "_OneSegmentDeseasonalityTransform": | ||
""" | ||
Perform seasonal decomposition. | ||
|
||
Parameters | ||
---------- | ||
df: | ||
Features dataframe with time | ||
|
||
Returns | ||
------- | ||
result: | ||
instance after processing | ||
""" | ||
df = df.loc[df[self.in_column].first_valid_index() : df[self.in_column].last_valid_index()] | ||
if df[self.in_column].isnull().values.any(): | ||
raise ValueError("The input column contains NaNs in the middle of the series! Try to use the imputer.") | ||
self._seasonal = seasonal_decompose( | ||
x=df[self.in_column], model=self.model, filt=None, two_sided=False, extrapolate_trend=0 | ||
).seasonal[: self.period] | ||
return self | ||
|
||
def transform(self, df: pd.DataFrame) -> pd.DataFrame: | ||
""" | ||
Subtract seasonal component. | ||
|
||
Parameters | ||
---------- | ||
df: | ||
Features dataframe with time | ||
|
||
Returns | ||
------- | ||
result: pd.DataFrame | ||
Mr-Geekman marked this conversation as resolved.
Show resolved
Hide resolved
|
||
Dataframe with extracted features | ||
""" | ||
result = df | ||
seasonal = self._roll_seasonal(result[self.in_column]) | ||
if self.model == "additive": | ||
result[self.in_column] -= seasonal | ||
else: | ||
if np.any(result[self.in_column] <= 0): | ||
raise ValueError( | ||
Mr-Geekman marked this conversation as resolved.
Show resolved
Hide resolved
|
||
"The input column contains zero or negative values," | ||
"but multiplicative seasonality can not work with such values." | ||
) | ||
result[self.in_column] /= seasonal | ||
Mr-Geekman marked this conversation as resolved.
Show resolved
Hide resolved
|
||
return result | ||
|
||
def inverse_transform(self, df: pd.DataFrame) -> pd.DataFrame: | ||
""" | ||
Add seasonal component. | ||
|
||
Parameters | ||
---------- | ||
df: | ||
Features dataframe with time | ||
|
||
Returns | ||
------- | ||
result: pd.DataFrame | ||
Mr-Geekman marked this conversation as resolved.
Show resolved
Hide resolved
|
||
Dataframe with extracted features | ||
""" | ||
result = df | ||
seasonal = self._roll_seasonal(result[self.in_column]) | ||
if self.model == "additive": | ||
result[self.in_column] += seasonal | ||
else: | ||
if np.any(result[self.in_column] <= 0): | ||
raise ValueError( | ||
"The input column contains zero or negative values," | ||
"but multiplicative seasonality can not work with such values." | ||
) | ||
result[self.in_column] *= seasonal | ||
if self.in_column == "target": | ||
quantiles = match_target_quantiles(set(result.columns)) | ||
for quantile_column_nm in quantiles: | ||
if self.model == "additive": | ||
result.loc[:, quantile_column_nm] += seasonal | ||
else: | ||
if np.any(result.loc[quantile_column_nm] <= 0): | ||
raise ValueError( | ||
f"The {quantile_column_nm} column contains zero or negative values," | ||
"but multiplicative seasonality can not work with such values." | ||
) | ||
result.loc[:, quantile_column_nm] *= seasonal | ||
return result | ||
|
||
|
||
class DeseasonalityTransform(ReversiblePerSegmentWrapper): | ||
"""Transform that uses :py:class:`statsmodels.tsa.seasonal.seasonal_decompose` to subtract season from the data. | ||
Mr-Geekman marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
Warning | ||
Mr-Geekman marked this conversation as resolved.
Show resolved
Hide resolved
|
||
------- | ||
This transform can suffer from look-ahead bias. For transforming data at some timestamp | ||
it uses information from the whole train part. | ||
""" | ||
|
||
def __init__(self, in_column: str, period: int, model: str = "additive"): | ||
""" | ||
Init DeseasonalityTransform. | ||
|
||
Parameters | ||
---------- | ||
in_column: | ||
name of processed column | ||
period: | ||
size of seasonality | ||
model: | ||
'additive' (default) or 'multiplicative' | ||
Mr-Geekman marked this conversation as resolved.
Show resolved
Hide resolved
|
||
""" | ||
self.in_column = in_column | ||
self.period = period | ||
self.model = model | ||
super().__init__( | ||
transform=_OneSegmentDeseasonalityTransform( | ||
in_column=self.in_column, | ||
period=self.period, | ||
model=self.model, | ||
), | ||
required_features=[self.in_column], | ||
) | ||
|
||
def get_regressors_info(self) -> List[str]: | ||
"""Return the list with regressors created by the transform.""" | ||
return [] | ||
|
||
def params_to_tune(self) -> Dict[str, BaseDistribution]: | ||
"""Get default grid for tuning hyperparameters. | ||
|
||
This grid tunes parameters: ``model``. Other parameters are expected to be set by the user. | ||
|
||
Returns | ||
------- | ||
: | ||
Grid to tune. | ||
""" | ||
return {"model": CategoricalDistribution(["additive", "multiplicative"])} |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can do this by using
enum
. Value checking will be made there.