Skip to content

Add DeseasonalityTransform #1307

Merged
merged 10 commits into from
Jul 6, 2023
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## Unreleased
### Added
-
- `DeseasonalityTransform` ([#1307](https://github.com/tinkoff-ai/etna/pull/1307))
-
-
-
Expand Down
1 change: 1 addition & 0 deletions etna/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from etna.transforms.decomposition import ChangePointsLevelTransform
from etna.transforms.decomposition import ChangePointsSegmentationTransform
from etna.transforms.decomposition import ChangePointsTrendTransform
from etna.transforms.decomposition import DeseasonalityTransform
from etna.transforms.decomposition import IrreversibleChangePointsTransform
from etna.transforms.decomposition import LinearTrendTransform
from etna.transforms.decomposition import ReversibleChangePointsTransform
Expand Down
1 change: 1 addition & 0 deletions etna/transforms/decomposition/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from etna.transforms.decomposition.change_points_based.change_points_models import RupturesChangePointsModel
from etna.transforms.decomposition.change_points_based.level import ChangePointsLevelTransform
from etna.transforms.decomposition.change_points_based.trend import TrendTransform
from etna.transforms.decomposition.deseasonal import DeseasonalityTransform
from etna.transforms.decomposition.detrend import LinearTrendTransform
from etna.transforms.decomposition.detrend import TheilSenTrendTransform
from etna.transforms.decomposition.stl import STLTransform
211 changes: 211 additions & 0 deletions etna/transforms/decomposition/deseasonal.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,211 @@
from enum import Enum
from typing import Dict
from typing import List
from typing import Optional

import numpy as np
import pandas as pd
from statsmodels.tsa.seasonal import seasonal_decompose

from etna.distributions import BaseDistribution
from etna.distributions import CategoricalDistribution
from etna.models.utils import determine_freq
from etna.models.utils import determine_num_steps
from etna.transforms.base import OneSegmentTransform
from etna.transforms.base import ReversiblePerSegmentWrapper
from etna.transforms.utils import match_target_quantiles


class DeseasonalModel(str, Enum):
"""Enum for different types of deseasonality model."""

additive = "additive"
multiplicative = "multiplicative"

@classmethod
def _missing_(cls, value):
raise NotImplementedError(
f"{value} is not a valid {cls.__name__}. Only {', '.join([repr(m.value) for m in cls])} types allowed."
)


class _OneSegmentDeseasonalityTransform(OneSegmentTransform):
def __init__(self, in_column: str, period: int, model: str = DeseasonalModel.additive):
"""
Init _OneSegmentDeseasonalityTransform.

Parameters
----------
in_column:
name of processed column
period:
size of seasonality
model:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can do this by using enum. Value checking will be made there.

'additive' (default) or 'multiplicative'
"""
self.in_column = in_column
self.period = period
self.model = DeseasonalModel(model)
self._seasonal: Optional[pd.Series] = None

def _roll_seasonal(self, x: pd.Series) -> np.ndarray:
"""
Roll out seasonal component by x's time index.

Parameters
----------
x:
processed column

Returns
-------
result:
seasonal component
"""
if self._seasonal is None:
raise ValueError("Transform is not fitted! Fit the Transform before calling.")
freq = determine_freq(x.index)
if self._seasonal.index[0] <= x.index[0]:
shift = -determine_num_steps(self._seasonal.index[0], x.index[0], freq) % self.period
else:
shift = determine_num_steps(x.index[0], self._seasonal.index[0], freq) % self.period
return np.resize(np.roll(self._seasonal, shift=shift), x.shape[0])

def fit(self, df: pd.DataFrame) -> "_OneSegmentDeseasonalityTransform":
"""
Perform seasonal decomposition.

Parameters
----------
df:
Features dataframe with time

Returns
-------
result:
instance after processing
"""
df = df.loc[df[self.in_column].first_valid_index() : df[self.in_column].last_valid_index()]
if df[self.in_column].isnull().values.any():
raise ValueError("The input column contains NaNs in the middle of the series! Try to use the imputer.")
self._seasonal = seasonal_decompose(
x=df[self.in_column], model=self.model, filt=None, two_sided=False, extrapolate_trend=0
).seasonal[: self.period]
return self

def transform(self, df: pd.DataFrame) -> pd.DataFrame:
"""
Subtract seasonal component.

Parameters
----------
df:
Features dataframe with time

Returns
-------
result: pd.DataFrame
Mr-Geekman marked this conversation as resolved.
Show resolved Hide resolved
Dataframe with extracted features
"""
result = df
seasonal = self._roll_seasonal(result[self.in_column])
if self.model == "additive":
result[self.in_column] -= seasonal
else:
if np.any(result[self.in_column] <= 0):
raise ValueError(
Mr-Geekman marked this conversation as resolved.
Show resolved Hide resolved
"The input column contains zero or negative values,"
"but multiplicative seasonality can not work with such values."
)
result[self.in_column] /= seasonal
Mr-Geekman marked this conversation as resolved.
Show resolved Hide resolved
return result

def inverse_transform(self, df: pd.DataFrame) -> pd.DataFrame:
"""
Add seasonal component.

Parameters
----------
df:
Features dataframe with time

Returns
-------
result: pd.DataFrame
Mr-Geekman marked this conversation as resolved.
Show resolved Hide resolved
Dataframe with extracted features
"""
result = df
seasonal = self._roll_seasonal(result[self.in_column])
if self.model == "additive":
result[self.in_column] += seasonal
else:
if np.any(result[self.in_column] <= 0):
raise ValueError(
"The input column contains zero or negative values,"
"but multiplicative seasonality can not work with such values."
)
result[self.in_column] *= seasonal
if self.in_column == "target":
quantiles = match_target_quantiles(set(result.columns))
for quantile_column_nm in quantiles:
if self.model == "additive":
result.loc[:, quantile_column_nm] += seasonal
else:
if np.any(result.loc[quantile_column_nm] <= 0):
raise ValueError(
f"The {quantile_column_nm} column contains zero or negative values,"
"but multiplicative seasonality can not work with such values."
)
result.loc[:, quantile_column_nm] *= seasonal
return result


class DeseasonalityTransform(ReversiblePerSegmentWrapper):
"""Transform that uses :py:class:`statsmodels.tsa.seasonal.seasonal_decompose` to subtract season from the data.
Mr-Geekman marked this conversation as resolved.
Show resolved Hide resolved

Warning
Mr-Geekman marked this conversation as resolved.
Show resolved Hide resolved
-------
This transform can suffer from look-ahead bias. For transforming data at some timestamp
it uses information from the whole train part.
"""

def __init__(self, in_column: str, period: int, model: str = "additive"):
"""
Init DeseasonalityTransform.

Parameters
----------
in_column:
name of processed column
period:
size of seasonality
model:
'additive' (default) or 'multiplicative'
Mr-Geekman marked this conversation as resolved.
Show resolved Hide resolved
"""
self.in_column = in_column
self.period = period
self.model = model
super().__init__(
transform=_OneSegmentDeseasonalityTransform(
in_column=self.in_column,
period=self.period,
model=self.model,
),
required_features=[self.in_column],
)

def get_regressors_info(self) -> List[str]:
"""Return the list with regressors created by the transform."""
return []

def params_to_tune(self) -> Dict[str, BaseDistribution]:
"""Get default grid for tuning hyperparameters.

This grid tunes parameters: ``model``. Other parameters are expected to be set by the user.

Returns
-------
:
Grid to tune.
"""
return {"model": CategoricalDistribution(["additive", "multiplicative"])}
Loading