From e5ce2802c0c9097a4fc96fe159fc24771637d474 Mon Sep 17 00:00:00 2001 From: alex-hse-repository <55380696+alex-hse-repository@users.noreply.github.com> Date: Fri, 27 Jan 2023 14:03:46 +0300 Subject: [PATCH 01/13] Add hierarchical pipeline (#1083) --- CHANGELOG.md | 12 +- README.md | 1 + docs/source/tutorials.rst | 1 + etna/datasets/__init__.py | 2 + etna/datasets/datasets_generation.py | 100 + etna/datasets/hierarchical_structure.py | 192 + etna/datasets/tsdataset.py | 195 +- etna/datasets/utils.py | 91 + etna/pipeline/__init__.py | 1 + etna/pipeline/base.py | 27 +- etna/pipeline/hierarchical_pipeline.py | 156 + etna/reconciliation/__init__.py | 3 + etna/reconciliation/base.py | 100 + etna/reconciliation/bottom_up.py | 50 + etna/reconciliation/top_down.py | 139 + etna/transforms/utils.py | 9 +- examples/README.md | 6 + .../hierarchical_pipeline/hierarchy.png | Bin 0 -> 21275 bytes examples/hierarchical_pipeline.ipynb | 6578 +++++++++++++++++ tests/conftest.py | 227 + tests/test_datasets/conftest.py | 21 + .../test_datasets/test_datasets_generation.py | 57 + .../test_hierarchical_dataset.py | 483 ++ .../test_hierarchical_structure.py | 278 + tests/test_datasets/test_utils.py | 68 + .../test_hierarchical_pipeline.py | 262 + tests/test_reconciliation/__init__.py | 0 tests/test_reconciliation/conftest.py | 29 + tests/test_reconciliation/test_base.py | 107 + tests/test_reconciliation/test_bottom_up.py | 107 + tests/test_reconciliation/test_top_down.py | 300 + 31 files changed, 9576 insertions(+), 26 deletions(-) create mode 100644 etna/datasets/hierarchical_structure.py create mode 100644 etna/pipeline/hierarchical_pipeline.py create mode 100644 etna/reconciliation/__init__.py create mode 100644 etna/reconciliation/base.py create mode 100644 etna/reconciliation/bottom_up.py create mode 100644 etna/reconciliation/top_down.py create mode 100644 examples/assets/hierarchical_pipeline/hierarchy.png create mode 100644 examples/hierarchical_pipeline.ipynb create mode 100644 tests/test_datasets/conftest.py create mode 100644 tests/test_datasets/test_hierarchical_dataset.py create mode 100644 tests/test_datasets/test_hierarchical_structure.py create mode 100644 tests/test_pipeline/test_hierarchical_pipeline.py create mode 100644 tests/test_reconciliation/__init__.py create mode 100644 tests/test_reconciliation/conftest.py create mode 100644 tests/test_reconciliation/test_base.py create mode 100644 tests/test_reconciliation/test_bottom_up.py create mode 100644 tests/test_reconciliation/test_top_down.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 30584c4ef..dbca6be42 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,12 +11,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - `RMSE` metric & `rmse` functional metric ([#1051](https://github.com/tinkoff-ai/etna/pull/1051)) - `MaxDeviation` metric & `max_deviation` functional metric ([#1061](https://github.com/tinkoff-ai/etna/pull/1061)) - Add saving/loading for transforms, models, pipelines, ensembles; tutorial for saving/loading ([#1068](https://github.com/tinkoff-ai/etna/pull/1068)) -- -- -- -- -- -- +- Add hierarchical time series support([#1083](https://github.com/tinkoff-ai/etna/pull/1083)) +- +- +- +- +- ### Changed - - diff --git a/README.md b/README.md index a4a0c04dd..1b21cf31c 100644 --- a/README.md +++ b/README.md @@ -181,6 +181,7 @@ We have also prepared a set of tutorials for an easy introduction: | [Exogenous data](https://github.com/tinkoff-ai/etna/tree/master/examples/exogenous_data.ipynb) | [![Binder](https://mybinder.org/badge_logo.svg)](https://mybinder.org/v2/gh/tinkoff-ai/etna/master?filepath=examples/exogenous_data.ipynb) | | [Forecasting strategies](https://github.com/tinkoff-ai/etna/blob/master/examples/forecasting_strategies.ipynb) | [![Binder](https://mybinder.org/badge_logo.svg)](https://mybinder.org/v2/gh/tinkoff-ai/etna/master?filepath=examples/forecasting_strategies.ipynb) | | [Classification](https://github.com/tinkoff-ai/etna/blob/master/examples/classification.ipynb) | [![Binder](https://mybinder.org/badge_logo.svg)](https://mybinder.org/v2/gh/tinkoff-ai/etna/master?filepath=examples/classification.ipynb) | +| [Hierarchical time series](https://github.com/tinkoff-ai/etna/blob/master/examples/hierarchical_pipeline.ipynb) | [![Binder](https://mybinder.org/badge_logo.svg)](https://mybinder.org/v2/gh/tinkoff-ai/etna/master?filepath=examples/hierarchical_pipeline.ipynb) | ## Documentation diff --git a/docs/source/tutorials.rst b/docs/source/tutorials.rst index c941c4000..a1d1baf57 100644 --- a/docs/source/tutorials.rst +++ b/docs/source/tutorials.rst @@ -16,3 +16,4 @@ Tutorials tutorials/ensembles tutorials/NN_examples tutorials/classification + tutorials/hierarchical_pipeline diff --git a/etna/datasets/__init__.py b/etna/datasets/__init__.py index 9ef938d77..686ddc466 100644 --- a/etna/datasets/__init__.py +++ b/etna/datasets/__init__.py @@ -1,7 +1,9 @@ from etna.datasets.datasets_generation import generate_ar_df from etna.datasets.datasets_generation import generate_const_df from etna.datasets.datasets_generation import generate_from_patterns_df +from etna.datasets.datasets_generation import generate_hierarchical_df from etna.datasets.datasets_generation import generate_periodic_df +from etna.datasets.hierarchical_structure import HierarchicalStructure from etna.datasets.tsdataset import TSDataset from etna.datasets.utils import duplicate_data from etna.datasets.utils import set_columns_wide diff --git a/etna/datasets/datasets_generation.py b/etna/datasets/datasets_generation.py index 80fa2cbb0..7e6a4a65d 100644 --- a/etna/datasets/datasets_generation.py +++ b/etna/datasets/datasets_generation.py @@ -186,3 +186,103 @@ def generate_from_patterns_df( df["timestamp"] = pd.date_range(start=start_time, freq=freq, periods=periods) df = df.melt(id_vars=["timestamp"], value_name="target", var_name="segment") return df + + +def generate_hierarchical_df( + periods: int, + n_segments: List[int], + freq: str = "D", + start_time: str = "2000-01-01", + ar_coef: Optional[list] = None, + sigma: float = 1, + random_seed: int = 1, +) -> pd.DataFrame: + """ + Create DataFrame with hierarchical structure and AR process data. + + The hierarchical structure is generated as follows: + 1. Number of levels in the structure is the same as length of ``n_segments`` parameter + 2. Each level contains the number of segments set in ``n_segments`` + 3. Connections from parent to child level are generated randomly. + + Parameters + ---------- + periods: + number of timestamps + n_segments: + number of segments on each level. + freq: + pandas frequency string for :py:func:`pandas.date_range` that is used to generate timestamp + start_time: + start timestamp + ar_coef: + AR coefficients + sigma: + scale of AR noise + random_seed: + random seed + + Returns + ------- + : + DataFrame at the bottom level of the hierarchy + + Raises + ------ + ValueError: + ``n_segments`` is empty + ValueError: + ``n_segments`` contains not positive integers + ValueError: + ``n_segments`` represents not non-decreasing sequence + """ + if len(n_segments) == 0: + raise ValueError("`n_segments` should contain at least one positive integer!") + + if (np.less_equal(n_segments, 0)).any(): + raise ValueError("All `n_segments` elements should be positive!") + + if (np.diff(n_segments) < 0).any(): + raise ValueError("`n_segments` should represent non-decreasing sequence!") + + rnd = RandomState(seed=random_seed) + + bottom_df = generate_ar_df( + periods=periods, + start_time=start_time, + ar_coef=ar_coef, + sigma=sigma, + n_segments=n_segments[-1], + freq=freq, + random_seed=random_seed, + ) + + bottom_segments = np.unique(bottom_df["segment"]) + + n_levels = len(n_segments) + child_to_parent = dict() + for level_id in range(1, n_levels): + prev_level_n_segments = n_segments[level_id - 1] + cur_level_n_segments = n_segments[level_id] + + # ensure all parents have at least one child + seen_ids = set() + child_ids = rnd.choice(cur_level_n_segments, prev_level_n_segments, replace=False) + for parent_id, child_id in enumerate(child_ids): + seen_ids.add(child_id) + child_to_parent[f"l{level_id}s{child_id}"] = f"l{level_id - 1}s{parent_id}" + + for child_id in range(cur_level_n_segments): + if child_id not in seen_ids: + parent_id = rnd.choice(prev_level_n_segments, 1).item() + child_to_parent[f"l{level_id}s{child_id}"] = f"l{level_id - 1}s{parent_id}" + + bottom_segments_map = {segment: f"l{n_levels - 1}s{idx}" for idx, segment in enumerate(bottom_segments)} + bottom_df[f"level_{n_levels - 1}"] = bottom_df["segment"].map(lambda x: bottom_segments_map[x]) + + for level_id in range(n_levels - 2, -1, -1): + bottom_df[f"level_{level_id}"] = bottom_df[f"level_{level_id + 1}"].map(lambda x: child_to_parent[x]) + + bottom_df.drop(columns=["segment"], inplace=True) + + return bottom_df diff --git a/etna/datasets/hierarchical_structure.py b/etna/datasets/hierarchical_structure.py new file mode 100644 index 000000000..4dd664c07 --- /dev/null +++ b/etna/datasets/hierarchical_structure.py @@ -0,0 +1,192 @@ +from collections import defaultdict +from itertools import chain +from queue import Queue +from typing import Dict +from typing import List +from typing import Optional +from typing import Tuple + +from scipy.sparse import csr_matrix +from scipy.sparse import lil_matrix + +from etna.core import BaseMixin + + +class HierarchicalStructure(BaseMixin): + """Represents hierarchical structure of TSDataset.""" + + def __init__(self, level_structure: Dict[str, List[str]], level_names: Optional[List[str]] = None): + """Init HierarchicalStructure. + + Parameters + ---------- + level_structure: + Adjacency list describing the structure of the hierarchy tree (i.e. {"total":["X", "Y"], "X":["a", "b"], "Y":["c", "d"]}). + level_names: + Names of levels in the hierarchy in the order from top to bottom (i.e. ["total", "category", "product"]). + If None is passed, level names are generated automatically with structure "level_". + """ + self.level_structure = level_structure + self._hierarchy_root = self._find_tree_root(self.level_structure) + self._num_nodes = self._find_num_nodes(self.level_structure) + + hierarchy_levels = self._find_hierarchy_levels() + tree_depth = len(hierarchy_levels) + + self.level_names = self._get_level_names(level_names, tree_depth) + self._level_series: Dict[str, Tuple[str, ...]] = { + self.level_names[i]: tuple(hierarchy_levels[i]) for i in range(tree_depth) + } + self._level_to_index: Dict[str, int] = {self.level_names[i]: i for i in range(tree_depth)} + + self._segment_num_reachable_leafs: Dict[str, int] = self._get_num_reachable_leafs(hierarchy_levels) + + self._segment_to_level: Dict[str, str] = { + segment: level for level in self._level_series for segment in self._level_series[level] + } + + @staticmethod + def _get_level_names(level_names: Optional[List[str]], tree_depth: int) -> List[str]: + """Assign level names if not provided.""" + if level_names is None: + level_names = [f"level_{i}" for i in range(tree_depth)] + + if len(level_names) != tree_depth: + raise ValueError("Length of `level_names` must be equal to hierarchy tree depth!") + + return level_names + + @staticmethod + def _find_tree_root(hierarchy_structure: Dict[str, List[str]]) -> str: + """Find hierarchy top level (root of tree).""" + children = set(chain(*hierarchy_structure.values())) + parents = set(hierarchy_structure.keys()) + + tree_roots = parents.difference(children) + if len(tree_roots) != 1: + raise ValueError("Invalid tree definition: unable to find root!") + + return tree_roots.pop() + + @staticmethod + def _find_num_nodes(hierarchy_structure: Dict[str, List[str]]) -> int: + """Count number of nodes in tree.""" + children = set(chain(*hierarchy_structure.values())) + parents = set(hierarchy_structure.keys()) + + num_nodes = len(children | parents) + + num_edges = sum(map(len, hierarchy_structure.values())) + if num_edges != num_nodes - 1: + raise ValueError("Invalid tree definition: invalid number of nodes and edges!") + + return num_nodes + + def _find_hierarchy_levels(self) -> Dict[int, List[str]]: + """Traverse hierarchy tree to group segments into levels.""" + leaves_levels = set() + levels = defaultdict(list) + seen_nodes = {self._hierarchy_root} + queue: Queue = Queue() + queue.put((self._hierarchy_root, 0)) + while not queue.empty(): + node, level = queue.get() + levels[level].append(node) + child_nodes = self.level_structure.get(node, []) + + if len(child_nodes) == 0: + leaves_levels.add(level) + + for adj_node in child_nodes: + queue.put((adj_node, level + 1)) + seen_nodes.add(adj_node) + + if len(seen_nodes) != self._num_nodes: + raise ValueError("Invalid tree definition: disconnected graph!") + + if len(leaves_levels) != 1: + raise ValueError("All hierarchy tree leaves must be on the same level!") + + return levels + + def _get_num_reachable_leafs(self, hierarchy_levels: Dict[int, List[str]]) -> Dict[str, int]: + """Compute subtree size for each node.""" + num_reachable_leafs: Dict[str, int] = dict() + for level in sorted(hierarchy_levels.keys(), reverse=True): + for node in hierarchy_levels[level]: + if node in self.level_structure: + num_reachable_leafs[node] = sum( + num_reachable_leafs[child_node] for child_node in self.level_structure[node] + ) + + else: + num_reachable_leafs[node] = 1 + + return num_reachable_leafs + + def get_summing_matrix(self, target_level: str, source_level: str) -> csr_matrix: + """Get summing matrix for transition from source level to target level. + + Generation algorithm is based on summing matrix structure. Number of 1 in such matrices equals to + number of nodes on the source level. Each row of summing matrices has ones only for source level nodes that + belongs to subtree rooted from corresponding target level node. BFS order of nodes on levels view simplifies + algorithm to calculation necessary offsets for each row. + + Parameters + ---------- + target_level: + Name of target level. + source_level: + Name of source level. + + Returns + ------- + : + Summing matrix from source level to target level + + """ + try: + target_idx = self._level_to_index[target_level] + source_idx = self._level_to_index[source_level] + except KeyError as e: + raise ValueError(f"Invalid level name: {e.args[0]}") + + if target_idx > source_idx: + raise ValueError("Target level must be higher or equal in hierarchy than source level!") + + target_level_segment = self.get_level_segments(target_level) + source_level_segment = self.get_level_segments(source_level) + summing_matrix = lil_matrix((len(target_level_segment), len(source_level_segment)), dtype="int32") + + current_source_segment_id = 0 + for current_target_segment_id, segment in enumerate(target_level_segment): + num_reachable_leafs_left = self._segment_num_reachable_leafs[segment] + + while num_reachable_leafs_left > 0: + source_segment = source_level_segment[current_source_segment_id] + num_reachable_leafs_left -= self._segment_num_reachable_leafs[source_segment] + summing_matrix[current_target_segment_id, current_source_segment_id] = 1 + current_source_segment_id += 1 + + return summing_matrix.tocsr() + + def get_level_segments(self, level_name: str) -> List[str]: + """Get all segments from particular level.""" + try: + return list(self._level_series[level_name]) + except KeyError: + raise ValueError(f"Invalid level name: {level_name}") + + def get_segment_level(self, segment: str) -> str: + """Get level name for provided segment.""" + try: + return self._segment_to_level[segment] + except KeyError: + raise ValueError(f"Segment {segment} is out of the hierarchy") + + def get_level_depth(self, level_name: str) -> int: + """Get level depth in a hierarchy tree.""" + try: + return self._level_to_index[level_name] + except KeyError: + raise ValueError(f"Invalid level name: {level_name}") diff --git a/etna/datasets/tsdataset.py b/etna/datasets/tsdataset.py index 74f15336e..09c89edc6 100644 --- a/etna/datasets/tsdataset.py +++ b/etna/datasets/tsdataset.py @@ -21,7 +21,10 @@ from typing_extensions import Literal from etna import SETTINGS +from etna.datasets.hierarchical_structure import HierarchicalStructure from etna.datasets.utils import _TorchDataset +from etna.datasets.utils import get_level_dataframe +from etna.datasets.utils import get_target_with_quantiles from etna.loggers import tslogger if TYPE_CHECKING: @@ -82,6 +85,21 @@ class TSDataset: 2021-01-03 0.48 0.47 -0.81 -1.56 -1.37 0.48 2021-01-04 -0.59 2.44 -2.21 -1.21 -0.69 -0.59 2021-01-05 0.28 0.58 -3.07 -1.45 0.77 0.28 + + >>> from etna.datasets import generate_hierarchical_df + >>> pd.options.display.width = 0 + >>> df = generate_hierarchical_df(periods=100, n_segments=[2, 4], start_time="2021-01-01",) + >>> df, hierarchical_structure = TSDataset.to_hierarchical_dataset(df=df, level_columns=["level_0", "level_1"]) + >>> tsdataset = TSDataset(df=df, freq="D", hierarchical_structure=hierarchical_structure) + >>> tsdataset.df.head(5) + segment l0s0_l1s3 l0s1_l1s0 l0s1_l1s1 l0s1_l1s2 + feature target target target target + timestamp + 2021-01-01 2.07 1.62 -0.45 -0.40 + 2021-01-02 0.59 1.01 0.78 0.42 + 2021-01-03 -0.24 0.48 1.18 -0.14 + 2021-01-04 -1.12 -0.59 1.77 1.82 + 2021-01-05 -1.40 0.28 0.68 0.48 """ idx = pd.IndexSlice @@ -92,6 +110,7 @@ def __init__( freq: str, df_exog: Optional[pd.DataFrame] = None, known_future: Union[Literal["all"], Sequence] = (), + hierarchical_structure: Optional[HierarchicalStructure] = None, ): """Init TSDataset. @@ -106,6 +125,8 @@ def __init__( known_future: columns in ``df_exog[known_future]`` that are regressors, if "all" value is given, all columns are meant to be regressors + hierarchical_structure: + Structure of the levels in the hierarchy. If None, there is no hierarchical structure in the dataset. """ self.raw_df = self._prepare_df(df) self.raw_df.index = pd.to_datetime(self.raw_df.index) @@ -132,13 +153,36 @@ def __init__( self.known_future = self._check_known_future(known_future, df_exog) self._regressors = copy(self.known_future) + self.hierarchical_structure = hierarchical_structure + self.current_df_level: Optional[str] = self._get_dataframe_level(df=self.df) + self.current_df_exog_level: Optional[str] = None + if df_exog is not None: self.df_exog = df_exog.copy(deep=True) self.df_exog.index = pd.to_datetime(self.df_exog.index) - self.df = self._merge_exog(self.df) + self.current_df_exog_level = self._get_dataframe_level(df=self.df_exog) + if self.current_df_level == self.current_df_exog_level: + self.df = self._merge_exog(self.df) self.transforms: Optional[Sequence["Transform"]] = None + def _get_dataframe_level(self, df: pd.DataFrame) -> Optional[str]: + """Return the level of the passed dataframe in hierarchical structure.""" + if self.hierarchical_structure is None: + return None + + df_segments = df.columns.get_level_values("segment").unique() + segment_levels = {self.hierarchical_structure.get_segment_level(segment=segment) for segment in df_segments} + if len(segment_levels) != 1: + raise ValueError("Segments in dataframe are from more than 1 hierarchical levels!") + + df_level = segment_levels.pop() + level_segments = self.hierarchical_structure.get_level_segments(level_name=df_level) + if len(df_segments) != len(level_segments): + raise ValueError("Some segments of hierarchical level are missing in dataframe!") + + return df_level + def transform(self, transforms: Sequence["Transform"]): """Apply given transform to the data.""" self._check_endings(warning=True) @@ -294,7 +338,7 @@ def make_future(self, future_steps: int, tail_steps: int = 0) -> "TSDataset": df = self.raw_df.reindex(new_index) df.index.name = "timestamp" - if self.df_exog is not None: + if self.df_exog is not None and self.current_df_level == self.current_df_exog_level: df = self._merge_exog(df) # check if we have enough values in regressors @@ -708,6 +752,78 @@ def to_dataset(df: pd.DataFrame) -> pd.DataFrame: df_copy = df_copy.sort_index(axis=1, level=(0, 1)) return df_copy + @staticmethod + def _hierarchical_structure_from_level_columns( + df: pd.DataFrame, level_columns: List[str], sep: str + ) -> HierarchicalStructure: + """Create hierarchical structure from dataframe columns.""" + df_level_columns = df[level_columns].astype("string") + + prev_level_name = level_columns[0] + for cur_level_name in level_columns[1:]: + df_level_columns[cur_level_name] = ( + df_level_columns[prev_level_name] + sep + df_level_columns[cur_level_name] + ) + prev_level_name = cur_level_name + + level_structure = {"total": list(df_level_columns[level_columns[0]].unique())} + cur_level_name = level_columns[0] + for next_level_name in level_columns[1:]: + cur_level_to_next_level_edges = df_level_columns[[cur_level_name, next_level_name]].drop_duplicates() + cur_level_to_next_level_adjacency_list = cur_level_to_next_level_edges.groupby(cur_level_name).agg(list) + level_structure.update(cur_level_to_next_level_adjacency_list.to_records()) + cur_level_name = next_level_name + + hierarchical_structure = HierarchicalStructure( + level_structure=level_structure, level_names=["total"] + level_columns + ) + return hierarchical_structure + + @staticmethod + def to_hierarchical_dataset( + df: pd.DataFrame, + level_columns: List[str], + keep_level_columns: bool = False, + sep: str = "_", + return_hierarchy: bool = True, + ) -> Tuple[pd.DataFrame, Optional[HierarchicalStructure]]: + """Convert pandas dataframe from long hierarchical to ETNA Dataset format. + + Parameters + ---------- + df: + Dataframe in long hierarchical format with columns [timestamp, target] + [level_columns] + [other_columns] + level_columns: + Columns of dataframe defines the levels in the hierarchy in order + from top to bottom i.e [level_name_1, level_name_2, ...]. Names of the columns will be used as + names of the levels in hierarchy. + keep_level_columns: + If true, leave the level columns in the result dataframe. + By default level columns are concatenated into "segment" column and dropped + sep: + String to concatenated the level names with + return_hierarchy: + If true, returns the hierarchical structure + + Returns + ------- + : + Dataframe in wide format and optionally hierarchical structure + """ + df_copy = df.copy(deep=True) + df_copy["segment"] = df_copy[level_columns].astype("string").agg(sep.join, axis=1) + if not keep_level_columns: + df_copy.drop(columns=level_columns, inplace=True) + df_copy = TSDataset.to_dataset(df_copy) + + hierarchical_structure = None + if return_hierarchy: + hierarchical_structure = TSDataset._hierarchical_structure_from_level_columns( + df=df, level_columns=level_columns, sep=sep + ) + + return df_copy, hierarchical_structure + def _find_all_borders( self, train_start: Optional[TTimestamp], @@ -848,13 +964,25 @@ def train_test_split( train_df = self.df[train_start_defined:train_end_defined][self.raw_df.columns] # type: ignore train_raw_df = self.raw_df[train_start_defined:train_end_defined] # type: ignore - train = TSDataset(df=train_df, df_exog=self.df_exog, freq=self.freq, known_future=self.known_future) + train = TSDataset( + df=train_df, + df_exog=self.df_exog, + freq=self.freq, + known_future=self.known_future, + hierarchical_structure=self.hierarchical_structure, + ) train.raw_df = train_raw_df train._regressors = self.regressors test_df = self.df[test_start_defined:test_end_defined][self.raw_df.columns] # type: ignore test_raw_df = self.raw_df[train_start_defined:test_end_defined] # type: ignore - test = TSDataset(df=test_df, df_exog=self.df_exog, freq=self.freq, known_future=self.known_future) + test = TSDataset( + df=test_df, + df_exog=self.df_exog, + freq=self.freq, + known_future=self.known_future, + hierarchical_structure=self.hierarchical_structure, + ) test.raw_df = test_raw_df test._regressors = self.regressors @@ -871,6 +999,65 @@ def index(self) -> pd.core.indexes.datetimes.DatetimeIndex: """ return self.df.index + def level_names(self) -> Optional[List[str]]: + """Return names of the levels in the hierarchical structure.""" + if self.hierarchical_structure is None: + return None + return self.hierarchical_structure.level_names + + def has_hierarchy(self) -> bool: + """Check whether dataset has hierarchical structure.""" + return self.hierarchical_structure is not None + + def get_level_dataset(self, target_level: str) -> "TSDataset": + """Generate new TSDataset on target level. + + Parameters + ---------- + target_level: + target level name + + Returns + ------- + TSDataset + generated dataset + """ + if self.hierarchical_structure is None or self.current_df_level is None: + raise ValueError("Method could be applied only to instances with a hierarchy!") + + current_level_segments = self.hierarchical_structure.get_level_segments(level_name=self.current_df_level) + target_level_segments = self.hierarchical_structure.get_level_segments(level_name=target_level) + + current_level_index = self.hierarchical_structure.get_level_depth(self.current_df_level) + target_level_index = self.hierarchical_structure.get_level_depth(target_level) + + if target_level_index > current_level_index: + raise ValueError("Target level should be higher in the hierarchy than the current level of dataframe!") + + if target_level_index < current_level_index: + summing_matrix = self.hierarchical_structure.get_summing_matrix( + target_level=target_level, source_level=self.current_df_level + ) + + target_level_df = get_level_dataframe( + df=self.df, + mapping_matrix=summing_matrix, + source_level_segments=current_level_segments, + target_level_segments=target_level_segments, + ) + + else: + target_names = tuple(get_target_with_quantiles(columns=self.columns)) + target_level_df = self[:, current_level_segments, target_names] + + return TSDataset( + df=target_level_df, + freq=self.freq, + df_exog=self.df_exog, + known_future=self.known_future, + hierarchical_structure=self.hierarchical_structure, + ) + @property def columns(self) -> pd.core.indexes.multi.MultiIndex: """Return columns of ``self.df``. diff --git a/etna/datasets/utils.py b/etna/datasets/utils.py index 0beef91ba..a2e56f073 100644 --- a/etna/datasets/utils.py +++ b/etna/datasets/utils.py @@ -1,9 +1,13 @@ +import re from enum import Enum from typing import List from typing import Optional from typing import Sequence +from typing import Set +import numpy as np import pandas as pd +from scipy.sparse import csr_matrix from etna import SETTINGS @@ -175,3 +179,90 @@ def set_columns_wide( df_left.loc[timestamps_left_index, (segments_left_index, features_left_index)] = right_value.values return df_left + + +def match_target_quantiles(features: Set[str]) -> Set[str]: + """Find quantiles in dataframe columns.""" + pattern = re.compile("target_\d+\.\d+$") + return {i for i in list(features) if pattern.match(i) is not None} + + +def get_target_with_quantiles(columns: pd.Index) -> Set[str]: + """Find "target" column and target quantiles among dataframe columns.""" + column_names = set(columns.get_level_values(level="feature")) + target_columns = match_target_quantiles(column_names) + if "target" in column_names: + target_columns.add("target") + return target_columns + + +def get_level_dataframe( + df: pd.DataFrame, + mapping_matrix: csr_matrix, + source_level_segments: List[str], + target_level_segments: List[str], +): + """Perform mapping to dataframe at the target level. + + Parameters + ---------- + df: + dataframe at the source level + mapping_matrix: + mapping matrix between levels + source_level_segments: + tuple of segments at the source level + target_level_segments: + tuple of segments at the target level + + Returns + ------- + : + dataframe at the target level + """ + target_names = tuple(get_target_with_quantiles(columns=df.columns)) + + num_target_names = len(target_names) + num_source_level_segments = len(source_level_segments) + num_target_level_segments = len(target_level_segments) + + if len(target_names) == 0: + raise ValueError("Provided dataframe has no columns with the target variable!") + + if set(df.columns.get_level_values(level="segment")) != set(source_level_segments): + raise ValueError("Segments mismatch for provided dataframe and `source_level_segments`!") + + if num_source_level_segments != mapping_matrix.shape[1]: + raise ValueError("Number of source level segments do not match mapping matrix number of columns!") + + if num_target_level_segments != mapping_matrix.shape[0]: + raise ValueError("Number of target level segments do not match mapping matrix number of columns!") + + df = df.loc[:, pd.IndexSlice[source_level_segments, target_names]] + + source_level_data = df.values # shape: (t, num_source_level_segments * num_target_names) + + source_level_data = source_level_data.reshape( + (-1, num_source_level_segments, num_target_names) + ) # shape: (t, num_source_level_segments, num_target_names) + source_level_data = np.swapaxes(source_level_data, 1, 2) # shape: (t, num_target_names, num_source_level_segments) + source_level_data = source_level_data.reshape( + (-1, num_source_level_segments) + ) # shape: (t * num_target_names, num_source_level_segments) + + target_level_data = source_level_data @ mapping_matrix.T + + target_level_data = target_level_data.reshape( + (-1, num_target_names, num_target_level_segments) + ) # shape: (t, num_target_names, num_target_level_segments) + target_level_data = np.swapaxes(target_level_data, 1, 2) # shape: (t, num_target_level_segments, num_target_names) + target_level_data = target_level_data.reshape( + (-1, num_target_names * num_target_level_segments) + ) # shape: (t, num_target_level_segments * num_target_names) + + target_level_segments = pd.MultiIndex.from_product( + [target_level_segments, target_names], names=["segment", "feature"] + ) + target_level_df = pd.DataFrame(data=target_level_data, index=df.index, columns=target_level_segments) + + return target_level_df diff --git a/etna/pipeline/__init__.py b/etna/pipeline/__init__.py index ceb9fbb34..e49e4df2f 100644 --- a/etna/pipeline/__init__.py +++ b/etna/pipeline/__init__.py @@ -1,4 +1,5 @@ from etna.pipeline.assembling_pipelines import assemble_pipelines from etna.pipeline.autoregressive_pipeline import AutoRegressivePipeline from etna.pipeline.base import FoldMask +from etna.pipeline.hierarchical_pipeline import HierarchicalPipeline from etna.pipeline.pipeline import Pipeline diff --git a/etna/pipeline/base.py b/etna/pipeline/base.py index 2c6939783..1bc5b7c5f 100644 --- a/etna/pipeline/base.py +++ b/etna/pipeline/base.py @@ -270,10 +270,22 @@ def _forecast_prediction_interval( raise ValueError("Pipeline is not fitted! Fit the Pipeline before calling forecast method.") with tslogger.disable(): _, forecasts, _ = self.backtest(ts=self.ts, metrics=[MAE()], n_folds=n_folds) - forecasts = TSDataset(df=forecasts, freq=self.ts.freq) + + self._add_forecast_borders(backtest_forecasts=forecasts, quantiles=quantiles, predictions=predictions) + + return predictions + + def _add_forecast_borders( + self, backtest_forecasts: pd.DataFrame, quantiles: Sequence[float], predictions: TSDataset + ) -> None: + """Estimate prediction intervals and add to the forecasts.""" + if self.ts is None: + raise ValueError("Pipeline is not fitted!") + + backtest_forecasts = TSDataset(df=backtest_forecasts, freq=self.ts.freq) residuals = ( - forecasts.loc[:, pd.IndexSlice[:, "target"]] - - self.ts[forecasts.index.min() : forecasts.index.max(), :, "target"] + backtest_forecasts.loc[:, pd.IndexSlice[:, "target"]] + - self.ts[backtest_forecasts.index.min() : backtest_forecasts.index.max(), :, "target"] ) sigma = np.std(residuals.values, axis=0) @@ -286,8 +298,6 @@ def _forecast_prediction_interval( predictions.df = pd.concat([predictions.df] + borders, axis=1).sort_index(axis=1, level=(0, 1)) - return predictions - def forecast( self, prediction_interval: bool = False, quantiles: Sequence[float] = (0.025, 0.975), n_folds: int = 3 ) -> TSDataset: @@ -502,8 +512,9 @@ def _generate_folds_datasets( ) yield train, test - @staticmethod - def _compute_metrics(metrics: List[Metric], y_true: TSDataset, y_pred: TSDataset) -> Dict[str, Dict[str, float]]: + def _compute_metrics( + self, metrics: List[Metric], y_true: TSDataset, y_pred: TSDataset + ) -> Dict[str, Dict[str, float]]: """Compute metrics for given y_true, y_pred.""" metrics_values: Dict[str, Dict[str, float]] = {} for metric in metrics: @@ -535,7 +546,7 @@ def _run_fold( test.df = test.df.loc[mask.target_timestamps] fold["forecast"] = forecast - fold["metrics"] = deepcopy(self._compute_metrics(metrics=metrics, y_true=test, y_pred=forecast)) + fold["metrics"] = deepcopy(pipeline._compute_metrics(metrics=metrics, y_true=test, y_pred=forecast)) tslogger.log_backtest_run(pd.DataFrame(fold["metrics"]), forecast.to_pandas(), test.to_pandas()) tslogger.finish_experiment() diff --git a/etna/pipeline/hierarchical_pipeline.py b/etna/pipeline/hierarchical_pipeline.py new file mode 100644 index 000000000..c7ca590ed --- /dev/null +++ b/etna/pipeline/hierarchical_pipeline.py @@ -0,0 +1,156 @@ +from copy import deepcopy +from typing import Dict +from typing import List +from typing import Optional +from typing import Sequence + +from etna.datasets.tsdataset import TSDataset +from etna.datasets.utils import get_target_with_quantiles +from etna.loggers import tslogger +from etna.metrics import MAE +from etna.metrics import Metric +from etna.models.base import ModelType +from etna.pipeline.pipeline import Pipeline +from etna.reconciliation.base import BaseReconciliator +from etna.transforms.base import Transform + + +class HierarchicalPipeline(Pipeline): + """Pipeline of transforms with a final estimator for hierarchical time series data.""" + + def __init__( + self, reconciliator: BaseReconciliator, model: ModelType, transforms: Sequence[Transform] = (), horizon: int = 1 + ): + """Create instance of HierarchicalPipeline with given parameters. + + Parameters + ---------- + reconciliator: + Instance of reconciliation method + model: + Instance of the etna Model + transforms: + Sequence of the transforms + horizon: + Number of timestamps in the future for forecasting + + Warnings + -------- + Estimation of forecast intervals with `forecast(prediction_interval=True)` method and + `BottomUpReconciliator` may be not reliable. + """ + super().__init__(model=model, transforms=transforms, horizon=horizon) + self.reconciliator = reconciliator + self._fit_ts: Optional[TSDataset] = None + + def fit(self, ts: TSDataset) -> "HierarchicalPipeline": + """Fit the HierarchicalPipeline. + + Fit and apply given transforms to the data, then fit the model on the transformed data. + Provided hierarchical dataset will be aggregated to the source level before fitting pipeline. + + Parameters + ---------- + ts: + Dataset with hierarchical timeseries data + + Returns + ------- + : + Fitted HierarchicalPipeline instance + """ + self._fit_ts = deepcopy(ts) + + self.reconciliator.fit(ts=ts) + ts = self.reconciliator.aggregate(ts=ts) + super().fit(ts=ts) + return self + + def raw_forecast( + self, prediction_interval: bool = False, quantiles: Sequence[float] = (0.25, 0.75), n_folds: int = 3 + ) -> TSDataset: + """Make a prediction for target at the source level of hierarchy. + + Parameters + ---------- + prediction_interval: + If True returns prediction interval for forecast + quantiles: + Levels of prediction distribution. By default 2.5% and 97.5% taken to form a 95% prediction interval + n_folds: + Number of folds to use in the backtest for prediction interval estimation + + Returns + ------- + : + Dataset with predictions at the source level + """ + forecast = super().forecast(prediction_interval=prediction_interval, quantiles=quantiles, n_folds=n_folds) + target_columns = tuple(get_target_with_quantiles(columns=forecast.columns)) + + hierarchical_forecast = TSDataset( + df=forecast[..., target_columns], + freq=forecast.freq, + df_exog=forecast.df_exog, + known_future=forecast.known_future, + hierarchical_structure=self.ts.hierarchical_structure, # type: ignore + ) + return hierarchical_forecast + + def forecast( + self, prediction_interval: bool = False, quantiles: Sequence[float] = (0.025, 0.975), n_folds: int = 3 + ) -> TSDataset: + """Make a prediction for target at the source level of hierarchy and make reconciliation to target level. + + Parameters + ---------- + prediction_interval: + If True returns prediction interval for forecast + quantiles: + Levels of prediction distribution. By default 2.5% and 97.5% taken to form a 95% prediction interval + n_folds: + Number of folds to use in the backtest for prediction interval estimation + + Returns + ------- + : + Dataset with predictions at the target level of hierarchy. + """ + forecast = self.raw_forecast(prediction_interval=prediction_interval, quantiles=quantiles, n_folds=n_folds) + forecast_reconciled = self.reconciliator.reconcile(forecast) + return forecast_reconciled + + def _compute_metrics( + self, metrics: List[Metric], y_true: TSDataset, y_pred: TSDataset + ) -> Dict[str, Dict[str, float]]: + """Compute metrics for given y_true, y_pred.""" + if y_true.current_df_level != self.reconciliator.target_level: + y_true = y_true.get_level_dataset(self.reconciliator.target_level) + + if y_pred.current_df_level == self.reconciliator.source_level: + y_pred = self.reconciliator.reconcile(y_pred) + + metrics_values: Dict[str, Dict[str, float]] = {} + for metric in metrics: + metrics_values[metric.name] = metric(y_true=y_true, y_pred=y_pred) # type: ignore + return metrics_values + + def _forecast_prediction_interval( + self, predictions: TSDataset, quantiles: Sequence[float], n_folds: int + ) -> TSDataset: + """Add prediction intervals to the forecasts.""" + self.forecast, self.raw_forecast = self.raw_forecast, self.forecast # type: ignore + + if self.ts is None or self._fit_ts is None: + raise ValueError("Pipeline is not fitted! Fit the Pipeline before calling forecast method.") + + # TODO: rework intervals estimation for `BottomUpReconciliator` + + with tslogger.disable(): + _, forecasts, _ = self.backtest(ts=self._fit_ts, metrics=[MAE()], n_folds=n_folds) + + self._add_forecast_borders(backtest_forecasts=forecasts, quantiles=quantiles, predictions=predictions) + + self.forecast, self.raw_forecast = self.raw_forecast, self.forecast # type: ignore + + return predictions diff --git a/etna/reconciliation/__init__.py b/etna/reconciliation/__init__.py new file mode 100644 index 000000000..379be5833 --- /dev/null +++ b/etna/reconciliation/__init__.py @@ -0,0 +1,3 @@ +from etna.reconciliation.base import BaseReconciliator +from etna.reconciliation.bottom_up import BottomUpReconciliator +from etna.reconciliation.top_down import TopDownReconciliator diff --git a/etna/reconciliation/base.py b/etna/reconciliation/base.py new file mode 100644 index 000000000..8cca0cba7 --- /dev/null +++ b/etna/reconciliation/base.py @@ -0,0 +1,100 @@ +from abc import ABC +from abc import abstractmethod +from typing import Optional + +from scipy.sparse import csr_matrix + +from etna.core import BaseMixin +from etna.datasets import TSDataset +from etna.datasets.utils import get_level_dataframe + + +class BaseReconciliator(ABC, BaseMixin): + """Base class to hold reconciliation methods.""" + + def __init__(self, target_level: str, source_level: str): + """Init BaseReconciliator. + + Parameters + ---------- + target_level: + Level to be reconciled from the forecasts. + source_level: + Level to be forecasted. + """ + self.target_level = target_level + self.source_level = source_level + self.mapping_matrix: Optional[csr_matrix] = None + + @abstractmethod + def fit(self, ts: TSDataset) -> "BaseReconciliator": + """Fit the reconciliator parameters. + + Parameters + ---------- + ts: + TSDataset on the level which is lower or equal to ``target_level``, ``source_level``. + + Returns + ------- + : + Fitted instance of reconciliator. + """ + pass + + def aggregate(self, ts: TSDataset) -> TSDataset: + """Aggregate the dataset to the ``source_level``. + + Parameters + ---------- + ts: + TSDataset on the level which is lower or equal to ``source_level``. + + Returns + ------- + : + TSDataset on the ``source_level``. + """ + ts_aggregated = ts.get_level_dataset(target_level=self.source_level) + return ts_aggregated + + def reconcile(self, ts: TSDataset) -> TSDataset: + """Reconcile the forecasts in the dataset. + + Parameters + ---------- + ts: + TSDataset on the ``source_level``. + + Returns + ------- + : + TSDataset on the ``target_level``. + """ + if self.mapping_matrix is None: + raise ValueError(f"Reconciliator is not fitted!") + + if ts.hierarchical_structure is None: + raise ValueError(f"Passed dataset has no hierarchical structure!") + + if ts.current_df_level != self.source_level: + raise ValueError(f"Dataset should be on the {self.source_level} level!") + + current_level_segments = ts.hierarchical_structure.get_level_segments(level_name=self.source_level) + target_level_segments = ts.hierarchical_structure.get_level_segments(level_name=self.target_level) + + df_reconciled = get_level_dataframe( + df=ts.to_pandas(), + mapping_matrix=self.mapping_matrix, + source_level_segments=current_level_segments, + target_level_segments=target_level_segments, + ) + + ts_reconciled = TSDataset( + df=df_reconciled, + freq=ts.freq, + df_exog=ts.df_exog, + known_future=ts.known_future, + hierarchical_structure=ts.hierarchical_structure, + ) + return ts_reconciled diff --git a/etna/reconciliation/bottom_up.py b/etna/reconciliation/bottom_up.py new file mode 100644 index 000000000..f45598f65 --- /dev/null +++ b/etna/reconciliation/bottom_up.py @@ -0,0 +1,50 @@ +from etna.datasets import TSDataset +from etna.reconciliation.base import BaseReconciliator + + +class BottomUpReconciliator(BaseReconciliator): + """Bottom-up reconciliation.""" + + def __init__(self, target_level: str, source_level: str): + """Create bottom-up reconciliator from ``source_level`` to ``target_level``. + + Parameters + ---------- + target_level: + Level to be reconciled from the forecasts. + source_level: + Level to be forecasted. + """ + super().__init__(target_level=target_level, source_level=source_level) + + def fit(self, ts: TSDataset) -> "BottomUpReconciliator": + """Fit the reconciliator parameters. + + Parameters + ---------- + ts: + TSDataset on the level which is lower or equal to ``target_level``, ``source_level``. + + Returns + ------- + : + Fitted instance of reconciliator. + """ + if ts.hierarchical_structure is None: + raise ValueError(f"The method can be applied only to instances with a hierarchy!") + + current_level_index = ts.hierarchical_structure.get_level_depth(ts.current_df_level) # type: ignore + source_level_index = ts.hierarchical_structure.get_level_depth(self.source_level) + target_level_index = ts.hierarchical_structure.get_level_depth(self.target_level) + + if source_level_index < target_level_index: + raise ValueError("Source level should be lower or equal in the hierarchy than the target level!") + + if current_level_index < source_level_index: + raise ValueError("Current TSDataset level should be lower or equal in the hierarchy than the source level!") + + self.mapping_matrix = ts.hierarchical_structure.get_summing_matrix( + target_level=self.target_level, source_level=self.source_level + ) + + return self diff --git a/etna/reconciliation/top_down.py b/etna/reconciliation/top_down.py new file mode 100644 index 000000000..3f6fdd80a --- /dev/null +++ b/etna/reconciliation/top_down.py @@ -0,0 +1,139 @@ +from enum import Enum + +import bottleneck as bn +import pandas as pd +from scipy.sparse import lil_matrix + +from etna.datasets import TSDataset +from etna.reconciliation.base import BaseReconciliator + + +class ReconciliationProportionsMethod(str, Enum): + """Enum for different reconciliation proportions methods.""" + + AHP = "AHP" + PHA = "PHA" + + @classmethod + def _missing_(cls, method): + raise ValueError( + f"Unable to recognize reconciliation method '{method}'! " + f"Supported methods: {', '.join(sorted(m for m in cls))}." + ) + + +class TopDownReconciliator(BaseReconciliator): + """Top-down reconciliation methods. + + Notes + ----- + Top-down reconciliation methods support only non-negative data. + """ + + def __init__(self, target_level: str, source_level: str, period: int, method: str): + """Create top-down reconciliator from ``source_level`` to ``target_level``. + + Parameters + ---------- + target_level: + Level to be reconciled from the forecasts. + source_level: + Level to be forecasted. + period: + Period length for calculation reconciliation proportions. + method: + Proportions calculation method. Selects last ``period`` timestamps for estimation. + Currently supported options: + + * AHP - Average historical proportions + + * PHA - Proportions of the historical averages + """ + super().__init__(target_level=target_level, source_level=source_level) + + if period < 1: + raise ValueError("Period length must be positive!") + + self.period = period + self.method = method + + proportions_method = ReconciliationProportionsMethod(method) + if proportions_method == ReconciliationProportionsMethod.AHP: + self._proportions_method_func = self._estimate_ahp_proportion + elif proportions_method == ReconciliationProportionsMethod.PHA: + self._proportions_method_func = self._estimate_pha_proportion + else: + raise ValueError(f"Failed to initialize proportions calculation method with name '{method}'!") + + def fit(self, ts: TSDataset) -> "TopDownReconciliator": + """Fit the reconciliator parameters. + + Parameters + ---------- + ts: + TSDataset on the level which is lower or equal to ``target_level``, ``source_level``. + + Returns + ------- + : + Fitted instance of reconciliator. + """ + if ts.hierarchical_structure is None: + raise ValueError(f"The method can be applied only to instances with a hierarchy!") + + current_level_index = ts.hierarchical_structure.get_level_depth(ts.current_df_level) # type: ignore + source_level_index = ts.hierarchical_structure.get_level_depth(self.source_level) + target_level_index = ts.hierarchical_structure.get_level_depth(self.target_level) + + if target_level_index < source_level_index: + raise ValueError("Target level should be lower or equal in the hierarchy than the source level!") + + if current_level_index < target_level_index: + raise ValueError("Current TSDataset level should be lower or equal in the hierarchy than the target level!") + + if (ts[..., "target"] < 0).values.any(): + raise ValueError("Provided dataset should not contain any negative numbers!") + + source_level_ts = ts.get_level_dataset(self.source_level) + target_level_ts = ts.get_level_dataset(self.target_level) + + if source_level_index < target_level_index: + + summing_matrix = target_level_ts.hierarchical_structure.get_summing_matrix( # type: ignore + target_level=self.source_level, source_level=self.target_level + ) + + source_level_segments = source_level_ts.hierarchical_structure.get_level_segments(self.source_level) # type: ignore + target_level_segments = target_level_ts.hierarchical_structure.get_level_segments(self.target_level) # type: ignore + + self.mapping_matrix = lil_matrix((len(target_level_segments), len(source_level_segments))) + + for source_index, target_index in zip(*summing_matrix.nonzero()): + source_segment = source_level_segments[source_index] + target_segment = target_level_segments[target_index] + + self.mapping_matrix[target_index, source_index] = self._proportions_method_func( # type: ignore + target_series=target_level_ts[:, target_segment, "target"], + source_series=source_level_ts[:, source_segment, "target"], + ) + + self.mapping_matrix = self.mapping_matrix.tocsr() + + else: + self.mapping_matrix = target_level_ts.hierarchical_structure.get_summing_matrix( # type: ignore + target_level=self.target_level, source_level=self.source_level + ) + + return self + + def _estimate_ahp_proportion(self, target_series: pd.Series, source_series: pd.Series) -> float: + """Calculate reconciliation proportion with Average historical proportions method.""" + data = pd.concat((target_series, source_series), axis=1).values + data = data[-self.period :] + return bn.nanmean(data[..., 0] / data[..., 1]) + + def _estimate_pha_proportion(self, target_series: pd.Series, source_series: pd.Series) -> float: + """Calculate reconciliation proportion with Proportions of the historical averages method.""" + target_data = target_series.values + source_data = source_series.values + return bn.nanmean(target_data[-self.period :]) / bn.nanmean(source_data[-self.period :]) diff --git a/etna/transforms/utils.py b/etna/transforms/utils.py index 7b3c7f971..3e6615257 100644 --- a/etna/transforms/utils.py +++ b/etna/transforms/utils.py @@ -1,8 +1 @@ -import re -from typing import Set - - -def match_target_quantiles(features: Set[str]) -> Set[str]: - """Find quantiles in dataframe columns.""" - pattern = re.compile("target_\d+\.\d+$") - return {i for i in list(features) if pattern.match(i) is not None} +from etna.datasets.utils import match_target_quantiles # noqa: F401 diff --git a/examples/README.md b/examples/README.md index 67662d41d..06666eb24 100644 --- a/examples/README.md +++ b/examples/README.md @@ -64,3 +64,9 @@ We have prepared a set of tutorials for an easy introduction: #### 10. [Inference: using saved pipeline on a new data](https://github.com/tinkoff-ai/etna/tree/master/examples/inference.ipynb) - Fitting and saving pipeline - Using saved pipeline on a new data + +#### 11. [Hierarchical time series](https://github.com/tinkoff-ai/etna/tree/master/examples/hierarchical_pipeline.ipynb) +- Hierarchical time series +- Hierarchical structure +- Reconciliation methods +- Exogenous variables for hierarchical forecasts diff --git a/examples/assets/hierarchical_pipeline/hierarchy.png b/examples/assets/hierarchical_pipeline/hierarchy.png new file mode 100644 index 0000000000000000000000000000000000000000..08e7244397ea5312c14a3539864c079b368ad5f8 GIT binary patch literal 21275 zcmbTec|6qL`!|jdk;W2A+4sUQ*6fVk7>u3BZfs*WWUcI5MP;W^h(wW)Em>MD+1Kp* zzD4%?%zOKM-uL%=|Ni*B9zCY%HD|rfb*^(=&*ycD(bZAAc;V^=0s?}I8tQO;0sSw;jZ0uz7=3kVAvK&7}~iZJ|# zh?Jlx3~Kggd22^|_rDqx6NG{Wc#YBSb{Kze@KwhEdCaF2AKxYacrX6fdmy(zyA(6? zu(1_KII1FS!Tb#LC2UlT6tT{d!mf70lH%Thx^`~Hq7G`-CZ5hpswOC7T@h%2q?)9X zS)jP2skMl>gCf#ISw#}mmDEJReUv1HO}IpqbWui{k_bmhH5U;_QF}uJA5AAkWyL^k z6`&YOs%9u}A5|9vl&F@ylLywr$jQ^eUt7vh%GJcs&fZzyO9zQU=o?@WZZ6g)SOZmI zds`hfsE59fo42sNv7Nr8wTq^tubZ8yt&1Df1_njy>WgW3`s-lL9Hf-=0>M!94E$|G zjMP1h_2G8L2uT-ZQ7=`UKvR^lN}#^3nmsbW2jS(eqwI0lvPBW)HQr0yj@M*{7rrI zB|U{Toj?^~mp};{XKgzd4JovujUozTVCpT2@k6;wU_o1=K3ElBSB$BvV*o}M<>9P} z!eaE%o`$OWhGrh#j$*oc9va3HP9l0}jFU3V(G%)y=%|cFX=|zk_=%c|N!fbn*#sK6 z>fq&Qs%8WSi=(6^V(JxOr|fI(sVJ)9CZVNm>ZFfwHAdOE`D$ya>KLd9J6ogGbak}@ zGz}3LPh~@26-6TvO-WC0VHJ$2r<$9Fx1*`ClDeULfQBwcSlPt{Vdkr^uO5g|)-==e z1sg@xP+ZR!ZHg5^i)ko&2e@e}>scFPksi7(J}#n~aCIkVLszsED2Fx=)34TdT81p{nTI{njTI820l&*6MH>SQ%zM<#KTu# zTf^H<(?<;D<03AF!Ki5KVf8UEcMXRCA6-KeaSvC0J25XiRkUY-FV;=PSs!bIhS}*F zyGnTaic1)Kio0X|F;WhuKsef}K4>(;&EL-63k8+5LmIehIY_$d;wPgmE~?|8<_-5# zgKAqF>WkW|+Sp1s*`YnbW(ET>aFXy2@G_Ee@<3xn0+mdSZ9RNFjKD8YO&=u#O&cXK zOn|qdsEee8qKUPlohi~D21AJ3oB0KbL9uT3DuGV^C?v*7*IQk~*G>_wZ13;y4ObHO zMw>x3krIxY2rxPi5eF;^Wuk8j)eBJ3RC3fr0o$PpiaTjaDWesgY?0zlrcez8M#B}T zet@o|JsOJjae+Cz!eAO2Iz9n`_DBt^zadt{4X)vdHiE(o&^CS!s-7Ya5->$a(3U;I z&)3({M%>!h+gVphN5dNG=;&giVv9oBY3b=;d^B`a)D@*%g+*Z+b{a@yC0peH_W*UI z6IMeFDZjQFbP*Yze zJ5>)wV;@^(LsyiSu86y~n2#YCgOVrQRo5SHq40lzwN?Gwy8U75z`y@GHef0%Es^gC z2v`X;;ED!V%eAbFSOdfK1uK06sEQ&D3I>Hc(9*W`d)~TA+nXyQDazt<^{O?~qe-g# zgr**AR#_^R1aYU?yzAlN4vMw~SGc0vd*1z!47%A%$?tQEI$>q<7B_btBfSGgTPZe{ z(_m=9UZ-1lHcxJSP-;IkCo0Olb|G{B_!N_MI9R=q4U6^&Cm;)JY$W!I;r%|dwVCtr z-ls`{tgI{ve;*&pIdf85kxn#(G4yt!N<3Bo;k_!cY7Lah6kA~V>edi zDe5{@mMT3wFrdn#lbQQD>XOy?M}LeOE2L_ur)S%neMAGE(>6xHPe2GEAWSA}lu0AL z+7T9(K)y-Y+q1W~hhdDspe&q&WpE!w!)*qbW|Q`0LYRt@#IWSV2LG z2#Wls5la=`xIxaz$?1Ac%FAvlTlQ|`!tAViTz*YWx%|%w(Ro6M z?mdF`WEv41EvpAdcuGo2u&Jr(%>LTwv%)qxgb{2$7?+RC88<_rylY>*Lf7vLKh zad;VdnP>Vf-{s)Bxw-5z3JpJ7YisMt#_)3u5#f)uB_B`_#X`aY^x^70=ZI4Y9=ZvQ$!5C86l_T zuE++xFqpvIrO^-w883ohUIQJ_Ep?t6DXNa9B|Q7d6W_P>^`~k> zt*jQ#RhNhrSdD-v-_pPXAZ6JK;?%*qhK4F47&VB1>Ub-wgv7JBiMRyOyqD2lKocH; z@y0%&Une1CqEXk-kV2W61xCd)-C2%Emgupl_L?g|AQ1k`hCIuJK+(hNfhgD{6cpTv ziVJkIYOrUgC!gomRad)_8RBcx;pfhaPu;8?GR2XCvZr8mPDrRA3hendDAKI z+0<9ncD{U5L`38_)H)>6*w*K6Aq%lWV>Qsn%wpLfc^TJoPOlJ9~WO#}8cv`1#KQ z_m6;KQDCJ3l6dr^sgCsyEs?0>?rAdAK=+WDcm6i16I>8>D?d`55#l;y2cvSooqWAX+$Rr;DwWBpu%X!J& z=#n+?+7wn_Sg5`n!gwVJzgJD(61n8WKZvhUZ2SafcI~EIjUZR}Xe8D3!?zO~4I1AT z5tJh$5Z8p`5b(gfPpr!l`>s zz`XKxTRoDrQRNBiqz9N86_UWmN{~oHL&J*smJhF(kPz}sqd;`!yLXTCwbOclq4J6& z>B!Rn3KszugYTwbP5!m;RlPjb>ljzpYxx#kl*_u#&d#IN-V3cW?b7!N^2{I}ykM_K zT}X1*(zv7I>AB^NwPUFFo9_O(1Fzsd)~}+)i|UYvVCMlPA)yA@@1XkS>(YK>896yZ z8qEVy0h^PG=@OXn^Xd6>CM*}(fqm4C0K($e#Iby%^{BQmPS41={W)^#)$J;q5AWaW zd3kw387Zl#K2WX$OG6MI4pi(Sob2p6@v0)L1P8|quXt;`1?5}oc{qDfY^|s!-e5K+ z;0@SiR@U?3JdR!2bz2=06T9E2-vr}=R^PTAfArt5RDzcd^YFuoh*(*Hz2`~mS6)qI zo7}=yzBYTm9MvVW6K{L?+E~phaEZ(8y>ovTQMkfeOJ5u_5SAk?Fb|cdi*=mDztx`_ zTLt?bb0o>ws5fskoXN>jGE=pLxtNEdNPD_CB3P8UV%yMZ&&&Dxh#;*Zmp051V{R#rm7 z(!o)m`S!F7wIpUYY(+QYbEj{q_Sv1;vReM*0Sm?<4e~7>SSPR@8maNh$OgN`CmsI# z_wQ0M1d~mxI7#r8=RX=6w!hsN^w7zcNgE;vkjZv|89omUJzOd%m5rS#K5Fn6N`Cy< ztK{{Sq&4hEqMR=O>b$&jQGDM|6%~&0n+u8*rzWxvQwSxVo}RvU{rYuJ!Drt+bnDUz z^7Hduk~Y^H&O!qBh8;^=Hk&AfAf->1-E(W?muJp8jTcX|eZhF|ePV5CYxDQ@^}TYb z%ft5vSmm(mKo^D;Nty|W&e2u{{#x!2-{_f29B-%1;1pMO#bO7dOCMgb@lMIgx;}90 z)q3Q&n>)>Ksjm9C6@JDRD2L=;0Ekdoy14R(@xU7$sls{4oo4%m z&aAsxJ3C$Sp>E=!?t?TQKK*No_reT_H?_3+t0AvZMx$re=`mp`nr2fGfQJYp_Qp1xsMsgFV-vI3e$ z{c2va@e>h3Aoz8=fM2)b{eY`ZpRiq~FMpgQ$)EoIRfki8={d+(ewIVha7ro>6MhBr z*WhMs^5rJX^Jm=95$h?rphWRS`D02{CG>}A0(0=LUeT*p)kd-z+?FiATA45w<{LhX zKiE)}{EbvIoyYCDgRFP(hd0L(&K>e2hM==Dj+qmDWM9|3M?ki#159ZeFr^Tz*WT(d z(?rk?qV(F-2+y}R=DY0;clX!1;JfL;I3-SSREU6G@z_+hd7YD}`pT6nYA!A=v{`*< z*~9d-v|%%2V~I!6MZ1NUeDV8}pqfS8kQ(yRK7sD#AFc1fHH4d=PYucLIFJUr*` zf0~>m#I#<-yCM*7;Eh2Qr^F#;>0w<9#pA%Ys5>|~Fp6SF9G$HIg~*X>v+#Fd=uE+h^`C%_kb>siG!s{Ite*L_>r}C3S;ZWONDON>2-$Br0`@i`bHWlP*Q z#YWwKoB-LtT_1n1{*aQY^BT^HfIPKwTZ0na$n2Q z{szms4`**a;v*KgT8N>+uS;oRr}x9QK;!X_@$?FyaXvmiM@=oQja2#2`sZsy)!S>M zA9}w%xgqc7>|E3PmvqH|bYDH^2%K^y8iwyG8X8`no1K+Rg!}r6&CPEo->PA79Fo!r z@nIys$p$ghw03eI^_WVqKmV0x}XaSL;}b%6HIru@1ov)g&bnMBQgB>lDte{SGFv9 zkJ=V=!C_;|iOc~xV zTlSmU(#UAwckt03S5aAXCPNf5N3QODe*Qf6gPL!sBD@RI&1;SG8&Rf#gK6f9fN4JI zSD7&2@%RWugbHf+-jJ1UkNX1VW?BUf?kN$BZR03)d%D=l^uf-;_al>Qr)5w*<%pQW z^ZhMqXF#y8fezbUzgs~`)J>yKpO1bHMF`R88OuMB)qvZNR=ROqk?5Jd2xYiw()jbn zg_*5u@-?$*g*w(aDmcZJInaPYt`J_YFG3#1HR`jQDZ)dA;(HIEpSyH4G=`=8R)_jZ z%s)^)o4P3@SL{C36u|%@z~Q03K2eFLi%=F??jnuULJn)MCO5pbg0cc@MMynG3MnHA z=V^9Dy@w)L)UNb9e;BO`0s+zb?*Q#m!UP}(e5Lk z$y38`%pzdK^lw}5BK?RCnnZ&U@MG+HD%JIG+Xq1zwU!&QFJHclAt52TW$+4l89YI% zxBm{};0Affj^^`cZ(|O1Qwue-mQ}<+dsM&}%_?+v+w+l)@i2GUN2ldWT#UXov9ME@ z3Z6=zrJiOSyQRKehoO;?>dIg##tl5+1o>bt>7>i)`eg`yVwTEV*Pf<0g);+|NE$>) z@ACMgzkVg|Y1Zx)74EGVCT2u|Xz70jx8?s!aEpXTQDRbgbr^ZY_0E)23(wq?J~|PB zFpnCnMJ!p2yE?LwL%Y{YxVQO8j6{3A|Nj+-zi5-wt3~$4(DEz-)qS?V{&SR?j99R; zqT)C+8jH#ALag7AXp=5GOMvRv668pXx|Xq0iIRgYb_M9(S7PNVO`6{~tTXpwhVv(e zbDs^vn_>De;Dc`HXld;%%+J@nD=)80OG|6UtRyfB-vjISWoGN)_c`1hsCu|U7%+s3 zO%W8$nD6}q13gQ<1t*4vhMyuMNikkt+iL0gTin?)pM6G1Pi7(5mGjvV)AHsIVh$U` zV=h{QgK+2}FlQH46x9>iplYvOrcD{L(M9h(yS)E@0wq;RDgZBNtEPp9Wg0wlh0?F%AH5ACVkb5GHb!tQHJ8zv!B{2!%4cfWGpJc zgvWD(T_tMWcZsNhXH&pBeVsZVeSUH%y)l%Id-?89ZG%wh>d_h)85LMEd9Y-MZ{Y~m z9}I*c*P1`xlLz}48Lr7sWZDFm<<#!ueuYyj16o)Ew7TayA3t&U3bBPveoO+b5PaSj zVWP}|xwm_7wuFbo2QU3=2qeqSV5ivDjR?;GVM8d!nQfd>-qwlB!tSXShqdXOD#_N!7$;n>Z znrLB7x4{?$ryl9x&)-Eei1^J(?|;nof972_57{kS#Y0xbCnb0smrHr@czDO?h0!zG0s0+tfq9fC7w&6NS==D1YP+p>D(q@t%!g0T0s zYGawtY17rIhE05#G37noxzq&T`t+j;SVc~Df*sCBCS?7hL_yG-8y7kq)8C>GrF+d( zm37!Ax!5hG+qjFgQwz&=mShgeiQeIzzuSt-Zt-tsS!CCnkdxNU#D~0f;;|X~d&^!Ya$5!~!?Y7-=W*Q|}Igd(I@X&9#o!Y#=> zvbVMSPMY9ugj~SJtpMN5w3Lt?8r8@`-n(0cmDxvmr>opx&7}x`3j`HUR~j`fv5N;g zP2iwS_D!b&Q6j-DY~lyzNvpNmH?M{(2D%|+i&~Rx);LuP+7>n9DKJ=Q8qk8Tjoq2A zi03J0c9GN6KO53VZe2Y`I77B+#v76;0pu#%Di)CrHbEdTv`RXK>UZ+7@V2A5T6Orw zt09eDa^h&G8mOKzSYDP~p#OL9t2=Wgk7S2ThF80nCA!VKL9`TIK}__NCRIS2PX6({ zm|jo_Gv4XF5B!`f#Pt2?nv;H8SohjgbYM&0KsO`VP*kJU1=@aA?LPTY9#Ie)FXPcF zH7%aKSqEU?%SJuo}T0C>gqSEtF9Oa2kr!B*>u%- zMnwR@%x+=7%z5cS8DA){DhLP&_)ImC+`7AW^Z#*d6rTV{+Q;M(<6TsGDmtCpAzhql z^qTW|!ZvC~HU}I%#km$2T>U$8l7e_|@lOd*b}nEn`Gj588xo@Dxw*Jb8fUgBbELc- zj%U9n1H?`uG4kTf#Yo$aBO@b=%`vn-zc;78g8*)gjF^bQ+{`RQ2QD46xHIIORJE%M zPTjUwn@S$1z095hy!aqYFkGG1A7u6`5T2%6Wm=; z0XGch%+K0-D)-*Kdu#y#0aMQsw>!Um(<~}1tg^7Q3=v~vo03S+FJXY`>kTmay=|Z= zM;tvsAKo6~^5g(8_iN%)cssmf(E)ewt%e$ScH_z21@hdhO5ua+e z!x0c}=RC!3DYr}*8As83JTiQ_@5mkYePPFKJFsp)<>^yydOErx?wWrA>-~9rwu{!$x<}yQ)J6}YTJ;QuCp(VMSmxYTUaH8sW(;zi-n{2zaD9PY zAZj2mG^)9%m}q3s*}dpf zgRvYghjgqSZz+0&T2;1+1axziD9C{`gV!>-nY*agoRN@ydV~7Zb;mI%SF7LCJvU^g z;f5;MSb>V=D1Cju#Ou=H0JbmreBdVa(dxKz95^XL*)jp|-U6K4T^_Tr<-BLceyIF* z^SCPl9cUqwee~+*QS5JEqTkR?rrC(=9S1njzK@GOhBpfZWU}Mlo|NeaWGcr>;$4>a zIQc)l>pb9}3%@v?G;8JJs38quq-w4^->%7B0!tRh^hI67(0d2qL z*}QFy`|TZkiF(U+--vzsT=l`wi|l91Gnv0)S&4dq$xTZ+2L3iI-)6f4`{9!0Wh$t{fo@24mGz!f>YI+PHA}%r7l^oC`M$@8fjDP>TUYk{~+X z{$d58-mvIo!0?GAN1qD&B%N!O;(VZtH|nfGfTXB@{rdIR(cYR?pKT5|Ma^accR)@d zaNs@+^z~iCfXGS*e!RK`+l!X9u|3z)R(z8Kw`SCqum_AXj@^TS`b3ucry%pNZ|Uqr z^rAeCUirD!>cYwA53d5dpgqvNuyYbP-5ewJ(QoZS$CuP=;hr5GX|T4?-NkNaMDOpy zs55Q&b@(hhdx*sXH%^CrnTA4ZLsXEaj}lmORzl0RR86>X^TJV)Q{35n*CBIGm&6F+ z%$PE%h-$w$UZZddu9^8dS%Fpo0o85I4hXlRvNGuAcdOHFZ5NjtLYxZ{p1&_47j`D_ z$N4NzBj)E8rPpE<^;z&qC=WOHQvh&2%(xe0FODHSRTF3{Ku&M;3M&0&ZS9hfREj2=yGDJtdAp2`E1uC*Xd zdO96yt+etARS?o`OI^)zn5c6vjF$XPrZ;Z6a}Cy>yfq@9mYsd-g+O$8T&9nTii=}- zadYzJ>ZcxG802PtuxEjxYHR)+LOb}z=r4qyL+5%hBP(MdMZKG z_>IK@8ik@Af746B52yPAVm2cDg(>TI1w8%Fcn@NCg^Yq)2#|*y1O-fq}t~>6SPbQN%|kJG^#qg?YZgtU)`!W(+=pQ|-`^ zmxOO*{TIbi_U%6u2We}JG^pdMPB1e@fV-1fZauJkjix<0Q6b=BW4f42E5qf>OD4r; zh9K<2V^U>%8Um)>Oy_}!CCY{v7f2mAkFJA=fe$JT^1e8 zIAhk_n?Sc1b;OFy5%%_6h%}LxZ(qKgqxZcVlN}xw z2m~Ba%GX-GY5En*)1C~FwWOaSO97-Y<>|8O105E3Sc**4dELDA?6)6FaheOHc19HXpDPQK#$_h zoG#-!+qT5JJ3BkRg&kK6ugPE^0(^2mJRdF9xhNb+PUJ<~D9KPA!nnZh$gM@5%5y!% z>tCdZq69FzjpxCJ7a0=EXLu-rZym8v2CDdJHSR6u6MjV?pcz2hP zc$O{9Q!uauRPE~Q5I>h>2;k2WM|RGNKl zGvgT3=tuxrp27#X{T%?3+0D-mt-KT#*Dh^1y>+|FBOGh=IyaZKHGA2r7LdjOS2GwL z8w<)feV?ZaoZ~vOA#^51TETgipiQQkk)w?U&)I86l;c;z^nrSHk%J~ z>>QH`++AdJ=_$qTf0~OEs?Tk0RUN7F>zmL7*RoaDFoEV2|4 z5E|lpSotpV>(_>nOPFV;Gf6y}Wu8y<4akGAa!D6&))}SS6e^3LVCXT&)TyNB-#(l8 zFcES$zOMg+WJ4?1%Yq<{ViIZ}+>H~5{CE7{#J#Pr$vn21FkV94H3{-btsAZJWxYZo>lftT&7>8qtMcv6 z{~RbWcl*VCA>>g@Y7j3K5jS;NujJrhJ8Ch?*32xu#32XTnU?Z`TjJaf9fybX8Y%-vk*3h=x?in+b_dXv=0K2S@cEjEBgWl?b#-6<>zt2d5S=R*x(Rq?|gW zV!2}_TxL{e{Y)hjKj_+M%h8<&2^sh`*eJOroz4^Q&CT`dYe%C@3)z^?QBIj+;Qmkl zyIJq3uH2AbJ9uz{{I4p>Yi`p2U0zUnBQkp*a5FQ(?$N_10y5I82M@CRZU64XK67S$ z@57xSKqOU;LKtTgE5YTVR3lmiNNO@FRP}0xG6`)hC+Q zV0mdo$G{-}0gy6mhXw^90&(;LHMRENpDx5WB^=hNTvAjZ2S%Dbg&Nq0OqJ_@Wu^5u<4oI){&hEw`Cb?ez7t-K)@rhF4Ik?Ky+$gj_1t(U}5<4 z-wo1sBT4I-@PNQjG?+NF1xSpWR|9iDpa)s?A1c0qzyI~u|DxiZ_aJpdg!aIjp+K2| zFCQggY;`tM5qz zZpP3sc!e>Z8JAakPwn{k$34+E4rQ`(!N^te{%x&*? z?_~U$`q21hJNI-*4@6pFG5F$yOf#7=rcTxA!IQF$kyEz z?636TT@6#`o5iM6( zRD{dSyk=|u&Uy_54=EYnk;{vgnbwZV;Sd_KMR1N^Yn2KX(ZHF38VoC6fiLw8P%4p^ znC3IR@8i$e8#h*Lf`dW4QN(xY-EHI};FlnVfK&-9Cv@I(3nuzFm>Tc8Q5fJKMZ2*BF{4)Kco$1ZLV=0EDWvrv_ z`pz?4mvKTHAms28*%YOgMZ#gxi^1pPy^ra)(*7~ac8KVKq0`}5Xab20-&fu0g z0)W!CMz=RKovJcM`ja}bOLKAAFATl4MSXDpdCMZQeDgPD0XWJnWdgSCY*DD*F{pTB zkO`wk_QZqS4W z(UHp+6(!^z_Pc+$dr;%MG6t+5H63iw+9FJD$W0hNs_*@8xDUD^-4S08*@4nwy&|UTQt#roX)PnBC>gzY(0; zpG}I_*{33&xHe+_*2}1VJ;$eLc$1C6;df&eGfh?)wTk(ZyeVR$5p?oR($o~3DenPB zSTN+v$BomsCLT|$tHJa^TNZZy2Nrcp!qe9Ge=0ELfTBDc~5!=YaKo=sDB+i1`S(AWRZwt**3lle~n);FGMZwT||7&Jc4sZNgET4_dIA0HDOznDqcvGS+KEaxZjF#DzO zo7S0Fn0op6Q$uv-Onq7SG|67z&6P5wgQO%)YQdKl7NPppfGOBUzRRyCT5wa)>Dkz2 zN>a?y>S`^>39?;$a>ID+_U+rA;JJPO{H68yq@;#fzmE+d%CRCP3P%!!69^s7%NrW` zE4{un8L?MOwOn;^T9FLcHuxp7ndTLL%eRo|)tRwS2#ZAtQhhYLrt00rDw-G%>Th3-hG2?@tWUt~-C92gDV&V-ENOwFq z9IcjY&DXsK=$-T7ERb%kCdsecWVg9@nyU`TNjKU%G1DhjcDX^o03HsyzWz8t{^2!OuIN7?BK91jtTlWF!g-lh}ywXkD^OKsS|{*^446 z?<0*C?mszD!`&8Nm#|Cmq%)?8gc@%0)i# zzu#hyB&vPr8w=GHDfwE)gMiQZ5FO7yi;8y8q*C(RY5u)IPa6Bd#q>rs^S7u=JG&b+ zUDTMqEo!z}ikpIWX(y8(A2?kLq3v2f>Ke-6#Ch`*J{$V{O63^xDC_=(`4}-q0m{B3 z`0n7lIGoop6bUbmj%MK6$z#n75kOWZ(ZV8XQ**?cLCyq(Ih@lvvlBwi;`*lPWo&?r z0G9jV#O7(}^m*KcL*BvjJQ)Spt+~79P@US98-mqj@pM;1xI~*Hy}g#AoH}?KVNA8W zX%f#3m*yy*w4-=$oMuP=K8pDuG8s(jIUF5^>gdgh5xSjy^|l>{EJV|XSe&-|q)FkV zbg@W*52BmhnAIe?;AJ$}Z5~BCDHc|1 z6q%4ro0#eq8!%ll&l2I9i=8@)qw4bhUQ_xm;P88Y>6Rjhy7e)_? z(h7dqOp&g(7a8G}`eyVVUB^W$hd%GzS~lbM;VQ0~Zb%2qmPn;Bt=j+kBZ z)6z7Y(fcZcped)23ESkjYQg%0i7@2Fr&8CQTZJ47#30LjQs@Vrb>?$Yo*K`|XH%E9rRy2kv3O;f<}CfS2b6CAZ}w zBWy^Q^GqgjtBu5VFy!&e4{_1V9JmiE``CQ+Yp^BqNf^i;D2EL;zYV`NycV6k1h2@5 zo^N~*bsw^CqJ5|EOK;%$tEblRJ~&+FJ3XY_J*d^5i25S^Swn|S@N1`2K5OOxp+Z|D z;>i?odl-@hqF1R-C5||rD~S1pkfa;nGcC4~Kk#r08JnV=&tO^D5;E;9zjl(Yh+xf? zCX}Pga2!4e>U8uey`6aBfQ~SvH(tK&d0fMJtohaK2+!pfwTD{{=`vp%LuW+7!nEhn z*Ql?v;-=ucB-AbBI~5}Yf`vV%#3AhMCvFBO0RPJzgfB@s+4& zP|07@!v$w*aOE1p$+l;;*pZlWou4~*YN0wZm2Te7_ZWvUY3;P2aCuaz7ThN`8p|QM z%IszrDVgU!9+B;DQ*WhxBp7+>RKzja-V(DKIqV?ZC5>(pX2 z7pJYt_95K`dV`7gSaXHKYao~j9zA&&2XEX{-FUHDqvX8Swu#ntU`Qq9+!OCtK7DGY& zn!B-lU1p|vX42f~ylwd&@YE;jXK0BtYX@ElnPL%)^3!v3WSR0+KCG{V!8RPJX&CbtcGm_Co6^ zVH-on%6C{!Z4Nb4f&a_AQme*yD}^jEL(YeM)LS!F;<`r66C8EfMyWiOo07mg?Qn8; zua_1O=*Y>)Sozl3x#!^E&;zcOp#R^#S}s1mc5@TRk_#TO0=bE*Hd?|dMo$Q~k%snXbyoaK-;;2s#ZVoc(Oq4x>uY;4Gh*H3NTeMJ^~{W4!KqxWc3j zfLFR=oOr0G?o1AC7U^C+D7Y6JysxMdY)nm!8Tkg&@pb<0U+2%i>}0%@O(*ny(^C-K zi`GZWCWs?tK37!K)t!t+L`2-baN$Da?A+FfyKU+JP7k+=$B(S35*#>6sO%=F83qHF zvb4c=rdg=DFyvKH-cR4V@Yh%&*k4?$ylHZ|;%4hc}-ke2^sD_e#kQ%qRdT{s({wjf2ZD=_a~&a=iI( z@bf#(eS1;@0&!(TFSdcg(kW3(far_-`(2PW-M=d<2yjT!Nho!l23(MRT8)q2s%VM# zqfsb7B^@UGlk~ujHl5JjDS3*f#Y~Q0eZ4Xzejk5+A1pCn5agShiChhE`wv(A_3r)6 zfnQa9tYOC}<+l33=jrveZHz3iId1lBH%%_~@z}ZZ&fnFGmI`AnY9k1{6t;-GIyaHzYt!Hg(l13pn7;Jb3Gl2HcRsy0)v!#7n4x-sdiXGO8%?j#%)jy-WDxzlBTB1R zU{dw*_FLO|NdHWE7) zZk&=(*YT8xJRiC|s7Df0b*WbJV55a{@gZ~8^fP*_^7qiG(5UQ&0^0s|)c?&C7sDRz z$E3>z2f6_Iaw|Wdd)%DlSpw+C^=2Z~b14NYIZZoe>a8nTwF;Sox3+9f06fKs+W3a5Hmj##aqtrjrE9w142qDf&GX+4{BBH9gOKfh%D5X#ya4J9(fB>V-7ab4w0$y(0-n-dOKnH?y;cU*3Hj)NntUbRQ8 z)P2#3*#%~HFZ%>V)g~kRFK#%)r#4#VPZ&-%$2HN7vWcYNEEq3=ONT|zg?tQ@+kTKy z$)5k=&AO9_$fsqA{ zq4VfpeXmdq8xsu;NA7u3!7F@zgB9prL1syhJG7lmFhi5soW>HTQ$;JW-H1fzMaHXR z&mB_(_w2LZjk)_9-oBk>9$CH|f2z6gf8BV?c^LZhvBIb0`}NL?6xaKI2=`GeUJ1{x zOXnFcu>I@u+or4_89ssJue)zoa%jFZFkLB53H@PwG5C4j)mxE)gt8pH94UAD=wFPu52R%{^4eSO3o_mq~Wxb-S3_le-p8i{ONhvK_rC@$vCL z!FAMB-{o_5;Ba%fefy^?;HY1DENKxONDv^u{%8Gkg>oI~0z*E_Y{9z3Tn6HW`1yPG zZ|FTA!JF@v*vlqIc47Hn8p5A{%lv0{ei^<^`=d_0A z3#?~N&3x#8%SFFyWgc*oOW&2Fs9w&`%|)0#@x~Ote*Jo1gE*$P%J7ADQ zUio)_yMdAIsDCb)P&Dv6PNb!$kNB;RRc}e|=sc5AI8}+(d-v}(zn{&vQ-Q0h`Yk2% zjw0s;)@NIWx`e!eZhiGW!M0FS-@}_{xusVRmDI$1D1Q+Jw&pH~;cZBf|LTw@CL|=} zfuYp?9spT-Y6EuUeuxURRnzS*No@;xC5p`bvvAKX6}pCYYaOyJ1P7nY&pQm3S!10U z1Utb+22C9uD$@I(K21IS6kJCOMwor&e^j*FN%=2=*+lX$g83d8;5UPZoc|tRA&8Fu zb0JUF^qFm(?B0;=>Ccb;LTtFfYZPpIi5zD*J6?@{%U!Qa5nRCjj;ED>x5K*NKBt+ph3;NL|3tt)*557M9M@R5Gz;@6Zf_Z0L zL(fixcjYQ^BC|P$hpKr@t1=J~tk)Ng8yZh1RD$^QE;phV+ca7_+9)0&X- z@Zn~{8OSbzgVcR|{@u!1J;*C?Idka0+~cha^2Haxy@#*uU*E3(i?1eqEd-q!cckz8 z?|Xg@?_Zfo`*k(^$UUCCe<_4RNcn!)6H?8{<8sXjC!f#1d-fM?{I93%sKN*cC>`7Auz(uE5M~qA&sN=j}!;U#(T1}V= z+8*zRgDY;gyyx3H^3;+lIUlZToV74m>Dt=bmQsn19ks7E50^VGbxLm>f7{sKmos#5 zfIhCY-#=b)(OqAe=txogV?64~rzozMlVFGE^fU zUq5hMP&EBdXJ_Y!z?Vp1aZf%=tDYTd^0T{D`we(X;zC<0n=8b$PRh4Iz6167dUai-Zz$IzrT=FsDm2BEbt`E{)WxN zT7)_1j6x51Wn&u@R1S;mil?SH3}=mX5cv(1TZt`aU{t?(7Utd5LLZInkr}kJ>Z1@K zeSKC+KA6=&Ya8h9$O>cyEO$ig?a9E#j(j@z35&~TgklNrR}Nn^Iw7*5oSHSXYH(7v zssgb8(+)YXgN&A#V%(=@!!NTiNiVthUq1?T4(~#5q`5rF4wkK4sg|^eP2(xf4&K|> z3ZDx(JyE*^b%ui7r*S$RoLCrp({=u$xmRY!9f|*_sjJS zmv<-=B?x9!W`nkvb?ZuMHbShrNaj^T7?pg#Z&oH>xn_g@yne9X7JA$k{N{7{h|C|e z82vwmT#G-HX&biNL0V*t^N?f7YMnphn2A9thuPE=Ib=i!OmZGZl+$J-P21QosLfEK z30aF#ezjvT%}lfmJABULn8QpOtL(!khi`v<-#;+V`#$r$_kG>hb6r<&_f2lZ0MXrm zD!wT8tTsa(8ss?VPh*xG$E92(;bL#^ZvMWzE;#UQ{H~AYEmFCxvJw39kLSzg1aLx^ zG7?h{#&XZhFXk9x9Lk63lAf|nE4wq*jqb~^b5Zq~D8u}U^E6!GtbFr*)L@zgCq7?W zH&RKW#>v~Ox1-=?+hyQc?dW-#1IKzd%nPdEx7n&PmQ%_X!}A4Yy|0C|n5qDVk_W}l zzTEQ1?z3c1=eFH%Jmu@hDr|i+>KkYu)wzu-sRmRjEZvWOSf-DsP}YMR3jamWITPCm zw#&nyr=ta4D0T~69*%cT8cp>F)gH1D<4zI#f+z~V1v;T}DKgdpinD8gX}RG};AVB- zn|YZ#W!%cWpnCS*QQefqbm=NBp(y?>OWxvPA*NWKW$a+hu^I~S_g4lJl=F{I^=NV} z>r#V^WT{wM))kb|?fV={v;8D zb#-1bPCKO)b-* zBrs|&d`&5r3|c034p;Bv1AKg*NPZ1Qo!1*W=#zbP4`b6ps4))L%B%b8l~tE$W9jRd zYH0Y@IG7m(j3do(pMaMRWaSW##8B~yrsl7jhE+D(yJ ztcdTe4LBG8hkg8s^Q7qaqbtb?r$VnsC;3zs^}Z%s>90m7IAi}*m(W;Rkn>1hTZ+;> z0j19Wb+fWObbk9B5-9|QM^+#Zn3X*cZCw*ITJWUwG4}ARD|k2b<^$$X9Ecqr?dl>A zwGcD)b#y5GEp2y#SCy0LRW5sruPxR>o;t1n+0`8FxOek6D;q+q(O|)i+65?8Z<4#* z#&k3_?^je*7!%VjOsG7TW~(p#Mlbf=j6)DiYAX~?!Hu<-q|d+WwK@Gm`GA6j;b(Pd z5Iqc}j4 zaOn3NCNWzYw|ygb34Q&EO5$79mf|kQPxi^qcWqf{rIG~+nbjgqNb~!btGMw$g$Jd} z8p1W!SW!KHFp}3Q=(e!aHY&^4oFvyfwqf zl`9N{LM+=984N}g`N0EWSXh`B@PqWUwjS-T7W@zEP=EE%7s9M`0" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "In many applications time series have a natural level structure. Time series with such properties can be disaggregated by attributes\n", + "from lower levels. On the other hand, this time series can be aggregated to higher levels to represent more general relations.\n", + "The set of possible levels forms the hierarchy of time series." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "![Hierarchy example](assets/hierarchical_pipeline/hierarchy.png)\n", + "*Two level hierarchical structure*" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Image above represents relations between members of the hierarchy. Middle and top levels can be disaggregated using members from\n", + "lower levels. For example\n", + "\n", + "$$\n", + "y_{A,t} = y_{AA,t} + y_{AB,t}\n", + "$$\n", + "\n", + "$$\n", + "y_{t} = y_{A,t} + y_{B,t}\n", + "$$\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "In matrix notation level aggregation could be written as\n", + "\n", + "\\begin{equation*}\n", + " \\begin{bmatrix}\n", + " y_{A,t} \\\\\n", + " y_t\n", + " \\end{bmatrix}\n", + " =\n", + " \\begin{bmatrix}\n", + " 1 & 1 & 0 \\\\\n", + " 1 & 1 & 1\n", + " \\end{bmatrix}\n", + " \\begin{bmatrix}\n", + " y_{AA,t} \\\\ y_{AB,t} \\\\ y_{B,t}\n", + " \\end{bmatrix}\n", + " =\n", + " S\n", + " \\begin{bmatrix}\n", + " y_{AA,t} \\\\ y_{AB,t} \\\\ y_{B,t}\n", + " \\end{bmatrix}\n", + "\\end{equation*}\n", + "where $S$ - summing matrix." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 2.Preparing dataset " + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Consider the Australian tourism dataset.\n", + "\n", + "This dataset consists of the following components:\n", + "\n", + "* `Total` - total domestic tourism demand,\n", + "* Tourism reasons components (`Hol` for holiday, `Bus` for business, etc)\n", + "* Components representing the \"region-reason\" division (`NSW - hol`, `NSW - bus`, etc)\n", + "* Components representing \"region - reason - city\" division (`NSW - hol - city`, `NSW - hol - noncity`, etc)\n", + "\n", + "We can see that these components form a hierarchy with the following levels::\n", + "1. Total\n", + "2. Tourism reason\n", + "3. Region\n", + "4. City" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": { + "collapsed": true + }, + "outputs": [], + "source": [ + "import pandas as pd\n", + "\n", + "pd.options.display.max_columns = 100" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + " % Total % Received % Xferd Average Speed Time Time Time Current\n", + " Dload Upload Total Spent Left Speed\n", + "\n", + " 0 0 0 0 0 0 0 0 --:--:-- --:--:-- --:--:-- 0\n", + " 0 0 0 0 0 0 0 0 --:--:-- --:--:-- --:--:-- 0\n", + "100 15664 100 15664 0 0 17366 0 --:--:-- --:--:-- --:--:-- 17385\n" + ] + } + ], + "source": [ + "!curl \"https://robjhyndman.com/data/hier1_with_names.csv\" --ssl-no-revoke -o \"hier1_with_names.csv\"" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
TotalHolVFRBusOthNSW - holVIC - holQLD - holSA - holWA - holTAS - holNT - holNSW - vfrVIC - vfrQLD - vfrSA - vfrWA - vfrTAS - vfrNT - vfrNSW - busVIC - busQLD - busSA - busWA - busTAS - busNT - busNSW - othVIC - othQLD - othSA - othWA - othTAS - othNT - othNSW - hol - cityNSW - hol - noncityVIC - hol - cityVIC - hol - noncityQLD - hol - cityQLD - hol - noncitySA - hol - citySA - hol - noncityWA - hol - cityWA - hol - noncityTAS - hol - cityTAS - hol - noncityNT - hol - cityNT - hol - noncityNSW - vfr - cityNSW - vfr - noncityVIC - vfr - cityVIC - vfr - noncityQLD - vfr - cityQLD - vfr - noncitySA - vfr - citySA - vfr - noncityWA - vfr - cityWA - vfr - noncityTAS - vfr - cityTAS - vfr - noncityNT - vfr - cityNT - vfr - noncityNSW - bus - cityNSW - bus - noncityVIC - bus - cityVIC - bus - noncityQLD - bus - cityQLD - bus - noncitySA - bus - citySA - bus - noncityWA - bus - cityWA - bus - noncityTAS - bus - cityTAS - bus - noncityNT - bus - cityNT - bus - noncityNSW - oth - cityNSW - oth - noncityVIC - oth - cityVIC - oth - noncityQLD - oth - cityQLD - oth - noncitySA - oth - citySA - oth - noncityWA - oth - cityWA - oth - noncityTAS - oth - cityTAS - oth - noncityNT - oth - cityNT - oth - noncity
timestamp
2006-01-0184503459062604298152740175891041290783089344921021879398599352902193178113503728852148209384414062232169064677023172051004330961449325317881468843908882201138320666191483101862709668925653428300322871324869101976260274828912011684116498411119823884565328741161071368039651018128643127124473168377624358
2006-02-01653122934720676118233466110276025631019352454109849878294107490214451353523517430118252224749204337330812385528393632699710814799548143945862320399052114141059139540968920129721845645185222551957294580663975060325726616834920202281101481177614483464033561687832901381706575812293236691701422211709936616939
2006-03-0172753324922058213565611489105060117331569339845813647277381154891453168739147440931944337975015603031536143344614347121546554881609730114883572475869754761093110122971273316197452225505218821929261928701078375953734130261390841975211811537911079230039036044011201961074521084540893128318270116439731538011663223150338
2006-04-017088031813216131147859761065854818109227035611320414830350904441120917143944623463175328808901791298403190260674945415499162515209138190635753328478157116991128243337194916425029185385220828822097234456864199971513725724421815001963124550811281752255635539125270228243160745115727033621453519426041011394843172453
2006-05-018689346793269471002731261615210958100473023428721132131038661525636168520267842783347152227516661023335383984558101518019013762195814194251784414930511787321501560272752315906215131547232298831642703293388779813966303474371531251196215195057211921559386280582441130205194189426558265293458557147331622877601547
\n", + "
" + ], + "text/plain": [ + " Total Hol VFR Bus Oth NSW - hol VIC - hol QLD - hol \\\n", + "timestamp \n", + "2006-01-01 84503 45906 26042 9815 2740 17589 10412 9078 \n", + "2006-02-01 65312 29347 20676 11823 3466 11027 6025 6310 \n", + "2006-03-01 72753 32492 20582 13565 6114 8910 5060 11733 \n", + "2006-04-01 70880 31813 21613 11478 5976 10658 5481 8109 \n", + "2006-05-01 86893 46793 26947 10027 3126 16152 10958 10047 \n", + "\n", + " SA - hol WA - hol TAS - hol NT - hol NSW - vfr VIC - vfr \\\n", + "timestamp \n", + "2006-01-01 3089 3449 2102 187 9398 5993 \n", + "2006-02-01 1935 2454 1098 498 7829 4107 \n", + "2006-03-01 1569 3398 458 1364 7277 3811 \n", + "2006-04-01 2270 3561 1320 414 8303 5090 \n", + "2006-05-01 3023 4287 2113 213 10386 6152 \n", + "\n", + " QLD - vfr SA - vfr WA - vfr TAS - vfr NT - vfr NSW - bus \\\n", + "timestamp \n", + "2006-01-01 5290 2193 1781 1350 37 2885 \n", + "2006-02-01 4902 1445 1353 523 517 4301 \n", + "2006-03-01 5489 1453 1687 391 474 4093 \n", + "2006-04-01 4441 1209 1714 394 462 3463 \n", + "2006-05-01 5636 1685 2026 784 278 3347 \n", + "\n", + " VIC - bus QLD - bus SA - bus WA - bus TAS - bus NT - bus \\\n", + "timestamp \n", + "2006-01-01 2148 2093 844 1406 223 216 \n", + "2006-02-01 1825 2224 749 2043 373 308 \n", + "2006-03-01 1944 3379 750 1560 303 1536 \n", + "2006-04-01 1753 2880 890 1791 298 403 \n", + "2006-05-01 1522 2751 666 1023 335 383 \n", + "\n", + " NSW - oth VIC - oth QLD - oth SA - oth WA - oth TAS - oth \\\n", + "timestamp \n", + "2006-01-01 906 467 702 317 205 100 \n", + "2006-02-01 1238 552 839 363 269 97 \n", + "2006-03-01 1433 446 1434 712 1546 55 \n", + "2006-04-01 1902 606 749 454 1549 91 \n", + "2006-05-01 984 558 1015 180 190 137 \n", + "\n", + " NT - oth NSW - hol - city NSW - hol - noncity VIC - hol - city \\\n", + "timestamp \n", + "2006-01-01 43 3096 14493 2531 \n", + "2006-02-01 108 1479 9548 1439 \n", + "2006-03-01 488 1609 7301 1488 \n", + "2006-04-01 625 1520 9138 1906 \n", + "2006-05-01 62 1958 14194 2517 \n", + "\n", + " VIC - hol - noncity QLD - hol - city QLD - hol - noncity \\\n", + "timestamp \n", + "2006-01-01 7881 4688 4390 \n", + "2006-02-01 4586 2320 3990 \n", + "2006-03-01 3572 4758 6975 \n", + "2006-04-01 3575 3328 4781 \n", + "2006-05-01 8441 4930 5117 \n", + "\n", + " SA - hol - city SA - hol - noncity WA - hol - city \\\n", + "timestamp \n", + "2006-01-01 888 2201 1383 \n", + "2006-02-01 521 1414 1059 \n", + "2006-03-01 476 1093 1101 \n", + "2006-04-01 571 1699 1128 \n", + "2006-05-01 873 2150 1560 \n", + "\n", + " WA - hol - noncity TAS - hol - city TAS - hol - noncity \\\n", + "timestamp \n", + "2006-01-01 2066 619 1483 \n", + "2006-02-01 1395 409 689 \n", + "2006-03-01 2297 127 331 \n", + "2006-04-01 2433 371 949 \n", + "2006-05-01 2727 523 1590 \n", + "\n", + " NT - hol - city NT - hol - noncity NSW - vfr - city \\\n", + "timestamp \n", + "2006-01-01 101 86 2709 \n", + "2006-02-01 201 297 2184 \n", + "2006-03-01 619 745 2225 \n", + "2006-04-01 164 250 2918 \n", + "2006-05-01 62 151 3154 \n", + "\n", + " NSW - vfr - noncity VIC - vfr - city VIC - vfr - noncity \\\n", + "timestamp \n", + "2006-01-01 6689 2565 3428 \n", + "2006-02-01 5645 1852 2255 \n", + "2006-03-01 5052 1882 1929 \n", + "2006-04-01 5385 2208 2882 \n", + "2006-05-01 7232 2988 3164 \n", + "\n", + " QLD - vfr - city QLD - vfr - noncity SA - vfr - city \\\n", + "timestamp \n", + "2006-01-01 3003 2287 1324 \n", + "2006-02-01 1957 2945 806 \n", + "2006-03-01 2619 2870 1078 \n", + "2006-04-01 2097 2344 568 \n", + "2006-05-01 2703 2933 887 \n", + "\n", + " SA - vfr - noncity WA - vfr - city WA - vfr - noncity \\\n", + "timestamp \n", + "2006-01-01 869 1019 762 \n", + "2006-02-01 639 750 603 \n", + "2006-03-01 375 953 734 \n", + "2006-04-01 641 999 715 \n", + "2006-05-01 798 1396 630 \n", + "\n", + " TAS - vfr - city TAS - vfr - noncity NT - vfr - city \\\n", + "timestamp \n", + "2006-01-01 602 748 28 \n", + "2006-02-01 257 266 168 \n", + "2006-03-01 130 261 390 \n", + "2006-04-01 137 257 244 \n", + "2006-05-01 347 437 153 \n", + "\n", + " NT - vfr - noncity NSW - bus - city NSW - bus - noncity \\\n", + "timestamp \n", + "2006-01-01 9 1201 1684 \n", + "2006-02-01 349 2020 2281 \n", + "2006-03-01 84 1975 2118 \n", + "2006-04-01 218 1500 1963 \n", + "2006-05-01 125 1196 2151 \n", + "\n", + " VIC - bus - city VIC - bus - noncity QLD - bus - city \\\n", + "timestamp \n", + "2006-01-01 1164 984 1111 \n", + "2006-02-01 1014 811 776 \n", + "2006-03-01 1153 791 1079 \n", + "2006-04-01 1245 508 1128 \n", + "2006-05-01 950 572 1192 \n", + "\n", + " QLD - bus - noncity SA - bus - city SA - bus - noncity \\\n", + "timestamp \n", + "2006-01-01 982 388 456 \n", + "2006-02-01 1448 346 403 \n", + "2006-03-01 2300 390 360 \n", + "2006-04-01 1752 255 635 \n", + "2006-05-01 1559 386 280 \n", + "\n", + " WA - bus - city WA - bus - noncity TAS - bus - city \\\n", + "timestamp \n", + "2006-01-01 532 874 116 \n", + "2006-02-01 356 1687 83 \n", + "2006-03-01 440 1120 196 \n", + "2006-04-01 539 1252 70 \n", + "2006-05-01 582 441 130 \n", + "\n", + " TAS - bus - noncity NT - bus - city NT - bus - noncity \\\n", + "timestamp \n", + "2006-01-01 107 136 80 \n", + "2006-02-01 290 138 170 \n", + "2006-03-01 107 452 1084 \n", + "2006-04-01 228 243 160 \n", + "2006-05-01 205 194 189 \n", + "\n", + " NSW - oth - city NSW - oth - noncity VIC - oth - city \\\n", + "timestamp \n", + "2006-01-01 396 510 181 \n", + "2006-02-01 657 581 229 \n", + "2006-03-01 540 893 128 \n", + "2006-04-01 745 1157 270 \n", + "2006-05-01 426 558 265 \n", + "\n", + " VIC - oth - noncity QLD - oth - city QLD - oth - noncity \\\n", + "timestamp \n", + "2006-01-01 286 431 271 \n", + "2006-02-01 323 669 170 \n", + "2006-03-01 318 270 1164 \n", + "2006-04-01 336 214 535 \n", + "2006-05-01 293 458 557 \n", + "\n", + " SA - oth - city SA - oth - noncity WA - oth - city \\\n", + "timestamp \n", + "2006-01-01 244 73 168 \n", + "2006-02-01 142 221 170 \n", + "2006-03-01 397 315 380 \n", + "2006-04-01 194 260 410 \n", + "2006-05-01 147 33 162 \n", + "\n", + " WA - oth - noncity TAS - oth - city TAS - oth - noncity \\\n", + "timestamp \n", + "2006-01-01 37 76 24 \n", + "2006-02-01 99 36 61 \n", + "2006-03-01 1166 32 23 \n", + "2006-04-01 1139 48 43 \n", + "2006-05-01 28 77 60 \n", + "\n", + " NT - oth - city NT - oth - noncity \n", + "timestamp \n", + "2006-01-01 35 8 \n", + "2006-02-01 69 39 \n", + "2006-03-01 150 338 \n", + "2006-04-01 172 453 \n", + "2006-05-01 15 47 " + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "df = pd.read_csv(\"hier1_with_names.csv\")\n", + "\n", + "periods = len(df)\n", + "df[\"timestamp\"] = pd.date_range(\"2006-01-01\", periods=periods, freq=\"MS\")\n", + "df.set_index(\"timestamp\", inplace=True)\n", + "\n", + "df.head()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 2.1 Manually setting hierarchical structure \n", + "This section presents how to set hierarchical structure and prepare data. We are going to create a hierarchical\n", + "dataset with two levels: total demand and demand per tourism reason." + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "from etna.datasets import TSDataset" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Consider the **Reason** level of the hierarchy." + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
HolVFRBusOth
timestamp
2006-01-01459062604298152740
2006-02-012934720676118233466
2006-03-013249220582135656114
2006-04-013181321613114785976
2006-05-014679326947100273126
\n", + "
" + ], + "text/plain": [ + " Hol VFR Bus Oth\n", + "timestamp \n", + "2006-01-01 45906 26042 9815 2740\n", + "2006-02-01 29347 20676 11823 3466\n", + "2006-03-01 32492 20582 13565 6114\n", + "2006-04-01 31813 21613 11478 5976\n", + "2006-05-01 46793 26947 10027 3126" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "reason_segments = [\"Hol\", \"VFR\", \"Bus\", \"Oth\"]\n", + "\n", + "df[reason_segments].head()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 2.1.1 Convert dataset to ETNA wide format \n", + "First, convert dataframe to ETNA long format." + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
timestamptargetsegment
02006-01-0145906Hol
12006-02-0129347Hol
22006-03-0132492Hol
32006-04-0131813Hol
42006-05-0146793Hol
\n", + "
" + ], + "text/plain": [ + " timestamp target segment\n", + "0 2006-01-01 45906 Hol\n", + "1 2006-02-01 29347 Hol\n", + "2 2006-03-01 32492 Hol\n", + "3 2006-04-01 31813 Hol\n", + "4 2006-05-01 46793 Hol" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "hierarchical_df = []\n", + "for segment_name in reason_segments:\n", + " segment = df[segment_name]\n", + "\n", + " segment_slice = pd.DataFrame(\n", + " {\"timestamp\": segment.index, \"target\": segment.values, \"segment\": [segment_name] * periods}\n", + " )\n", + " hierarchical_df.append(segment_slice)\n", + "\n", + "hierarchical_df = pd.concat(hierarchical_df, axis=0)\n", + "\n", + "hierarchical_df.head()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now, the dataframe could be converted to ETNA wide format." + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [], + "source": [ + "hierarchical_df = TSDataset.to_dataset(df=hierarchical_df)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 2.1.2 Creat HierarchicalStructure \n", + "For handling information about hierarchical structure, there is a dedicated object in the ETNA library: `HierarchicalStructure`.\n", + "\n", + "To create `HierarchicalStructure` define relationships between segments at different levels. This relation should be\n", + "described as mapping between levels members, where keys are parent segments and values are lists of child segments\n", + "from the lower level. Also provide a list of level names, where ordering corresponds to hierarchical relationships\n", + "between levels." + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [], + "source": [ + "from etna.datasets import HierarchicalStructure" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "HierarchicalStructure(level_structure = {'total': ['Hol', 'VFR', 'Bus', 'Oth']}, level_names = ['total', 'reason'], )" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "hierarchical_structure = HierarchicalStructure(\n", + " level_structure={\"total\": [\"Hol\", \"VFR\", \"Bus\", \"Oth\"]}, level_names=[\"total\", \"reason\"]\n", + ")\n", + "\n", + "hierarchical_structure" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 2.1.3 Create hierarchical dataset \n", + "When all the data is prepared, call the `TSDataset` constructor to create a hierarchical dataset." + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
segmentBusHolOthVFR
featuretargettargettargettarget
timestamp
2006-01-01981545906274026042
2006-02-011182329347346620676
2006-03-011356532492611420582
2006-04-011147831813597621613
2006-05-011002746793312626947
\n", + "
" + ], + "text/plain": [ + "segment Bus Hol Oth VFR\n", + "feature target target target target\n", + "timestamp \n", + "2006-01-01 9815 45906 2740 26042\n", + "2006-02-01 11823 29347 3466 20676\n", + "2006-03-01 13565 32492 6114 20582\n", + "2006-04-01 11478 31813 5976 21613\n", + "2006-05-01 10027 46793 3126 26947" + ] + }, + "execution_count": 11, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "hierarchical_ts = TSDataset(df=hierarchical_df, freq=\"MS\", hierarchical_structure=hierarchical_structure)\n", + "\n", + "hierarchical_ts.head()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Ensure that the dataset is at the desired level." + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "'reason'" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "hierarchical_ts.current_df_level" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 2.2 Hierarchical structure detection \n", + "\n", + "This section presents how to prepare data and detect hierarchical structure.\n", + "The main advantage of this approach for creating hierarchical structures is that you don't need to define an adjacency list.\n", + "All hierarchical relationships would be detected from the dataframe columns.\n", + "\n", + "The main applications for this approach are when defining the adjacency list is not desirable or when some columns of\n", + "the dataframe already have information about hierarchy (e.g. related categorical columns).\n", + "\n", + "A data frame must be prepared in a specific format for detection to work. The following sections show how to do so.\n", + "\n", + "Consider the City level of the hierarchy." + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
NSW - hol - cityNSW - hol - noncityVIC - hol - cityVIC - hol - noncityQLD - hol - cityQLD - hol - noncitySA - hol - citySA - hol - noncityWA - hol - cityWA - hol - noncityTAS - hol - cityTAS - hol - noncityNT - hol - cityNT - hol - noncityNSW - vfr - cityNSW - vfr - noncityVIC - vfr - cityVIC - vfr - noncityQLD - vfr - cityQLD - vfr - noncitySA - vfr - citySA - vfr - noncityWA - vfr - cityWA - vfr - noncityTAS - vfr - cityTAS - vfr - noncityNT - vfr - cityNT - vfr - noncityNSW - bus - cityNSW - bus - noncityVIC - bus - cityVIC - bus - noncityQLD - bus - cityQLD - bus - noncitySA - bus - citySA - bus - noncityWA - bus - cityWA - bus - noncityTAS - bus - cityTAS - bus - noncityNT - bus - cityNT - bus - noncityNSW - oth - cityNSW - oth - noncityVIC - oth - cityVIC - oth - noncityQLD - oth - cityQLD - oth - noncitySA - oth - citySA - oth - noncityWA - oth - cityWA - oth - noncityTAS - oth - cityTAS - oth - noncityNT - oth - cityNT - oth - noncity
timestamp
2006-01-0130961449325317881468843908882201138320666191483101862709668925653428300322871324869101976260274828912011684116498411119823884565328741161071368039651018128643127124473168377624358
2006-02-0114799548143945862320399052114141059139540968920129721845645185222551957294580663975060325726616834920202281101481177614483464033561687832901381706575812293236691701422211709936616939
2006-03-011609730114883572475869754761093110122971273316197452225505218821929261928701078375953734130261390841975211811537911079230039036044011201961074521084540893128318270116439731538011663223150338
2006-04-0115209138190635753328478157116991128243337194916425029185385220828822097234456864199971513725724421815001963124550811281752255635539125270228243160745115727033621453519426041011394843172453
2006-05-01195814194251784414930511787321501560272752315906215131547232298831642703293388779813966303474371531251196215195057211921559386280582441130205194189426558265293458557147331622877601547
\n", + "
" + ], + "text/plain": [ + " NSW - hol - city NSW - hol - noncity VIC - hol - city \\\n", + "timestamp \n", + "2006-01-01 3096 14493 2531 \n", + "2006-02-01 1479 9548 1439 \n", + "2006-03-01 1609 7301 1488 \n", + "2006-04-01 1520 9138 1906 \n", + "2006-05-01 1958 14194 2517 \n", + "\n", + " VIC - hol - noncity QLD - hol - city QLD - hol - noncity \\\n", + "timestamp \n", + "2006-01-01 7881 4688 4390 \n", + "2006-02-01 4586 2320 3990 \n", + "2006-03-01 3572 4758 6975 \n", + "2006-04-01 3575 3328 4781 \n", + "2006-05-01 8441 4930 5117 \n", + "\n", + " SA - hol - city SA - hol - noncity WA - hol - city \\\n", + "timestamp \n", + "2006-01-01 888 2201 1383 \n", + "2006-02-01 521 1414 1059 \n", + "2006-03-01 476 1093 1101 \n", + "2006-04-01 571 1699 1128 \n", + "2006-05-01 873 2150 1560 \n", + "\n", + " WA - hol - noncity TAS - hol - city TAS - hol - noncity \\\n", + "timestamp \n", + "2006-01-01 2066 619 1483 \n", + "2006-02-01 1395 409 689 \n", + "2006-03-01 2297 127 331 \n", + "2006-04-01 2433 371 949 \n", + "2006-05-01 2727 523 1590 \n", + "\n", + " NT - hol - city NT - hol - noncity NSW - vfr - city \\\n", + "timestamp \n", + "2006-01-01 101 86 2709 \n", + "2006-02-01 201 297 2184 \n", + "2006-03-01 619 745 2225 \n", + "2006-04-01 164 250 2918 \n", + "2006-05-01 62 151 3154 \n", + "\n", + " NSW - vfr - noncity VIC - vfr - city VIC - vfr - noncity \\\n", + "timestamp \n", + "2006-01-01 6689 2565 3428 \n", + "2006-02-01 5645 1852 2255 \n", + "2006-03-01 5052 1882 1929 \n", + "2006-04-01 5385 2208 2882 \n", + "2006-05-01 7232 2988 3164 \n", + "\n", + " QLD - vfr - city QLD - vfr - noncity SA - vfr - city \\\n", + "timestamp \n", + "2006-01-01 3003 2287 1324 \n", + "2006-02-01 1957 2945 806 \n", + "2006-03-01 2619 2870 1078 \n", + "2006-04-01 2097 2344 568 \n", + "2006-05-01 2703 2933 887 \n", + "\n", + " SA - vfr - noncity WA - vfr - city WA - vfr - noncity \\\n", + "timestamp \n", + "2006-01-01 869 1019 762 \n", + "2006-02-01 639 750 603 \n", + "2006-03-01 375 953 734 \n", + "2006-04-01 641 999 715 \n", + "2006-05-01 798 1396 630 \n", + "\n", + " TAS - vfr - city TAS - vfr - noncity NT - vfr - city \\\n", + "timestamp \n", + "2006-01-01 602 748 28 \n", + "2006-02-01 257 266 168 \n", + "2006-03-01 130 261 390 \n", + "2006-04-01 137 257 244 \n", + "2006-05-01 347 437 153 \n", + "\n", + " NT - vfr - noncity NSW - bus - city NSW - bus - noncity \\\n", + "timestamp \n", + "2006-01-01 9 1201 1684 \n", + "2006-02-01 349 2020 2281 \n", + "2006-03-01 84 1975 2118 \n", + "2006-04-01 218 1500 1963 \n", + "2006-05-01 125 1196 2151 \n", + "\n", + " VIC - bus - city VIC - bus - noncity QLD - bus - city \\\n", + "timestamp \n", + "2006-01-01 1164 984 1111 \n", + "2006-02-01 1014 811 776 \n", + "2006-03-01 1153 791 1079 \n", + "2006-04-01 1245 508 1128 \n", + "2006-05-01 950 572 1192 \n", + "\n", + " QLD - bus - noncity SA - bus - city SA - bus - noncity \\\n", + "timestamp \n", + "2006-01-01 982 388 456 \n", + "2006-02-01 1448 346 403 \n", + "2006-03-01 2300 390 360 \n", + "2006-04-01 1752 255 635 \n", + "2006-05-01 1559 386 280 \n", + "\n", + " WA - bus - city WA - bus - noncity TAS - bus - city \\\n", + "timestamp \n", + "2006-01-01 532 874 116 \n", + "2006-02-01 356 1687 83 \n", + "2006-03-01 440 1120 196 \n", + "2006-04-01 539 1252 70 \n", + "2006-05-01 582 441 130 \n", + "\n", + " TAS - bus - noncity NT - bus - city NT - bus - noncity \\\n", + "timestamp \n", + "2006-01-01 107 136 80 \n", + "2006-02-01 290 138 170 \n", + "2006-03-01 107 452 1084 \n", + "2006-04-01 228 243 160 \n", + "2006-05-01 205 194 189 \n", + "\n", + " NSW - oth - city NSW - oth - noncity VIC - oth - city \\\n", + "timestamp \n", + "2006-01-01 396 510 181 \n", + "2006-02-01 657 581 229 \n", + "2006-03-01 540 893 128 \n", + "2006-04-01 745 1157 270 \n", + "2006-05-01 426 558 265 \n", + "\n", + " VIC - oth - noncity QLD - oth - city QLD - oth - noncity \\\n", + "timestamp \n", + "2006-01-01 286 431 271 \n", + "2006-02-01 323 669 170 \n", + "2006-03-01 318 270 1164 \n", + "2006-04-01 336 214 535 \n", + "2006-05-01 293 458 557 \n", + "\n", + " SA - oth - city SA - oth - noncity WA - oth - city \\\n", + "timestamp \n", + "2006-01-01 244 73 168 \n", + "2006-02-01 142 221 170 \n", + "2006-03-01 397 315 380 \n", + "2006-04-01 194 260 410 \n", + "2006-05-01 147 33 162 \n", + "\n", + " WA - oth - noncity TAS - oth - city TAS - oth - noncity \\\n", + "timestamp \n", + "2006-01-01 37 76 24 \n", + "2006-02-01 99 36 61 \n", + "2006-03-01 1166 32 23 \n", + "2006-04-01 1139 48 43 \n", + "2006-05-01 28 77 60 \n", + "\n", + " NT - oth - city NT - oth - noncity \n", + "timestamp \n", + "2006-01-01 35 8 \n", + "2006-02-01 69 39 \n", + "2006-03-01 150 338 \n", + "2006-04-01 172 453 \n", + "2006-05-01 15 47 " + ] + }, + "execution_count": 13, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "city_segments = list(filter(lambda name: name.count(\"-\") == 2, df.columns))\n", + "\n", + "df[city_segments].head()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 2.2.1 Prepare data in ETNA hierarchical long format \n", + "Before trying to detect a hierarchical structure, data must be transformed to hierarchical long format. In this format,\n", + "your `DataFrame` must contain `timestamp`, `target` and level columns. Each level column represents membership of the\n", + "observation at higher levels of the hierarchy." + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
timestamptargetcity_levelregion_levelreason_level
02006-01-013096cityNSWHol
12006-02-011479cityNSWHol
22006-03-011609cityNSWHol
32006-04-011520cityNSWHol
42006-05-011958cityNSWHol
\n", + "
" + ], + "text/plain": [ + " timestamp target city_level region_level reason_level\n", + "0 2006-01-01 3096 city NSW Hol\n", + "1 2006-02-01 1479 city NSW Hol\n", + "2 2006-03-01 1609 city NSW Hol\n", + "3 2006-04-01 1520 city NSW Hol\n", + "4 2006-05-01 1958 city NSW Hol" + ] + }, + "execution_count": 14, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "hierarchical_df = []\n", + "for segment_name in city_segments:\n", + " segment = df[segment_name]\n", + " region, reason, city = segment_name.split(\" - \")\n", + "\n", + " seg_df = pd.DataFrame(\n", + " data={\n", + " \"timestamp\": segment.index,\n", + " \"target\": segment.values,\n", + " \"city_level\": [city] * periods,\n", + " \"region_level\": [region] * periods,\n", + " \"reason_level\": [reason] * periods,\n", + " },\n", + " )\n", + " hierarchical_df.append(seg_df)\n", + "\n", + "hierarchical_df = pd.concat(hierarchical_df, axis=0)\n", + "\n", + "hierarchical_df[\"reason_level\"].replace({\"hol\": \"Hol\", \"vfr\": \"VFR\", \"bus\": \"Bus\", \"oth\": \"Oth\"}, inplace=True)\n", + "hierarchical_df.head()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Here we can omit total level as it will be added automatically in hierarchy detection." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 2.2.2 Convert data to etna wide format with `to_hierarchical_dataset`\n", + "To detect hierarchical structure and convert `DataFrame` to ETNA wide format, call `TSDataset.to_hierarchical_dataset`,\n", + "provided with prepared data and level columns names in order from top to bottom." + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
segmentBus_NSW_cityBus_NSW_noncityBus_NT_cityBus_NT_noncityBus_QLD_cityBus_QLD_noncityBus_SA_cityBus_SA_noncityBus_TAS_cityBus_TAS_noncityBus_VIC_cityBus_VIC_noncityBus_WA_cityBus_WA_noncityHol_NSW_cityHol_NSW_noncityHol_NT_cityHol_NT_noncityHol_QLD_cityHol_QLD_noncityHol_SA_cityHol_SA_noncityHol_TAS_cityHol_TAS_noncityHol_VIC_cityHol_VIC_noncityHol_WA_cityHol_WA_noncityOth_NSW_cityOth_NSW_noncityOth_NT_cityOth_NT_noncityOth_QLD_cityOth_QLD_noncityOth_SA_cityOth_SA_noncityOth_TAS_cityOth_TAS_noncityOth_VIC_cityOth_VIC_noncityOth_WA_cityOth_WA_noncityVFR_NSW_cityVFR_NSW_noncityVFR_NT_cityVFR_NT_noncityVFR_QLD_cityVFR_QLD_noncityVFR_SA_cityVFR_SA_noncityVFR_TAS_cityVFR_TAS_noncityVFR_VIC_cityVFR_VIC_noncityVFR_WA_cityVFR_WA_noncity
featuretargettargettargettargettargettargettargettargettargettargettargettargettargettargettargettargettargettargettargettargettargettargettargettargettargettargettargettargettargettargettargettargettargettargettargettargettargettargettargettargettargettargettargettargettargettargettargettargettargettargettargettargettargettargettargettarget
timestamp
2006-01-0112011684136801111982388456116107116498453287430961449310186468843908882201619148325317881138320663965103584312712447376241812861683727096689289300322871324869602748256534281019762
2006-02-0120202281138170776144834640383290101481135616871479954820129723203990521141440968914394586105913956575816939669170142221366122932317099218456451683491957294580663925726618522255750603
2006-03-011975211845210841079230039036019610711537914401120160973016197454758697547610931273311488357211012297540893150338270116439731532231283183801166222550523908426192870107837513026118821929953734
2006-04-0115001963243160112817522556357022812455085391252152091381642503328478157116993719491906357511282433745115717245321453519426048432703364101139291853852442182097234456864113725722082882999715
2006-05-01119621511941891192155938628013020595057258244119581419462151493051178732150523159025178441156027274265581547458557147337760265293162283154723215312527032933887798347437298831641396630
\n", + "
" + ], + "text/plain": [ + "segment Bus_NSW_city Bus_NSW_noncity Bus_NT_city Bus_NT_noncity \\\n", + "feature target target target target \n", + "timestamp \n", + "2006-01-01 1201 1684 136 80 \n", + "2006-02-01 2020 2281 138 170 \n", + "2006-03-01 1975 2118 452 1084 \n", + "2006-04-01 1500 1963 243 160 \n", + "2006-05-01 1196 2151 194 189 \n", + "\n", + "segment Bus_QLD_city Bus_QLD_noncity Bus_SA_city Bus_SA_noncity \\\n", + "feature target target target target \n", + "timestamp \n", + "2006-01-01 1111 982 388 456 \n", + "2006-02-01 776 1448 346 403 \n", + "2006-03-01 1079 2300 390 360 \n", + "2006-04-01 1128 1752 255 635 \n", + "2006-05-01 1192 1559 386 280 \n", + "\n", + "segment Bus_TAS_city Bus_TAS_noncity Bus_VIC_city Bus_VIC_noncity \\\n", + "feature target target target target \n", + "timestamp \n", + "2006-01-01 116 107 1164 984 \n", + "2006-02-01 83 290 1014 811 \n", + "2006-03-01 196 107 1153 791 \n", + "2006-04-01 70 228 1245 508 \n", + "2006-05-01 130 205 950 572 \n", + "\n", + "segment Bus_WA_city Bus_WA_noncity Hol_NSW_city Hol_NSW_noncity \\\n", + "feature target target target target \n", + "timestamp \n", + "2006-01-01 532 874 3096 14493 \n", + "2006-02-01 356 1687 1479 9548 \n", + "2006-03-01 440 1120 1609 7301 \n", + "2006-04-01 539 1252 1520 9138 \n", + "2006-05-01 582 441 1958 14194 \n", + "\n", + "segment Hol_NT_city Hol_NT_noncity Hol_QLD_city Hol_QLD_noncity \\\n", + "feature target target target target \n", + "timestamp \n", + "2006-01-01 101 86 4688 4390 \n", + "2006-02-01 201 297 2320 3990 \n", + "2006-03-01 619 745 4758 6975 \n", + "2006-04-01 164 250 3328 4781 \n", + "2006-05-01 62 151 4930 5117 \n", + "\n", + "segment Hol_SA_city Hol_SA_noncity Hol_TAS_city Hol_TAS_noncity \\\n", + "feature target target target target \n", + "timestamp \n", + "2006-01-01 888 2201 619 1483 \n", + "2006-02-01 521 1414 409 689 \n", + "2006-03-01 476 1093 127 331 \n", + "2006-04-01 571 1699 371 949 \n", + "2006-05-01 873 2150 523 1590 \n", + "\n", + "segment Hol_VIC_city Hol_VIC_noncity Hol_WA_city Hol_WA_noncity \\\n", + "feature target target target target \n", + "timestamp \n", + "2006-01-01 2531 7881 1383 2066 \n", + "2006-02-01 1439 4586 1059 1395 \n", + "2006-03-01 1488 3572 1101 2297 \n", + "2006-04-01 1906 3575 1128 2433 \n", + "2006-05-01 2517 8441 1560 2727 \n", + "\n", + "segment Oth_NSW_city Oth_NSW_noncity Oth_NT_city Oth_NT_noncity \\\n", + "feature target target target target \n", + "timestamp \n", + "2006-01-01 396 510 35 8 \n", + "2006-02-01 657 581 69 39 \n", + "2006-03-01 540 893 150 338 \n", + "2006-04-01 745 1157 172 453 \n", + "2006-05-01 426 558 15 47 \n", + "\n", + "segment Oth_QLD_city Oth_QLD_noncity Oth_SA_city Oth_SA_noncity \\\n", + "feature target target target target \n", + "timestamp \n", + "2006-01-01 431 271 244 73 \n", + "2006-02-01 669 170 142 221 \n", + "2006-03-01 270 1164 397 315 \n", + "2006-04-01 214 535 194 260 \n", + "2006-05-01 458 557 147 33 \n", + "\n", + "segment Oth_TAS_city Oth_TAS_noncity Oth_VIC_city Oth_VIC_noncity \\\n", + "feature target target target target \n", + "timestamp \n", + "2006-01-01 76 24 181 286 \n", + "2006-02-01 36 61 229 323 \n", + "2006-03-01 32 23 128 318 \n", + "2006-04-01 48 43 270 336 \n", + "2006-05-01 77 60 265 293 \n", + "\n", + "segment Oth_WA_city Oth_WA_noncity VFR_NSW_city VFR_NSW_noncity \\\n", + "feature target target target target \n", + "timestamp \n", + "2006-01-01 168 37 2709 6689 \n", + "2006-02-01 170 99 2184 5645 \n", + "2006-03-01 380 1166 2225 5052 \n", + "2006-04-01 410 1139 2918 5385 \n", + "2006-05-01 162 28 3154 7232 \n", + "\n", + "segment VFR_NT_city VFR_NT_noncity VFR_QLD_city VFR_QLD_noncity \\\n", + "feature target target target target \n", + "timestamp \n", + "2006-01-01 28 9 3003 2287 \n", + "2006-02-01 168 349 1957 2945 \n", + "2006-03-01 390 84 2619 2870 \n", + "2006-04-01 244 218 2097 2344 \n", + "2006-05-01 153 125 2703 2933 \n", + "\n", + "segment VFR_SA_city VFR_SA_noncity VFR_TAS_city VFR_TAS_noncity \\\n", + "feature target target target target \n", + "timestamp \n", + "2006-01-01 1324 869 602 748 \n", + "2006-02-01 806 639 257 266 \n", + "2006-03-01 1078 375 130 261 \n", + "2006-04-01 568 641 137 257 \n", + "2006-05-01 887 798 347 437 \n", + "\n", + "segment VFR_VIC_city VFR_VIC_noncity VFR_WA_city VFR_WA_noncity \n", + "feature target target target target \n", + "timestamp \n", + "2006-01-01 2565 3428 1019 762 \n", + "2006-02-01 1852 2255 750 603 \n", + "2006-03-01 1882 1929 953 734 \n", + "2006-04-01 2208 2882 999 715 \n", + "2006-05-01 2988 3164 1396 630 " + ] + }, + "execution_count": 15, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "hierarchical_df, hierarchical_structure = TSDataset.to_hierarchical_dataset(\n", + " df=hierarchical_df, level_columns=[\"reason_level\", \"region_level\", \"city_level\"]\n", + ")\n", + "\n", + "hierarchical_df.head()" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "HierarchicalStructure(level_structure = {'total': ['Hol', 'VFR', 'Bus', 'Oth'], 'Bus': ['Bus_NSW', 'Bus_VIC', 'Bus_QLD', 'Bus_SA', 'Bus_WA', 'Bus_TAS', 'Bus_NT'], 'Hol': ['Hol_NSW', 'Hol_VIC', 'Hol_QLD', 'Hol_SA', 'Hol_WA', 'Hol_TAS', 'Hol_NT'], 'Oth': ['Oth_NSW', 'Oth_VIC', 'Oth_QLD', 'Oth_SA', 'Oth_WA', 'Oth_TAS', 'Oth_NT'], 'VFR': ['VFR_NSW', 'VFR_VIC', 'VFR_QLD', 'VFR_SA', 'VFR_WA', 'VFR_TAS', 'VFR_NT'], 'Bus_NSW': ['Bus_NSW_city', 'Bus_NSW_noncity'], 'Bus_NT': ['Bus_NT_city', 'Bus_NT_noncity'], 'Bus_QLD': ['Bus_QLD_city', 'Bus_QLD_noncity'], 'Bus_SA': ['Bus_SA_city', 'Bus_SA_noncity'], 'Bus_TAS': ['Bus_TAS_city', 'Bus_TAS_noncity'], 'Bus_VIC': ['Bus_VIC_city', 'Bus_VIC_noncity'], 'Bus_WA': ['Bus_WA_city', 'Bus_WA_noncity'], 'Hol_NSW': ['Hol_NSW_city', 'Hol_NSW_noncity'], 'Hol_NT': ['Hol_NT_city', 'Hol_NT_noncity'], 'Hol_QLD': ['Hol_QLD_city', 'Hol_QLD_noncity'], 'Hol_SA': ['Hol_SA_city', 'Hol_SA_noncity'], 'Hol_TAS': ['Hol_TAS_city', 'Hol_TAS_noncity'], 'Hol_VIC': ['Hol_VIC_city', 'Hol_VIC_noncity'], 'Hol_WA': ['Hol_WA_city', 'Hol_WA_noncity'], 'Oth_NSW': ['Oth_NSW_city', 'Oth_NSW_noncity'], 'Oth_NT': ['Oth_NT_city', 'Oth_NT_noncity'], 'Oth_QLD': ['Oth_QLD_city', 'Oth_QLD_noncity'], 'Oth_SA': ['Oth_SA_city', 'Oth_SA_noncity'], 'Oth_TAS': ['Oth_TAS_city', 'Oth_TAS_noncity'], 'Oth_VIC': ['Oth_VIC_city', 'Oth_VIC_noncity'], 'Oth_WA': ['Oth_WA_city', 'Oth_WA_noncity'], 'VFR_NSW': ['VFR_NSW_city', 'VFR_NSW_noncity'], 'VFR_NT': ['VFR_NT_city', 'VFR_NT_noncity'], 'VFR_QLD': ['VFR_QLD_city', 'VFR_QLD_noncity'], 'VFR_SA': ['VFR_SA_city', 'VFR_SA_noncity'], 'VFR_TAS': ['VFR_TAS_city', 'VFR_TAS_noncity'], 'VFR_VIC': ['VFR_VIC_city', 'VFR_VIC_noncity'], 'VFR_WA': ['VFR_WA_city', 'VFR_WA_noncity']}, level_names = ['total', 'reason_level', 'region_level', 'city_level'], )" + ] + }, + "execution_count": 16, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "hierarchical_structure" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Here we see that `hierarchical_structure` has a mapping between higher level segments and adjacent lower level segments." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 2.2.3 Create the hierarchical dataset\n", + "To convert data to `TSDataset` call the constructor and pass detected `hierarchical_structure`." + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
segmentBus_NSW_cityBus_NSW_noncityBus_NT_cityBus_NT_noncityBus_QLD_cityBus_QLD_noncityBus_SA_cityBus_SA_noncityBus_TAS_cityBus_TAS_noncityBus_VIC_cityBus_VIC_noncityBus_WA_cityBus_WA_noncityHol_NSW_cityHol_NSW_noncityHol_NT_cityHol_NT_noncityHol_QLD_cityHol_QLD_noncityHol_SA_cityHol_SA_noncityHol_TAS_cityHol_TAS_noncityHol_VIC_cityHol_VIC_noncityHol_WA_cityHol_WA_noncityOth_NSW_cityOth_NSW_noncityOth_NT_cityOth_NT_noncityOth_QLD_cityOth_QLD_noncityOth_SA_cityOth_SA_noncityOth_TAS_cityOth_TAS_noncityOth_VIC_cityOth_VIC_noncityOth_WA_cityOth_WA_noncityVFR_NSW_cityVFR_NSW_noncityVFR_NT_cityVFR_NT_noncityVFR_QLD_cityVFR_QLD_noncityVFR_SA_cityVFR_SA_noncityVFR_TAS_cityVFR_TAS_noncityVFR_VIC_cityVFR_VIC_noncityVFR_WA_cityVFR_WA_noncity
featuretargettargettargettargettargettargettargettargettargettargettargettargettargettargettargettargettargettargettargettargettargettargettargettargettargettargettargettargettargettargettargettargettargettargettargettargettargettargettargettargettargettargettargettargettargettargettargettargettargettargettargettargettargettargettargettarget
timestamp
2006-01-0112011684136801111982388456116107116498453287430961449310186468843908882201619148325317881138320663965103584312712447376241812861683727096689289300322871324869602748256534281019762
2006-02-0120202281138170776144834640383290101481135616871479954820129723203990521141440968914394586105913956575816939669170142221366122932317099218456451683491957294580663925726618522255750603
2006-03-011975211845210841079230039036019610711537914401120160973016197454758697547610931273311488357211012297540893150338270116439731532231283183801166222550523908426192870107837513026118821929953734
2006-04-0115001963243160112817522556357022812455085391252152091381642503328478157116993719491906357511282433745115717245321453519426048432703364101139291853852442182097234456864113725722082882999715
2006-05-01119621511941891192155938628013020595057258244119581419462151493051178732150523159025178441156027274265581547458557147337760265293162283154723215312527032933887798347437298831641396630
\n", + "
" + ], + "text/plain": [ + "segment Bus_NSW_city Bus_NSW_noncity Bus_NT_city Bus_NT_noncity \\\n", + "feature target target target target \n", + "timestamp \n", + "2006-01-01 1201 1684 136 80 \n", + "2006-02-01 2020 2281 138 170 \n", + "2006-03-01 1975 2118 452 1084 \n", + "2006-04-01 1500 1963 243 160 \n", + "2006-05-01 1196 2151 194 189 \n", + "\n", + "segment Bus_QLD_city Bus_QLD_noncity Bus_SA_city Bus_SA_noncity \\\n", + "feature target target target target \n", + "timestamp \n", + "2006-01-01 1111 982 388 456 \n", + "2006-02-01 776 1448 346 403 \n", + "2006-03-01 1079 2300 390 360 \n", + "2006-04-01 1128 1752 255 635 \n", + "2006-05-01 1192 1559 386 280 \n", + "\n", + "segment Bus_TAS_city Bus_TAS_noncity Bus_VIC_city Bus_VIC_noncity \\\n", + "feature target target target target \n", + "timestamp \n", + "2006-01-01 116 107 1164 984 \n", + "2006-02-01 83 290 1014 811 \n", + "2006-03-01 196 107 1153 791 \n", + "2006-04-01 70 228 1245 508 \n", + "2006-05-01 130 205 950 572 \n", + "\n", + "segment Bus_WA_city Bus_WA_noncity Hol_NSW_city Hol_NSW_noncity \\\n", + "feature target target target target \n", + "timestamp \n", + "2006-01-01 532 874 3096 14493 \n", + "2006-02-01 356 1687 1479 9548 \n", + "2006-03-01 440 1120 1609 7301 \n", + "2006-04-01 539 1252 1520 9138 \n", + "2006-05-01 582 441 1958 14194 \n", + "\n", + "segment Hol_NT_city Hol_NT_noncity Hol_QLD_city Hol_QLD_noncity \\\n", + "feature target target target target \n", + "timestamp \n", + "2006-01-01 101 86 4688 4390 \n", + "2006-02-01 201 297 2320 3990 \n", + "2006-03-01 619 745 4758 6975 \n", + "2006-04-01 164 250 3328 4781 \n", + "2006-05-01 62 151 4930 5117 \n", + "\n", + "segment Hol_SA_city Hol_SA_noncity Hol_TAS_city Hol_TAS_noncity \\\n", + "feature target target target target \n", + "timestamp \n", + "2006-01-01 888 2201 619 1483 \n", + "2006-02-01 521 1414 409 689 \n", + "2006-03-01 476 1093 127 331 \n", + "2006-04-01 571 1699 371 949 \n", + "2006-05-01 873 2150 523 1590 \n", + "\n", + "segment Hol_VIC_city Hol_VIC_noncity Hol_WA_city Hol_WA_noncity \\\n", + "feature target target target target \n", + "timestamp \n", + "2006-01-01 2531 7881 1383 2066 \n", + "2006-02-01 1439 4586 1059 1395 \n", + "2006-03-01 1488 3572 1101 2297 \n", + "2006-04-01 1906 3575 1128 2433 \n", + "2006-05-01 2517 8441 1560 2727 \n", + "\n", + "segment Oth_NSW_city Oth_NSW_noncity Oth_NT_city Oth_NT_noncity \\\n", + "feature target target target target \n", + "timestamp \n", + "2006-01-01 396 510 35 8 \n", + "2006-02-01 657 581 69 39 \n", + "2006-03-01 540 893 150 338 \n", + "2006-04-01 745 1157 172 453 \n", + "2006-05-01 426 558 15 47 \n", + "\n", + "segment Oth_QLD_city Oth_QLD_noncity Oth_SA_city Oth_SA_noncity \\\n", + "feature target target target target \n", + "timestamp \n", + "2006-01-01 431 271 244 73 \n", + "2006-02-01 669 170 142 221 \n", + "2006-03-01 270 1164 397 315 \n", + "2006-04-01 214 535 194 260 \n", + "2006-05-01 458 557 147 33 \n", + "\n", + "segment Oth_TAS_city Oth_TAS_noncity Oth_VIC_city Oth_VIC_noncity \\\n", + "feature target target target target \n", + "timestamp \n", + "2006-01-01 76 24 181 286 \n", + "2006-02-01 36 61 229 323 \n", + "2006-03-01 32 23 128 318 \n", + "2006-04-01 48 43 270 336 \n", + "2006-05-01 77 60 265 293 \n", + "\n", + "segment Oth_WA_city Oth_WA_noncity VFR_NSW_city VFR_NSW_noncity \\\n", + "feature target target target target \n", + "timestamp \n", + "2006-01-01 168 37 2709 6689 \n", + "2006-02-01 170 99 2184 5645 \n", + "2006-03-01 380 1166 2225 5052 \n", + "2006-04-01 410 1139 2918 5385 \n", + "2006-05-01 162 28 3154 7232 \n", + "\n", + "segment VFR_NT_city VFR_NT_noncity VFR_QLD_city VFR_QLD_noncity \\\n", + "feature target target target target \n", + "timestamp \n", + "2006-01-01 28 9 3003 2287 \n", + "2006-02-01 168 349 1957 2945 \n", + "2006-03-01 390 84 2619 2870 \n", + "2006-04-01 244 218 2097 2344 \n", + "2006-05-01 153 125 2703 2933 \n", + "\n", + "segment VFR_SA_city VFR_SA_noncity VFR_TAS_city VFR_TAS_noncity \\\n", + "feature target target target target \n", + "timestamp \n", + "2006-01-01 1324 869 602 748 \n", + "2006-02-01 806 639 257 266 \n", + "2006-03-01 1078 375 130 261 \n", + "2006-04-01 568 641 137 257 \n", + "2006-05-01 887 798 347 437 \n", + "\n", + "segment VFR_VIC_city VFR_VIC_noncity VFR_WA_city VFR_WA_noncity \n", + "feature target target target target \n", + "timestamp \n", + "2006-01-01 2565 3428 1019 762 \n", + "2006-02-01 1852 2255 750 603 \n", + "2006-03-01 1882 1929 953 734 \n", + "2006-04-01 2208 2882 999 715 \n", + "2006-05-01 2988 3164 1396 630 " + ] + }, + "execution_count": 17, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "hierarchical_ts = TSDataset(df=hierarchical_df, freq=\"MS\", hierarchical_structure=hierarchical_structure)\n", + "hierarchical_ts.head()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now the dataset converted to hierarchical. We can examine what hierarchical levels were detected with the following code." + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "['total', 'reason_level', 'region_level', 'city_level']" + ] + }, + "execution_count": 18, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "hierarchical_ts.hierarchical_structure.level_names" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Ensure that the dataset is at the desired level." + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "'city_level'" + ] + }, + "execution_count": 19, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "hierarchical_ts.current_df_level" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The hierarchical dataset could be aggregated to higher levels with the `get_level_dataset` method." + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
segmentHolVFRBusOth
featuretargettargettargettarget
timestamp
2006-01-01459062604298152740
2006-02-012934720676118233466
2006-03-013249220582135656114
2006-04-013181321613114785976
2006-05-014679326947100273126
\n", + "
" + ], + "text/plain": [ + "segment Hol VFR Bus Oth\n", + "feature target target target target\n", + "timestamp \n", + "2006-01-01 45906 26042 9815 2740\n", + "2006-02-01 29347 20676 11823 3466\n", + "2006-03-01 32492 20582 13565 6114\n", + "2006-04-01 31813 21613 11478 5976\n", + "2006-05-01 46793 26947 10027 3126" + ] + }, + "execution_count": 20, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "hierarchical_ts.get_level_dataset(target_level=\"reason_level\").head()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 3. Reconciliation methods " + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "In this section we will examine the reconciliation methods available in ETNA.\n", + "Hierarchical time series reconciliation allows for the readjustment of predictions produced by separate models on\n", + "a set of hierarchically linked time series in order to make the forecasts more accurate, and ensure that they are coherent.\n", + "\n", + "There are several reconciliation methods in the ETNA library:\n", + "* Bottom-up approach\n", + "* Top-down approach\n", + "\n", + "Middle-out reconciliation approach could be viewed as a composition of bottom-up and top-down approaches. This method could\n", + "be implemented using functionality from the library. For aggregation to higher levels, one could use provided bottom-up methods,\n", + "and for disaggregation to lower levels -- top-down methods.\n", + "\n", + "Reconciliation methods estimate mapping matrices to perform transitions between levels. These matrices are sparse.\n", + "ETNA uses `scipy.sparse.csr_matrix` to efficiently store them and perform computation.\n", + "\n", + "More detailed information about this and other reconciliation methods can be found [here](https://otexts.com/fpp2/hierarchical.html)." + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": {}, + "outputs": [], + "source": [ + "from etna.transforms import LagTransform\n", + "from etna.transforms import OneHotEncoderTransform\n", + "from etna.models import LinearPerSegmentModel\n", + "from etna.metrics import SMAPE\n", + "from etna.pipeline import HierarchicalPipeline\n", + "from etna.pipeline import Pipeline" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Some important definitions:\n", + "* **source level** - level of the hierarchy for model estimation, reconciliation applied to this level\n", + "* **target level** - desired level of the hierarchy, after reconciliation we want to have series at this level." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 3.1. Bottom-up approach " + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The main idea of this approach is to produce forecasts for time series at lower hierarchical levels and then perform\n", + "aggregation to higher levels.\n", + "\n", + "$$\n", + "\\hat y_{A,h} = \\hat y_{AA,h} + \\hat y_{AB,h}\n", + "$$\n", + "\n", + "$$\n", + "\\hat y_{B,h} = \\hat y_{BA,h} + \\hat y_{BB,h} + \\hat y_{BC,h}\n", + "$$\n", + "\n", + "where $h$ - forecast horizon." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "In matrix notation:\n", + "\n", + "\\begin{equation*}\n", + " \\begin{bmatrix}\n", + " \\hat y_{A,h} \\\\ \\hat y_{B,h}\n", + " \\end{bmatrix}\n", + " =\n", + " \\begin{bmatrix}\n", + " 1 & 1 & 0 & 0 & 0 \\\\\n", + " 0 & 0 & 1 & 1 & 1\n", + " \\end{bmatrix}\n", + " \\begin{bmatrix}\n", + " \\hat y_{AA,h} \\\\ \\hat y_{AB,h} \\\\ \\hat y_{BA,h} \\\\ \\hat y_{BB,h} \\\\ \\hat y_{BC,h}\n", + " \\end{bmatrix}\n", + "\\end{equation*}\n", + "\n", + "An advantage of this approach is that we are forecasting at the bottom-level of a structure and are able to capture\n", + "all the patterns of the individual series. On the other hand, bottom-level data can be quite noisy and more challenging\n", + "to model and forecast." + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "metadata": {}, + "outputs": [], + "source": [ + "from etna.reconciliation import BottomUpReconciliator" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "To create `BottomUpReconciliator` specify \"source\" and \"target\" levels for aggregation. Make sure that the source\n", + "level is lower in the hierarchy than the target level." + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "metadata": {}, + "outputs": [], + "source": [ + "reconciliator = BottomUpReconciliator(target_level=\"region_level\", source_level=\"city_level\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The next step is mapping matrix estimation. To do so pass hierarchical dataset to `fit` method. Current dataset level\n", + "should be lower or equal to source level." + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([[1, 1, 0, ..., 0, 0, 0],\n", + " [0, 0, 1, ..., 0, 0, 0],\n", + " [0, 0, 0, ..., 0, 0, 0],\n", + " ...,\n", + " [0, 0, 0, ..., 0, 0, 0],\n", + " [0, 0, 0, ..., 1, 0, 0],\n", + " [0, 0, 0, ..., 0, 1, 1]])" + ] + }, + "execution_count": 24, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "reconciliator.fit(ts=hierarchical_ts)\n", + "reconciliator.mapping_matrix.toarray()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now `reconciliator` is ready to perform aggregation to target level." + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
segmentBus_NSW_cityBus_NSW_noncityBus_NT_cityBus_NT_noncityBus_QLD_cityBus_QLD_noncityBus_SA_cityBus_SA_noncityBus_TAS_cityBus_TAS_noncityBus_VIC_cityBus_VIC_noncityBus_WA_cityBus_WA_noncityHol_NSW_cityHol_NSW_noncityHol_NT_cityHol_NT_noncityHol_QLD_cityHol_QLD_noncityHol_SA_cityHol_SA_noncityHol_TAS_cityHol_TAS_noncityHol_VIC_cityHol_VIC_noncityHol_WA_cityHol_WA_noncityOth_NSW_cityOth_NSW_noncityOth_NT_cityOth_NT_noncityOth_QLD_cityOth_QLD_noncityOth_SA_cityOth_SA_noncityOth_TAS_cityOth_TAS_noncityOth_VIC_cityOth_VIC_noncityOth_WA_cityOth_WA_noncityVFR_NSW_cityVFR_NSW_noncityVFR_NT_cityVFR_NT_noncityVFR_QLD_cityVFR_QLD_noncityVFR_SA_cityVFR_SA_noncityVFR_TAS_cityVFR_TAS_noncityVFR_VIC_cityVFR_VIC_noncityVFR_WA_cityVFR_WA_noncity
featuretargettargettargettargettargettargettargettargettargettargettargettargettargettargettargettargettargettargettargettargettargettargettargettargettargettargettargettargettargettargettargettargettargettargettargettargettargettargettargettargettargettargettargettargettargettargettargettargettargettargettargettargettargettargettargettarget
timestamp
2006-01-0112011684136801111982388456116107116498453287430961449310186468843908882201619148325317881138320663965103584312712447376241812861683727096689289300322871324869602748256534281019762
2006-02-0120202281138170776144834640383290101481135616871479954820129723203990521141440968914394586105913956575816939669170142221366122932317099218456451683491957294580663925726618522255750603
2006-03-011975211845210841079230039036019610711537914401120160973016197454758697547610931273311488357211012297540893150338270116439731532231283183801166222550523908426192870107837513026118821929953734
2006-04-0115001963243160112817522556357022812455085391252152091381642503328478157116993719491906357511282433745115717245321453519426048432703364101139291853852442182097234456864113725722082882999715
2006-05-01119621511941891192155938628013020595057258244119581419462151493051178732150523159025178441156027274265581547458557147337760265293162283154723215312527032933887798347437298831641396630
\n", + "
" + ], + "text/plain": [ + "segment Bus_NSW_city Bus_NSW_noncity Bus_NT_city Bus_NT_noncity \\\n", + "feature target target target target \n", + "timestamp \n", + "2006-01-01 1201 1684 136 80 \n", + "2006-02-01 2020 2281 138 170 \n", + "2006-03-01 1975 2118 452 1084 \n", + "2006-04-01 1500 1963 243 160 \n", + "2006-05-01 1196 2151 194 189 \n", + "\n", + "segment Bus_QLD_city Bus_QLD_noncity Bus_SA_city Bus_SA_noncity \\\n", + "feature target target target target \n", + "timestamp \n", + "2006-01-01 1111 982 388 456 \n", + "2006-02-01 776 1448 346 403 \n", + "2006-03-01 1079 2300 390 360 \n", + "2006-04-01 1128 1752 255 635 \n", + "2006-05-01 1192 1559 386 280 \n", + "\n", + "segment Bus_TAS_city Bus_TAS_noncity Bus_VIC_city Bus_VIC_noncity \\\n", + "feature target target target target \n", + "timestamp \n", + "2006-01-01 116 107 1164 984 \n", + "2006-02-01 83 290 1014 811 \n", + "2006-03-01 196 107 1153 791 \n", + "2006-04-01 70 228 1245 508 \n", + "2006-05-01 130 205 950 572 \n", + "\n", + "segment Bus_WA_city Bus_WA_noncity Hol_NSW_city Hol_NSW_noncity \\\n", + "feature target target target target \n", + "timestamp \n", + "2006-01-01 532 874 3096 14493 \n", + "2006-02-01 356 1687 1479 9548 \n", + "2006-03-01 440 1120 1609 7301 \n", + "2006-04-01 539 1252 1520 9138 \n", + "2006-05-01 582 441 1958 14194 \n", + "\n", + "segment Hol_NT_city Hol_NT_noncity Hol_QLD_city Hol_QLD_noncity \\\n", + "feature target target target target \n", + "timestamp \n", + "2006-01-01 101 86 4688 4390 \n", + "2006-02-01 201 297 2320 3990 \n", + "2006-03-01 619 745 4758 6975 \n", + "2006-04-01 164 250 3328 4781 \n", + "2006-05-01 62 151 4930 5117 \n", + "\n", + "segment Hol_SA_city Hol_SA_noncity Hol_TAS_city Hol_TAS_noncity \\\n", + "feature target target target target \n", + "timestamp \n", + "2006-01-01 888 2201 619 1483 \n", + "2006-02-01 521 1414 409 689 \n", + "2006-03-01 476 1093 127 331 \n", + "2006-04-01 571 1699 371 949 \n", + "2006-05-01 873 2150 523 1590 \n", + "\n", + "segment Hol_VIC_city Hol_VIC_noncity Hol_WA_city Hol_WA_noncity \\\n", + "feature target target target target \n", + "timestamp \n", + "2006-01-01 2531 7881 1383 2066 \n", + "2006-02-01 1439 4586 1059 1395 \n", + "2006-03-01 1488 3572 1101 2297 \n", + "2006-04-01 1906 3575 1128 2433 \n", + "2006-05-01 2517 8441 1560 2727 \n", + "\n", + "segment Oth_NSW_city Oth_NSW_noncity Oth_NT_city Oth_NT_noncity \\\n", + "feature target target target target \n", + "timestamp \n", + "2006-01-01 396 510 35 8 \n", + "2006-02-01 657 581 69 39 \n", + "2006-03-01 540 893 150 338 \n", + "2006-04-01 745 1157 172 453 \n", + "2006-05-01 426 558 15 47 \n", + "\n", + "segment Oth_QLD_city Oth_QLD_noncity Oth_SA_city Oth_SA_noncity \\\n", + "feature target target target target \n", + "timestamp \n", + "2006-01-01 431 271 244 73 \n", + "2006-02-01 669 170 142 221 \n", + "2006-03-01 270 1164 397 315 \n", + "2006-04-01 214 535 194 260 \n", + "2006-05-01 458 557 147 33 \n", + "\n", + "segment Oth_TAS_city Oth_TAS_noncity Oth_VIC_city Oth_VIC_noncity \\\n", + "feature target target target target \n", + "timestamp \n", + "2006-01-01 76 24 181 286 \n", + "2006-02-01 36 61 229 323 \n", + "2006-03-01 32 23 128 318 \n", + "2006-04-01 48 43 270 336 \n", + "2006-05-01 77 60 265 293 \n", + "\n", + "segment Oth_WA_city Oth_WA_noncity VFR_NSW_city VFR_NSW_noncity \\\n", + "feature target target target target \n", + "timestamp \n", + "2006-01-01 168 37 2709 6689 \n", + "2006-02-01 170 99 2184 5645 \n", + "2006-03-01 380 1166 2225 5052 \n", + "2006-04-01 410 1139 2918 5385 \n", + "2006-05-01 162 28 3154 7232 \n", + "\n", + "segment VFR_NT_city VFR_NT_noncity VFR_QLD_city VFR_QLD_noncity \\\n", + "feature target target target target \n", + "timestamp \n", + "2006-01-01 28 9 3003 2287 \n", + "2006-02-01 168 349 1957 2945 \n", + "2006-03-01 390 84 2619 2870 \n", + "2006-04-01 244 218 2097 2344 \n", + "2006-05-01 153 125 2703 2933 \n", + "\n", + "segment VFR_SA_city VFR_SA_noncity VFR_TAS_city VFR_TAS_noncity \\\n", + "feature target target target target \n", + "timestamp \n", + "2006-01-01 1324 869 602 748 \n", + "2006-02-01 806 639 257 266 \n", + "2006-03-01 1078 375 130 261 \n", + "2006-04-01 568 641 137 257 \n", + "2006-05-01 887 798 347 437 \n", + "\n", + "segment VFR_VIC_city VFR_VIC_noncity VFR_WA_city VFR_WA_noncity \n", + "feature target target target target \n", + "timestamp \n", + "2006-01-01 2565 3428 1019 762 \n", + "2006-02-01 1852 2255 750 603 \n", + "2006-03-01 1882 1929 953 734 \n", + "2006-04-01 2208 2882 999 715 \n", + "2006-05-01 2988 3164 1396 630 " + ] + }, + "execution_count": 25, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "reconciliator.aggregate(ts=hierarchical_ts).head(5)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "`HierarchicalPipeline` provides the ability to perform forecasts reconciliation in pipeline.\n", + "A couple of points to keep in mind while working with this type of pipeline:\n", + "1. Transforms and model work with the dataset on the **source** level.\n", + "2. Forecasts are automatically reconciliated to the **target** level, metrics reported for **target** level as well." + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.\n", + "[Parallel(n_jobs=1)]: Done 1 out of 1 | elapsed: 0.4s remaining: 0.0s\n", + "[Parallel(n_jobs=1)]: Done 2 out of 2 | elapsed: 0.8s remaining: 0.0s\n", + "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 1.3s remaining: 0.0s\n", + "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 1.3s finished\n" + ] + } + ], + "source": [ + "pipeline = HierarchicalPipeline(\n", + " transforms=[\n", + " LagTransform(in_column=\"target\", lags=[1, 2, 3, 4, 6, 12]),\n", + " ],\n", + " model=LinearPerSegmentModel(),\n", + " reconciliator=BottomUpReconciliator(target_level=\"region_level\", source_level=\"city_level\"),\n", + ")\n", + "\n", + "bottom_up_metrics, _, _ = pipeline.backtest(ts=hierarchical_ts, metrics=[SMAPE()], n_folds=3, aggregate_metrics=True)\n", + "bottom_up_metrics = bottom_up_metrics.set_index(\"segment\").add_suffix(\"_bottom_up\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 3.2. Top-down approach " + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Top-down approach is based on the idea of generating forecasts for time series at higher hierarchy levels and then\n", + "performing disaggregation to lower levels. This approach can be expressed with the following formula:\n", + "\n", + "\\begin{align*}\n", + "\\hat y_{AA,h} = p_{AA} \\hat y_A, &&\n", + "\\hat y_{AB,h} = p_{AB} \\hat y_A, &&\n", + "\\hat y_{BA,h} = p_{BA} \\hat y_B, &&\n", + "\\hat y_{BB,h} = p_{BB} \\hat y_B, &&\n", + "\\hat y_{BC,h} = p_{BC} \\hat y_B\n", + "\\end{align*}\n", + "\n", + "In matrix notations this equation could be rewritten as:\n", + "\n", + "\\begin{equation}\n", + " \\begin{bmatrix}\n", + " \\hat y_{AA,h} \\\\ \\hat y_{AB,h} \\\\ \\hat y_{BA,h} \\\\ \\hat y_{BB,h} \\\\ \\hat y_{BC,h}\n", + " \\end{bmatrix}\n", + " =\n", + " \\begin{bmatrix}\n", + " p_{AA} & 0 & 0 & 0 & 0 \\\\\n", + " 0 & p_{AB} & 0 & 0 & 0 \\\\\n", + " 0 & 0 & p_{BA} & 0 & 0 \\\\\n", + " 0 & 0 & 0 & p_{BB} & 0 \\\\\n", + " 0 & 0 & 0 & 0 & p_{BC} \\\\\n", + " \\end{bmatrix}\n", + " \\begin{bmatrix}\n", + " 1 & 0 \\\\\n", + " 1 & 0 \\\\\n", + " 0 & 1 \\\\\n", + " 0 & 1 \\\\\n", + " 0 & 1 \\\\\n", + " \\end{bmatrix}\n", + " \\begin{bmatrix}\n", + " \\hat y_{A,h} \\\\ \\hat y_{B,h}\n", + " \\end{bmatrix}\n", + " =\n", + " P S^T\n", + " \\begin{bmatrix}\n", + " \\hat y_{A,h} \\\\ \\hat y_{B,h}\n", + " \\end{bmatrix}\n", + "\\end{equation}\n", + "\n", + "The main challenge of this approach is proportions estimation.\n", + "In ETNA library, there are two methods available:\n", + "* Average historical proportions (AHP)\n", + "* Proportions of the historical averages (PHA)\n", + "\n", + "**Average historical proportions**\n", + "\n", + "This method is based on averaging historical proportions:\n", + "\\begin{equation}\n", + "\\large p_i = \\frac{1}{n} \\sum_{t = T - n}^{T} \\frac{y_{i, t}}{y_t}\n", + "\\end{equation}\n", + "\n", + "where $n$ - window size, $T$ - latest timestamp, $y_{i, t}$ - time series at the lower level, $y_t$ - time series at\n", + "the higher level. Both $y_{i, t}$ and $y_t$ have hierarchical relationship.\n", + "\n", + "**Proportions of the historical averages**\n", + "This approach uses a proportion of the averages for estimation:\n", + "\\begin{equation}\n", + "\\large p_i = \\sum_{t = T - n}^{T} \\frac{y_{i, t}}{n} \\Bigg / \\sum_{t = T - n}^{T} \\frac{y_t}{n}\n", + "\\end{equation}\n", + "\n", + "where $n$ - window size, $T$ - latest timestamp, $y_{i, t}$ - time series at the lower level, $y_t$ - time series at\n", + "the higher level. Both $y_{i, t}$ and $y_t$ have hierarchical relationship.\n", + "\n", + "Described methods require only series at the higher level for forecasting. Advantages of this approach are: simplicity and\n", + "reliability. Loss of information is main disadvantage of the approach.\n", + "\n", + "This method can be useful when it is needed to forecast lower level series, but some of them have more noise.\n", + "Aggregation to a higher level and reconciliation back helps to use more meaningful information while modeling.\n" + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "metadata": {}, + "outputs": [], + "source": [ + "from etna.reconciliation import TopDownReconciliator" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "`TopDownReconciliator` accepts four arguments in its constructor. You need to specify reconciliation levels,\n", + "a method and a window size. First, let's look at the AHP top-down reconciliation method." + ] + }, + { + "cell_type": "code", + "execution_count": 28, + "metadata": {}, + "outputs": [], + "source": [ + "ahp_reconciliator = TopDownReconciliator(\n", + " target_level=\"region_level\", source_level=\"reason_level\", method=\"AHP\", period=6\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The top-down approach has slightly different dataset levels requirements in comparison to the bottom-up method.\n", + "Here source level should be higher than the target level, and the current dataset level should not be higher\n", + "than the target level.\n", + "\n", + "After all level requirements met and the reconciliator is fitted, we can obtain the mapping matrix. Note, that now\n", + "mapping matrix contains reconciliation proportions, and not only zeros and ones.\n", + "\n", + "Columns of the top-down mapping matrix correspond to segments at the source level of the hierarchy, and rows to\n", + "the segments at the target level." + ] + }, + { + "cell_type": "code", + "execution_count": 29, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([[0.29517217, 0. , 0. , 0. ],\n", + " [0.17522331, 0. , 0. , 0. ],\n", + " [0.29906179, 0. , 0. , 0. ],\n", + " [0.06509802, 0. , 0. , 0. ],\n", + " [0.10138424, 0. , 0. , 0. ],\n", + " [0.0348691 , 0. , 0. , 0. ],\n", + " [0.02919136, 0. , 0. , 0. ],\n", + " [0. , 0.35663824, 0. , 0. ],\n", + " [0. , 0.19596791, 0. , 0. ],\n", + " [0. , 0.25065754, 0. , 0. ],\n", + " [0. , 0.06313639, 0. , 0. ],\n", + " [0. , 0.09261382, 0. , 0. ],\n", + " [0. , 0.02383924, 0. , 0. ],\n", + " [0. , 0.01714686, 0. , 0. ],\n", + " [0. , 0. , 0.29766462, 0. ],\n", + " [0. , 0. , 0.16667059, 0. ],\n", + " [0. , 0. , 0.27550314, 0. ],\n", + " [0. , 0. , 0.0654707 , 0. ],\n", + " [0. , 0. , 0.13979554, 0. ],\n", + " [0. , 0. , 0.0245672 , 0. ],\n", + " [0. , 0. , 0.03032821, 0. ],\n", + " [0. , 0. , 0. , 0.29191277],\n", + " [0. , 0. , 0. , 0.15036933],\n", + " [0. , 0. , 0. , 0.25667986],\n", + " [0. , 0. , 0. , 0.09445469],\n", + " [0. , 0. , 0. , 0.1319362 ],\n", + " [0. , 0. , 0. , 0.03209989],\n", + " [0. , 0. , 0. , 0.04254726]])" + ] + }, + "execution_count": 29, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "ahp_reconciliator.fit(ts=hierarchical_ts)\n", + "ahp_reconciliator.mapping_matrix.toarray()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Let’s fit `HierarchicalPipeline` with **AHP** method." + ] + }, + { + "cell_type": "code", + "execution_count": 30, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.\n", + "[Parallel(n_jobs=1)]: Done 1 out of 1 | elapsed: 0.1s remaining: 0.0s\n", + "[Parallel(n_jobs=1)]: Done 2 out of 2 | elapsed: 0.3s remaining: 0.0s\n", + "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 0.5s remaining: 0.0s\n", + "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 0.5s finished\n" + ] + } + ], + "source": [ + "reconciliator = TopDownReconciliator(target_level=\"region_level\", source_level=\"reason_level\", method=\"AHP\", period=9)\n", + "\n", + "pipeline = HierarchicalPipeline(\n", + " transforms=[\n", + " LagTransform(in_column=\"target\", lags=[1, 2, 3, 4, 6, 12]),\n", + " ],\n", + " model=LinearPerSegmentModel(),\n", + " reconciliator=reconciliator,\n", + ")\n", + "\n", + "ahp_metrics, _, _ = pipeline.backtest(ts=hierarchical_ts, metrics=[SMAPE()], n_folds=3, aggregate_metrics=True)\n", + "ahp_metrics = ahp_metrics.set_index(\"segment\").add_suffix(\"_ahp\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "To use another top-down proportion estimation method pass different method name to the `TopDownReconciliator` constructor.\n", + "Let's take a look at the PHA method." + ] + }, + { + "cell_type": "code", + "execution_count": 31, + "metadata": {}, + "outputs": [], + "source": [ + "pha_reconciliator = TopDownReconciliator(\n", + " target_level=\"region_level\", source_level=\"reason_level\", method=\"PHA\", period=6\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "It should be noted that the fitted mapping matrix has the same structure as in the previous method, but with slightly\n", + "different non-zero values." + ] + }, + { + "cell_type": "code", + "execution_count": 32, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([[0.29761574, 0. , 0. , 0. ],\n", + " [0.17910202, 0. , 0. , 0. ],\n", + " [0.29400697, 0. , 0. , 0. ],\n", + " [0.0651224 , 0. , 0. , 0. ],\n", + " [0.10000206, 0. , 0. , 0. ],\n", + " [0.03596948, 0. , 0. , 0. ],\n", + " [0.02818132, 0. , 0. , 0. ],\n", + " [0. , 0.35710317, 0. , 0. ],\n", + " [0. , 0.19744442, 0. , 0. ],\n", + " [0. , 0.24879185, 0. , 0. ],\n", + " [0. , 0.06362301, 0. , 0. ],\n", + " [0. , 0.09206311, 0. , 0. ],\n", + " [0. , 0.02404128, 0. , 0. ],\n", + " [0. , 0.01693316, 0. , 0. ],\n", + " [0. , 0. , 0.29730368, 0. ],\n", + " [0. , 0. , 0.16779538, 0. ],\n", + " [0. , 0. , 0.27544335, 0. ],\n", + " [0. , 0. , 0.06506127, 0. ],\n", + " [0. , 0. , 0.139399 , 0. ],\n", + " [0. , 0. , 0.02441176, 0. ],\n", + " [0. , 0. , 0.03058557, 0. ],\n", + " [0. , 0. , 0. , 0.28940705],\n", + " [0. , 0. , 0. , 0.14772684],\n", + " [0. , 0. , 0. , 0.26106345],\n", + " [0. , 0. , 0. , 0.09481879],\n", + " [0. , 0. , 0. , 0.13193001],\n", + " [0. , 0. , 0. , 0.03034655],\n", + " [0. , 0. , 0. , 0.04470731]])" + ] + }, + "execution_count": 32, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "pha_reconciliator.fit(ts=hierarchical_ts)\n", + "pha_reconciliator.mapping_matrix.toarray()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Let’s fit `HierarchicalPipeline` with **PHA** method." + ] + }, + { + "cell_type": "code", + "execution_count": 33, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.\n", + "[Parallel(n_jobs=1)]: Done 1 out of 1 | elapsed: 0.1s remaining: 0.0s\n", + "[Parallel(n_jobs=1)]: Done 2 out of 2 | elapsed: 0.2s remaining: 0.0s\n", + "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 0.4s remaining: 0.0s\n", + "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 0.4s finished\n" + ] + } + ], + "source": [ + "reconciliator = TopDownReconciliator(target_level=\"region_level\", source_level=\"reason_level\", method=\"PHA\", period=9)\n", + "\n", + "pipeline = HierarchicalPipeline(\n", + " transforms=[\n", + " LagTransform(in_column=\"target\", lags=[1, 2, 3, 4, 6, 12]),\n", + " ],\n", + " model=LinearPerSegmentModel(),\n", + " reconciliator=reconciliator,\n", + ")\n", + "\n", + "pha_metrics, _, _ = pipeline.backtest(ts=hierarchical_ts, metrics=[SMAPE()], n_folds=3, aggregate_metrics=True)\n", + "pha_metrics = pha_metrics.set_index(\"segment\").add_suffix(\"_pha\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Finally, let's forecast the middle level series directly." + ] + }, + { + "cell_type": "code", + "execution_count": 34, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.\n", + "[Parallel(n_jobs=1)]: Done 1 out of 1 | elapsed: 0.2s remaining: 0.0s\n", + "[Parallel(n_jobs=1)]: Done 2 out of 2 | elapsed: 0.5s remaining: 0.0s\n", + "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 0.8s remaining: 0.0s\n", + "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 0.8s finished\n" + ] + } + ], + "source": [ + "region_level_ts = hierarchical_ts.get_level_dataset(target_level=\"region_level\")\n", + "\n", + "pipeline = Pipeline(\n", + " transforms=[\n", + " LagTransform(in_column=\"target\", lags=[1, 2, 3, 4, 6, 12]),\n", + " ],\n", + " model=LinearPerSegmentModel(),\n", + ")\n", + "\n", + "region_level_metric, _, _ = pipeline.backtest(ts=region_level_ts, metrics=[SMAPE()], n_folds=3, aggregate_metrics=True)\n", + "\n", + "region_level_metric = region_level_metric.set_index(\"segment\").add_suffix(\"_region_level\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now we can take a look at metrics and compare methods." + ] + }, + { + "cell_type": "code", + "execution_count": 35, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
SMAPE_bottom_upSMAPE_ahpSMAPE_phaSMAPE_region_level
segment
Bus_NSW5.2704226.5193906.3180208.002023
Bus_NT25.76501815.15447314.73489435.648559
Bus_QLD18.2541623.7272783.8438375.920212
Bus_SA15.28232218.19676618.44347717.586339
Bus_TAS30.69501325.93255525.1451208.810328
Bus_VIC15.1162124.7556574.25207810.312053
Bus_WA10.00930418.51430718.41531610.715275
Hol_NSW14.4541657.7056298.0112449.115648
Hol_NT53.25068744.94929446.82134917.153756
Hol_QLD9.6241668.6479207.72220511.364234
Hol_SA8.20226920.08590019.78693111.244287
Hol_TAS51.59238650.64441451.20585455.117682
Hol_VIC7.12526917.98048420.27013221.994822
Hol_WA16.41513813.30313213.70301925.802063
Oth_NSW29.98723835.33528335.11397922.802959
Oth_NT98.03249350.69476355.75584248.984850
Oth_QLD31.46430326.66885227.80464414.136124
Oth_SA24.09880641.84852341.91169822.057562
Oth_TAS55.18720846.45779244.25270423.528327
Oth_VIC31.36579537.31090636.37275325.495443
Oth_WA26.89459223.56125226.07198125.078132
VFR_NSW4.9775857.0881597.0675668.696804
VFR_NT46.56588828.79628629.00183535.465418
VFR_QLD12.6750374.3129794.3707224.169244
VFR_SA15.61337619.78045920.27812224.620504
VFR_TAS33.18277326.50568529.20635928.587697
VFR_VIC9.23716410.54998110.22606121.911153
VFR_WA17.41611512.32912611.7021463.941069
\n", + "
" + ], + "text/plain": [ + " SMAPE_bottom_up SMAPE_ahp SMAPE_pha SMAPE_region_level\n", + "segment \n", + "Bus_NSW 5.270422 6.519390 6.318020 8.002023\n", + "Bus_NT 25.765018 15.154473 14.734894 35.648559\n", + "Bus_QLD 18.254162 3.727278 3.843837 5.920212\n", + "Bus_SA 15.282322 18.196766 18.443477 17.586339\n", + "Bus_TAS 30.695013 25.932555 25.145120 8.810328\n", + "Bus_VIC 15.116212 4.755657 4.252078 10.312053\n", + "Bus_WA 10.009304 18.514307 18.415316 10.715275\n", + "Hol_NSW 14.454165 7.705629 8.011244 9.115648\n", + "Hol_NT 53.250687 44.949294 46.821349 17.153756\n", + "Hol_QLD 9.624166 8.647920 7.722205 11.364234\n", + "Hol_SA 8.202269 20.085900 19.786931 11.244287\n", + "Hol_TAS 51.592386 50.644414 51.205854 55.117682\n", + "Hol_VIC 7.125269 17.980484 20.270132 21.994822\n", + "Hol_WA 16.415138 13.303132 13.703019 25.802063\n", + "Oth_NSW 29.987238 35.335283 35.113979 22.802959\n", + "Oth_NT 98.032493 50.694763 55.755842 48.984850\n", + "Oth_QLD 31.464303 26.668852 27.804644 14.136124\n", + "Oth_SA 24.098806 41.848523 41.911698 22.057562\n", + "Oth_TAS 55.187208 46.457792 44.252704 23.528327\n", + "Oth_VIC 31.365795 37.310906 36.372753 25.495443\n", + "Oth_WA 26.894592 23.561252 26.071981 25.078132\n", + "VFR_NSW 4.977585 7.088159 7.067566 8.696804\n", + "VFR_NT 46.565888 28.796286 29.001835 35.465418\n", + "VFR_QLD 12.675037 4.312979 4.370722 4.169244\n", + "VFR_SA 15.613376 19.780459 20.278122 24.620504\n", + "VFR_TAS 33.182773 26.505685 29.206359 28.587697\n", + "VFR_VIC 9.237164 10.549981 10.226061 21.911153\n", + "VFR_WA 17.416115 12.329126 11.702146 3.941069" + ] + }, + "execution_count": 35, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "all_metrics = pd.concat([bottom_up_metrics, ahp_metrics, pha_metrics, region_level_metric], axis=1)\n", + "\n", + "all_metrics" + ] + }, + { + "cell_type": "code", + "execution_count": 36, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "SMAPE_bottom_up 25.634104\n", + "SMAPE_ahp 22.405616\n", + "SMAPE_pha 22.778925\n", + "SMAPE_region_level 19.937949\n", + "dtype: float64" + ] + }, + "execution_count": 36, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "all_metrics.mean()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The results presented above show that using reconciliation methods can improve forecasting quality\n", + "for some segments. In this particular case, the direct forecast for segments at the Reason level is slightly better\n", + "on average than the reconciliation forecasts." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 4. Exogenous variables for hierarchical forecasts " + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "This section shows how exogenous variables can be added to a hierarchical `TSDataset`.\n", + "\n", + "Before adding exogenous variables to the dataset, we should decide at which level we should place them. Model fitting and\n", + "initial forecasting in the `HierarchicalPipeline` are made at the **source level**. So exogenous variables should be at the\n", + "**source level** as well.\n", + "\n", + "Let's try to add monthly indicators to our model." + ] + }, + { + "cell_type": "code", + "execution_count": 37, + "metadata": {}, + "outputs": [], + "source": [ + "from etna.datasets.utils import duplicate_data\n", + "\n", + "horizon = 3\n", + "exog_index = pd.date_range(\"2006-01-01\", periods=periods + horizon, freq=\"MS\")\n", + "\n", + "months_df = pd.DataFrame({\"timestamp\": exog_index.values, \"month\": exog_index.month.astype(\"category\")})\n", + "\n", + "reason_level_segments = hierarchical_ts.hierarchical_structure.get_level_segments(level_name=\"reason_level\")" + ] + }, + { + "cell_type": "code", + "execution_count": 38, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
segmentBusHolOthVFR
featuremonthmonthmonthmonth
timestamp
2006-01-011111
2006-02-012222
2006-03-013333
2006-04-014444
2006-05-015555
\n", + "
" + ], + "text/plain": [ + "segment Bus Hol Oth VFR\n", + "feature month month month month\n", + "timestamp \n", + "2006-01-01 1 1 1 1\n", + "2006-02-01 2 2 2 2\n", + "2006-03-01 3 3 3 3\n", + "2006-04-01 4 4 4 4\n", + "2006-05-01 5 5 5 5" + ] + }, + "execution_count": 38, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "months_ts = duplicate_data(df=months_df, segments=reason_level_segments)\n", + "months_ts.head()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Previous block showed how to create exogenous variables and convert to a hierarchical format manually.\n", + "Another way to convert exogenous variables to a hierarchical dataset is to use `TSDataset.to_hierarchical_dataset`." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "First, let's convert the dataframe to hierarchical long format." + ] + }, + { + "cell_type": "code", + "execution_count": 39, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
timestampmonthreason
02006-01-011Hol
12006-02-012Hol
22006-03-013Hol
32006-04-014Hol
42006-05-015Hol
\n", + "
" + ], + "text/plain": [ + " timestamp month reason\n", + "0 2006-01-01 1 Hol\n", + "1 2006-02-01 2 Hol\n", + "2 2006-03-01 3 Hol\n", + "3 2006-04-01 4 Hol\n", + "4 2006-05-01 5 Hol" + ] + }, + "execution_count": 39, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "months_ts = duplicate_data(df=months_df, segments=reason_level_segments, format=\"long\")\n", + "months_ts.rename(columns={\"segment\": \"reason\"}, inplace=True)\n", + "months_ts.head()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now we are ready to use `to_hierarchical_dataset` method. When using this method with exogenous data\n", + "pass `return_hierarchy=False`, because we want to use hierarchical structure from target variables.\n", + "Setting `keep_level_columns=True` will add level columns to the result dataframe." + ] + }, + { + "cell_type": "code", + "execution_count": 40, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
segmentBusHolOthVFR
featuremonthmonthmonthmonth
timestamp
2006-01-011111
2006-02-012222
2006-03-013333
2006-04-014444
2006-05-015555
\n", + "
" + ], + "text/plain": [ + "segment Bus Hol Oth VFR\n", + "feature month month month month\n", + "timestamp \n", + "2006-01-01 1 1 1 1\n", + "2006-02-01 2 2 2 2\n", + "2006-03-01 3 3 3 3\n", + "2006-04-01 4 4 4 4\n", + "2006-05-01 5 5 5 5" + ] + }, + "execution_count": 40, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "months_ts, _ = TSDataset.to_hierarchical_dataset(df=months_ts, level_columns=[\"reason\"], return_hierarchy=False)\n", + "months_ts.head()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "When dataframe with exogenous variables is prepared, create new hierarchical dataset with exogenous variables." + ] + }, + { + "cell_type": "code", + "execution_count": 41, + "metadata": {}, + "outputs": [], + "source": [ + "hierarchical_ts_w_exog = TSDataset(\n", + " df=hierarchical_df,\n", + " df_exog=months_ts,\n", + " hierarchical_structure=hierarchical_structure,\n", + " freq=\"MS\",\n", + " known_future=\"all\",\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 42, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "'df_level=city_level, df_exog_level=reason_level'" + ] + }, + "execution_count": 42, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "f\"df_level={hierarchical_ts_w_exog.current_df_level}, df_exog_level={hierarchical_ts_w_exog.current_df_exog_level}\"" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Here we can see different levels for the dataframes inside the dataset. In such case exogenous variables wouldn't be merged to target\n", + "variables." + ] + }, + { + "cell_type": "code", + "execution_count": 43, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
segmentBus_NSW_cityBus_NSW_noncityBus_NT_cityBus_NT_noncityBus_QLD_cityBus_QLD_noncityBus_SA_cityBus_SA_noncityBus_TAS_cityBus_TAS_noncityBus_VIC_cityBus_VIC_noncityBus_WA_cityBus_WA_noncityHol_NSW_cityHol_NSW_noncityHol_NT_cityHol_NT_noncityHol_QLD_cityHol_QLD_noncityHol_SA_cityHol_SA_noncityHol_TAS_cityHol_TAS_noncityHol_VIC_cityHol_VIC_noncityHol_WA_cityHol_WA_noncityOth_NSW_cityOth_NSW_noncityOth_NT_cityOth_NT_noncityOth_QLD_cityOth_QLD_noncityOth_SA_cityOth_SA_noncityOth_TAS_cityOth_TAS_noncityOth_VIC_cityOth_VIC_noncityOth_WA_cityOth_WA_noncityVFR_NSW_cityVFR_NSW_noncityVFR_NT_cityVFR_NT_noncityVFR_QLD_cityVFR_QLD_noncityVFR_SA_cityVFR_SA_noncityVFR_TAS_cityVFR_TAS_noncityVFR_VIC_cityVFR_VIC_noncityVFR_WA_cityVFR_WA_noncity
featuretargettargettargettargettargettargettargettargettargettargettargettargettargettargettargettargettargettargettargettargettargettargettargettargettargettargettargettargettargettargettargettargettargettargettargettargettargettargettargettargettargettargettargettargettargettargettargettargettargettargettargettargettargettargettargettarget
timestamp
2006-01-0112011684136801111982388456116107116498453287430961449310186468843908882201619148325317881138320663965103584312712447376241812861683727096689289300322871324869602748256534281019762
2006-02-0120202281138170776144834640383290101481135616871479954820129723203990521141440968914394586105913956575816939669170142221366122932317099218456451683491957294580663925726618522255750603
2006-03-011975211845210841079230039036019610711537914401120160973016197454758697547610931273311488357211012297540893150338270116439731532231283183801166222550523908426192870107837513026118821929953734
2006-04-0115001963243160112817522556357022812455085391252152091381642503328478157116993719491906357511282433745115717245321453519426048432703364101139291853852442182097234456864113725722082882999715
2006-05-01119621511941891192155938628013020595057258244119581419462151493051178732150523159025178441156027274265581547458557147337760265293162283154723215312527032933887798347437298831641396630
\n", + "
" + ], + "text/plain": [ + "segment Bus_NSW_city Bus_NSW_noncity Bus_NT_city Bus_NT_noncity \\\n", + "feature target target target target \n", + "timestamp \n", + "2006-01-01 1201 1684 136 80 \n", + "2006-02-01 2020 2281 138 170 \n", + "2006-03-01 1975 2118 452 1084 \n", + "2006-04-01 1500 1963 243 160 \n", + "2006-05-01 1196 2151 194 189 \n", + "\n", + "segment Bus_QLD_city Bus_QLD_noncity Bus_SA_city Bus_SA_noncity \\\n", + "feature target target target target \n", + "timestamp \n", + "2006-01-01 1111 982 388 456 \n", + "2006-02-01 776 1448 346 403 \n", + "2006-03-01 1079 2300 390 360 \n", + "2006-04-01 1128 1752 255 635 \n", + "2006-05-01 1192 1559 386 280 \n", + "\n", + "segment Bus_TAS_city Bus_TAS_noncity Bus_VIC_city Bus_VIC_noncity \\\n", + "feature target target target target \n", + "timestamp \n", + "2006-01-01 116 107 1164 984 \n", + "2006-02-01 83 290 1014 811 \n", + "2006-03-01 196 107 1153 791 \n", + "2006-04-01 70 228 1245 508 \n", + "2006-05-01 130 205 950 572 \n", + "\n", + "segment Bus_WA_city Bus_WA_noncity Hol_NSW_city Hol_NSW_noncity \\\n", + "feature target target target target \n", + "timestamp \n", + "2006-01-01 532 874 3096 14493 \n", + "2006-02-01 356 1687 1479 9548 \n", + "2006-03-01 440 1120 1609 7301 \n", + "2006-04-01 539 1252 1520 9138 \n", + "2006-05-01 582 441 1958 14194 \n", + "\n", + "segment Hol_NT_city Hol_NT_noncity Hol_QLD_city Hol_QLD_noncity \\\n", + "feature target target target target \n", + "timestamp \n", + "2006-01-01 101 86 4688 4390 \n", + "2006-02-01 201 297 2320 3990 \n", + "2006-03-01 619 745 4758 6975 \n", + "2006-04-01 164 250 3328 4781 \n", + "2006-05-01 62 151 4930 5117 \n", + "\n", + "segment Hol_SA_city Hol_SA_noncity Hol_TAS_city Hol_TAS_noncity \\\n", + "feature target target target target \n", + "timestamp \n", + "2006-01-01 888 2201 619 1483 \n", + "2006-02-01 521 1414 409 689 \n", + "2006-03-01 476 1093 127 331 \n", + "2006-04-01 571 1699 371 949 \n", + "2006-05-01 873 2150 523 1590 \n", + "\n", + "segment Hol_VIC_city Hol_VIC_noncity Hol_WA_city Hol_WA_noncity \\\n", + "feature target target target target \n", + "timestamp \n", + "2006-01-01 2531 7881 1383 2066 \n", + "2006-02-01 1439 4586 1059 1395 \n", + "2006-03-01 1488 3572 1101 2297 \n", + "2006-04-01 1906 3575 1128 2433 \n", + "2006-05-01 2517 8441 1560 2727 \n", + "\n", + "segment Oth_NSW_city Oth_NSW_noncity Oth_NT_city Oth_NT_noncity \\\n", + "feature target target target target \n", + "timestamp \n", + "2006-01-01 396 510 35 8 \n", + "2006-02-01 657 581 69 39 \n", + "2006-03-01 540 893 150 338 \n", + "2006-04-01 745 1157 172 453 \n", + "2006-05-01 426 558 15 47 \n", + "\n", + "segment Oth_QLD_city Oth_QLD_noncity Oth_SA_city Oth_SA_noncity \\\n", + "feature target target target target \n", + "timestamp \n", + "2006-01-01 431 271 244 73 \n", + "2006-02-01 669 170 142 221 \n", + "2006-03-01 270 1164 397 315 \n", + "2006-04-01 214 535 194 260 \n", + "2006-05-01 458 557 147 33 \n", + "\n", + "segment Oth_TAS_city Oth_TAS_noncity Oth_VIC_city Oth_VIC_noncity \\\n", + "feature target target target target \n", + "timestamp \n", + "2006-01-01 76 24 181 286 \n", + "2006-02-01 36 61 229 323 \n", + "2006-03-01 32 23 128 318 \n", + "2006-04-01 48 43 270 336 \n", + "2006-05-01 77 60 265 293 \n", + "\n", + "segment Oth_WA_city Oth_WA_noncity VFR_NSW_city VFR_NSW_noncity \\\n", + "feature target target target target \n", + "timestamp \n", + "2006-01-01 168 37 2709 6689 \n", + "2006-02-01 170 99 2184 5645 \n", + "2006-03-01 380 1166 2225 5052 \n", + "2006-04-01 410 1139 2918 5385 \n", + "2006-05-01 162 28 3154 7232 \n", + "\n", + "segment VFR_NT_city VFR_NT_noncity VFR_QLD_city VFR_QLD_noncity \\\n", + "feature target target target target \n", + "timestamp \n", + "2006-01-01 28 9 3003 2287 \n", + "2006-02-01 168 349 1957 2945 \n", + "2006-03-01 390 84 2619 2870 \n", + "2006-04-01 244 218 2097 2344 \n", + "2006-05-01 153 125 2703 2933 \n", + "\n", + "segment VFR_SA_city VFR_SA_noncity VFR_TAS_city VFR_TAS_noncity \\\n", + "feature target target target target \n", + "timestamp \n", + "2006-01-01 1324 869 602 748 \n", + "2006-02-01 806 639 257 266 \n", + "2006-03-01 1078 375 130 261 \n", + "2006-04-01 568 641 137 257 \n", + "2006-05-01 887 798 347 437 \n", + "\n", + "segment VFR_VIC_city VFR_VIC_noncity VFR_WA_city VFR_WA_noncity \n", + "feature target target target target \n", + "timestamp \n", + "2006-01-01 2565 3428 1019 762 \n", + "2006-02-01 1852 2255 750 603 \n", + "2006-03-01 1882 1929 953 734 \n", + "2006-04-01 2208 2882 999 715 \n", + "2006-05-01 2988 3164 1396 630 " + ] + }, + "execution_count": 43, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "hierarchical_ts_w_exog.head()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Exogenous data will be merged only when both dataframes are at the same level, so we can perform reconciliation to do this.\n", + "Right now, our dataset is lower, than the exogenous variables, so they aren't merged.\n", + "To perform aggregation to higher levels, we can use `get_level_dataset` method." + ] + }, + { + "cell_type": "code", + "execution_count": 44, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
segmentBusHolOthVFR
featuremonthtargetmonthtargetmonthtargetmonthtarget
timestamp
2006-01-0119815.0145906.012740.0126042.0
2006-02-01211823.0229347.023466.0220676.0
2006-03-01313565.0332492.036114.0320582.0
2006-04-01411478.0431813.045976.0421613.0
2006-05-01510027.0546793.053126.0526947.0
\n", + "
" + ], + "text/plain": [ + "segment Bus Hol Oth VFR \n", + "feature month target month target month target month target\n", + "timestamp \n", + "2006-01-01 1 9815.0 1 45906.0 1 2740.0 1 26042.0\n", + "2006-02-01 2 11823.0 2 29347.0 2 3466.0 2 20676.0\n", + "2006-03-01 3 13565.0 3 32492.0 3 6114.0 3 20582.0\n", + "2006-04-01 4 11478.0 4 31813.0 4 5976.0 4 21613.0\n", + "2006-05-01 5 10027.0 5 46793.0 5 3126.0 5 26947.0" + ] + }, + "execution_count": 44, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "hierarchical_ts_w_exog.get_level_dataset(target_level=\"reason_level\").head()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The modeling process stays the same as in the previous cases." + ] + }, + { + "cell_type": "code", + "execution_count": 45, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.\n", + "[Parallel(n_jobs=1)]: Done 1 out of 1 | elapsed: 0.2s remaining: 0.0s\n", + "[Parallel(n_jobs=1)]: Done 2 out of 2 | elapsed: 0.5s remaining: 0.0s\n", + "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 0.8s remaining: 0.0s\n", + "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 0.8s finished\n" + ] + } + ], + "source": [ + "region_level_ts_w_exog = hierarchical_ts_w_exog.get_level_dataset(target_level=\"region_level\")\n", + "\n", + "pipeline = HierarchicalPipeline(\n", + " transforms=[\n", + " OneHotEncoderTransform(in_column=\"month\"),\n", + " LagTransform(in_column=\"target\", lags=[1, 2, 3, 4, 6, 12]),\n", + " ],\n", + " model=LinearPerSegmentModel(),\n", + " reconciliator=TopDownReconciliator(\n", + " target_level=\"region_level\", source_level=\"reason_level\", period=9, method=\"AHP\"\n", + " ),\n", + ")\n", + "\n", + "metric, _, _ = pipeline.backtest(ts=region_level_ts_w_exog, metrics=[SMAPE()], n_folds=3, aggregate_metrics=True)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.6" + } + }, + "nbformat": 4, + "nbformat_minor": 1 +} diff --git a/tests/conftest.py b/tests/conftest.py index 5f2e1d2a2..abc843471 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -6,6 +6,7 @@ import pytest from etna.datasets import generate_const_df +from etna.datasets.hierarchical_structure import HierarchicalStructure from etna.datasets.tsdataset import TSDataset @@ -473,3 +474,229 @@ def toy_dataset_with_mean_shift_in_target(): "target_0.01": np.concatenate((np.array((-1, 3, 3, -4, -1)), np.array((-2, 3, -4, 5, -2)))).astype(np.float64), } return TSDataset.to_dataset(pd.DataFrame(df)) + + +@pytest.fixture +def hierarchical_structure(): + hs = HierarchicalStructure( + level_structure={"total": ["X", "Y"], "X": ["a", "b"], "Y": ["c", "d"]}, + level_names=["total", "market", "product"], + ) + return hs + + +@pytest.fixture +def total_level_df(): + df = pd.DataFrame( + { + "timestamp": ["2000-01-01", "2000-01-02"], + "segment": ["total"] * 2, + "target": [11.0, 22.0], + } + ) + df = TSDataset.to_dataset(df) + return df + + +@pytest.fixture +def market_level_df(): + df = pd.DataFrame( + { + "timestamp": ["2000-01-01", "2000-01-02"] * 2, + "segment": ["X"] * 2 + ["Y"] * 2, + "target": [1.0, 2.0] + [10.0, 20.0], + } + ) + df = TSDataset.to_dataset(df) + return df + + +@pytest.fixture +def product_level_df(): + df = pd.DataFrame( + { + "timestamp": ["2000-01-01", "2000-01-02"] * 4, + "segment": ["a"] * 2 + ["b"] * 2 + ["c"] * 2 + ["d"] * 2, + "target": [1.0, 1.0] + [0.0, 1.0] + [3.0, 18.0] + [7.0, 2.0], + } + ) + df = TSDataset.to_dataset(df) + return df + + +@pytest.fixture +def product_level_df_w_nans(): + df = pd.DataFrame( + { + "timestamp": ["2000-01-01", "2000-01-02", "2000-01-03", "2000-01-04"] * 4, + "segment": ["a"] * 4 + ["b"] * 4 + ["c"] * 4 + ["d"] * 4, + "target": [None, 0, 1, 2] + [3, 4, 5, None] + [7, 8, None, 9] + [10, 11, 12, 13], + } + ) + df = TSDataset.to_dataset(df) + return df + + +@pytest.fixture +def market_level_df_w_nans(): + df = pd.DataFrame( + { + "timestamp": ["2000-01-01", "2000-01-02", "2000-01-03", "2000-01-04"] * 2, + "segment": ["X"] * 4 + ["Y"] * 4, + "target": [None, 4, 6, None] + [17, 19, None, 22], + } + ) + df = TSDataset.to_dataset(df) + return df + + +@pytest.fixture +def total_level_df_w_nans(): + df = pd.DataFrame( + { + "timestamp": ["2000-01-01", "2000-01-02", "2000-01-03", "2000-01-04"], + "segment": ["total"] * 4, + "target": [None, 23, None, None], + } + ) + df = TSDataset.to_dataset(df) + return df + + +@pytest.fixture +def product_level_constant_hierarchical_df(): + df = pd.DataFrame( + { + "timestamp": ["2000-01-01", "2000-01-02", "2000-01-03", "2000-01-04"] * 4, + "segment": ["a"] * 4 + ["b"] * 4 + ["c"] * 4 + ["d"] * 4, + "target": [1, 1, 1, 1] + [2, 2, 2, 2] + [3, 3, 3, 3] + [4, 4, 4, 4], + } + ) + df = TSDataset.to_dataset(df) + return df + + +@pytest.fixture +def market_level_constant_hierarchical_df(): + df = pd.DataFrame( + { + "timestamp": ["2000-01-01", "2000-01-02", "2000-01-03", "2000-01-04"] * 2, + "segment": ["X"] * 4 + ["Y"] * 4, + "target": [3, 3, 3, 3] + [7, 7, 7, 7], + } + ) + df = TSDataset.to_dataset(df) + return df + + +@pytest.fixture +def market_level_constant_hierarchical_df_exog(): + df = pd.DataFrame( + { + "timestamp": ["2000-01-01", "2000-01-02", "2000-01-03", "2000-01-04", "2000-01-05", "2000-01-06"] * 2, + "segment": ["X"] * 6 + ["Y"] * 6, + "regressor": [1, 1, 1, 1, 1, 1] * 2, + } + ) + df = TSDataset.to_dataset(df) + return df + + +@pytest.fixture +def total_level_simple_hierarchical_ts(total_level_df, hierarchical_structure): + ts = TSDataset(df=total_level_df, freq="D", hierarchical_structure=hierarchical_structure) + return ts + + +@pytest.fixture +def market_level_simple_hierarchical_ts(market_level_df, hierarchical_structure): + ts = TSDataset(df=market_level_df, freq="D", hierarchical_structure=hierarchical_structure) + return ts + + +@pytest.fixture +def product_level_simple_hierarchical_ts(product_level_df, hierarchical_structure): + ts = TSDataset(df=product_level_df, freq="D", hierarchical_structure=hierarchical_structure) + return ts + + +@pytest.fixture +def simple_no_hierarchy_ts(market_level_df): + ts = TSDataset(df=market_level_df, freq="D") + return ts + + +@pytest.fixture +def market_level_constant_hierarchical_ts(market_level_constant_hierarchical_df, hierarchical_structure): + ts = TSDataset(df=market_level_constant_hierarchical_df, freq="D", hierarchical_structure=hierarchical_structure) + return ts + + +@pytest.fixture +def market_level_constant_hierarchical_ts_w_exog( + market_level_constant_hierarchical_df, market_level_constant_hierarchical_df_exog, hierarchical_structure +): + ts = TSDataset( + df=market_level_constant_hierarchical_df, + df_exog=market_level_constant_hierarchical_df_exog, + freq="D", + hierarchical_structure=hierarchical_structure, + known_future="all", + ) + return ts + + +@pytest.fixture +def product_level_constant_hierarchical_ts(product_level_constant_hierarchical_df, hierarchical_structure): + ts = TSDataset( + df=product_level_constant_hierarchical_df, + freq="D", + hierarchical_structure=hierarchical_structure, + ) + return ts + + +@pytest.fixture +def product_level_constant_hierarchical_ts_w_exog( + product_level_constant_hierarchical_df, market_level_constant_hierarchical_df_exog, hierarchical_structure +): + ts = TSDataset( + df=product_level_constant_hierarchical_df, + df_exog=market_level_constant_hierarchical_df_exog, + freq="D", + hierarchical_structure=hierarchical_structure, + known_future="all", + ) + return ts + + +@pytest.fixture +def product_level_constant_forecast_w_quantiles(hierarchical_structure): + df = pd.DataFrame( + { + "timestamp": ["2000-01-05", "2000-01-06"] * 4, + "segment": ["a"] * 2 + ["b"] * 2 + ["c"] * 2 + ["d"] * 2, + "target": [1, 1] + [2, 2] + [3, 3] + [4, 4], + "target_0.25": [1, 1] + [2, 2] + [3, 3] + [4, 4], + "target_0.75": [1, 1] + [2, 2] + [3, 3] + [4, 4], + } + ) + df = TSDataset.to_dataset(df=df) + ts = TSDataset(df=df, freq="D", hierarchical_structure=hierarchical_structure) + return ts + + +@pytest.fixture +def total_level_constant_forecast_w_quantiles(hierarchical_structure): + df = pd.DataFrame( + { + "timestamp": ["2000-01-05", "2000-01-06"], + "segment": ["total"] * 2, + "target": [10, 10], + "target_0.25": [10, 10], + "target_0.75": [10, 10], + } + ) + df = TSDataset.to_dataset(df=df) + ts = TSDataset(df=df, freq="D", hierarchical_structure=hierarchical_structure) + return ts diff --git a/tests/test_datasets/conftest.py b/tests/test_datasets/conftest.py new file mode 100644 index 000000000..203d0546d --- /dev/null +++ b/tests/test_datasets/conftest.py @@ -0,0 +1,21 @@ +import pytest + +from etna.datasets.hierarchical_structure import HierarchicalStructure + + +@pytest.fixture +def long_hierarchical_structure(): + hs = HierarchicalStructure( + level_structure={"total": ["X", "Y"], "X": ["a"], "a": ["c"], "Y": ["b"], "b": ["d"]}, + level_names=["l1", "l2", "l3", "l4"], + ) + return hs + + +@pytest.fixture +def tailed_hierarchical_structure(): + hs = HierarchicalStructure( + level_structure={"total": ["X", "Y"], "X": ["a"], "Y": ["c", "d"], "c": ["f"], "d": ["g"], "a": ["e", "h"]}, + level_names=["l1", "l2", "l3", "l4"], + ) + return hs diff --git a/tests/test_datasets/test_datasets_generation.py b/tests/test_datasets/test_datasets_generation.py index ccf6e1b1e..350ff96ee 100644 --- a/tests/test_datasets/test_datasets_generation.py +++ b/tests/test_datasets/test_datasets_generation.py @@ -4,7 +4,9 @@ from etna.datasets.datasets_generation import generate_ar_df from etna.datasets.datasets_generation import generate_const_df from etna.datasets.datasets_generation import generate_from_patterns_df +from etna.datasets.datasets_generation import generate_hierarchical_df from etna.datasets.datasets_generation import generate_periodic_df +from etna.datasets.tsdataset import TSDataset def check_equals(generated_value, expected_value, **kwargs): @@ -90,3 +92,58 @@ def test_simple_from_patterns_df_check(add_noise, checker): assert checker(from_patterns_df[from_patterns_df.segment == "segment_1"].iat[0, 2], patterns[1][0], sigma=sigma) assert checker(from_patterns_df[from_patterns_df.segment == "segment_1"].iat[3, 2], patterns[1][0], sigma=sigma) assert checker(from_patterns_df[from_patterns_df.segment == "segment_1"].iat[4, 2], patterns[1][1], sigma=sigma) + + +def test_generate_hierarchical_df_empty_n_segments_error(): + with pytest.raises(ValueError, match="`n_segments` should contain at least one positive integer!"): + generate_hierarchical_df(periods=2, n_segments=[]) + + +@pytest.mark.parametrize("n_segments", [[-1], [0, 2]]) +def test_generate_hierarchical_df_negative_size_error(n_segments): + with pytest.raises(ValueError, match="All `n_segments` elements should be positive!"): + generate_hierarchical_df(periods=2, n_segments=n_segments) + + +@pytest.mark.parametrize( + "periods,n_segments,expected_columns", + ( + (2, [1, 2], {"target", "timestamp", "level_0", "level_1"}), + (2, [2], {"target", "timestamp", "level_0"}), + (4, [3, 4], {"target", "timestamp", "level_0", "level_1"}), + (4, [3, 3], {"target", "timestamp", "level_0", "level_1"}), + ), +) +def test_generate_hierarchical_df_columns_set(periods, n_segments, expected_columns): + hierarchical_df = generate_hierarchical_df(periods=periods, n_segments=n_segments) + assert expected_columns == set(hierarchical_df.columns) + + +@pytest.mark.parametrize("periods,n_segments", ((2, [1, 2]), (2, [2]), (4, [3, 4]), (4, [3, 3]))) +def test_generate_hierarchical_df_periods(periods, n_segments): + hierarchical_df = generate_hierarchical_df(periods=periods, n_segments=n_segments) + assert hierarchical_df["timestamp"].nunique() == periods + + +@pytest.mark.parametrize("periods,n_segments", ((2, [1, 2]), (2, [2]), (4, [3, 4]), (4, [3, 3]))) +def test_generate_hierarchical_df_segments(periods, n_segments): + hierarchical_df = generate_hierarchical_df(periods=periods, n_segments=n_segments) + + for level_id, segment_size in enumerate(n_segments): + assert hierarchical_df[f"level_{level_id}"].nunique() == segment_size + + +@pytest.mark.parametrize("periods,n_segments", ((2, [1, 2]), (2, [2]), (4, [3, 4]), (4, [3, 3]))) +def test_generate_hierarchical_df_segments_names(periods, n_segments): + hierarchical_df = generate_hierarchical_df(periods=periods, n_segments=n_segments) + + num_levels = len(n_segments) + for level_id in range(num_levels): + assert all(hierarchical_df[f"level_{level_id}"].str.match(rf"l{level_id}s\d+")) + + +@pytest.mark.parametrize("periods,n_segments", ((2, [1, 2]), (2, [2]), (4, [3, 4]), (4, [3, 3]))) +def test_generate_hierarchical_df_convert_to_wide_format(periods, n_segments): + hierarchical_df = generate_hierarchical_df(periods=periods, n_segments=n_segments) + level_names = [f"level_{idx}" for idx in range(len(n_segments))] + TSDataset.to_hierarchical_dataset(df=hierarchical_df, level_columns=level_names) diff --git a/tests/test_datasets/test_hierarchical_dataset.py b/tests/test_datasets/test_hierarchical_dataset.py new file mode 100644 index 000000000..90ac295a6 --- /dev/null +++ b/tests/test_datasets/test_hierarchical_dataset.py @@ -0,0 +1,483 @@ +import numpy as np +import pandas as pd +import pytest + +from etna.datasets.hierarchical_structure import HierarchicalStructure +from etna.datasets.tsdataset import TSDataset + + +@pytest.fixture +def hierarchical_structure_complex(): + hs = HierarchicalStructure( + level_structure={ + "total": ["77", "120"], + "77": ["77_X"], + "120": ["120_Y"], + "77_X": ["77_X_1", "77_X_2"], + "120_Y": ["120_Y_3", "120_Y_4"], + }, + level_names=["total", "categorical", "string", "int"], + ) + return hs + + +@pytest.fixture +def different_level_segments_df(): + df = pd.DataFrame( + { + "timestamp": ["2000-01-01", "2000-01-02"] * 2, + "segment": ["X"] * 2 + ["a"] * 2, + "target": [1, 2] + [10, 20], + } + ) + df = TSDataset.to_dataset(df) + return df + + +@pytest.fixture +def level_columns_different_types_df(): + df = pd.DataFrame( + { + "timestamp": ["2000-01-01", "2000-01-02"] * 4, + "categorical": [77] * 4 + [120] * 4, + "string": ["X"] * 2 + ["X"] * 2 + ["Y"] * 2 + ["Y"] * 2, + "int": [1] * 2 + [2] * 2 + [3] * 2 + [4] * 2, + "target": [1, 2] + [10, 20] + [100, 200] + [1000, 2000], + } + ) + df["categorical"] = df["categorical"].astype("category") + return df + + +@pytest.fixture +def different_level_segments_df_exog(): + df = pd.DataFrame( + { + "timestamp": ["2000-01-01", "2000-01-02"] * 2, + "segment": ["X"] * 2 + ["a"] * 2, + "exog": [1, 2] + [10, 20], + } + ) + df = TSDataset.to_dataset(df) + return df + + +@pytest.fixture +def product_level_df_long(): + df = pd.DataFrame( + { + "timestamp": ["2000-01-01", "2000-01-02"] * 4, + "market": ["X"] * 2 + ["X"] * 2 + ["Y"] * 2 + ["Y"] * 2, + "product": ["a"] * 2 + ["b"] * 2 + ["c"] * 2 + ["d"] * 2, + "target": [1, 2] + [10, 20] + [100, 200] + [1000, 2000], + } + ) + return df + + +@pytest.fixture +def missing_segments_df(): + df = pd.DataFrame( + { + "timestamp": ["2000-01-01", "2000-01-02"], + "segment": ["X"] * 2, + "target": [1, 2], + } + ) + df = TSDataset.to_dataset(df) + return df + + +@pytest.fixture +def product_level_df_wide(product_level_df_long): + df = product_level_df_long + df["segment"] = ["X_a"] * 2 + ["X_b"] * 2 + ["Y_c"] * 2 + ["Y_d"] * 2 + df = TSDataset.to_dataset(df) + return df + + +@pytest.fixture +def market_level_df_exog(): + df_exog = pd.DataFrame( + { + "timestamp": ["2000-01-01", "2000-01-03"] * 2, + "segment": ["X"] * 2 + ["Y"] * 2, + "exog": [1.0, 5.0] + [10.0, 5.0], + "regressor": 1, + } + ) + df_exog = TSDataset.to_dataset(df_exog) + return df_exog + + +@pytest.fixture +def l4_level_df_long(): + df = pd.DataFrame( + { + "timestamp": ["2000-01-01", "2000-01-02"] * 2, + "segment": ["c"] * 2 + ["d"] * 2, + "target": [0, 1] + [2, 3], + } + ) + df = TSDataset.to_dataset(df) + return df + + +@pytest.fixture +def l3_level_df_long(): + df = pd.DataFrame( + { + "timestamp": ["2000-01-01", "2000-01-02"] * 2, + "segment": ["a"] * 2 + ["b"] * 2, + "target": [0, 1] + [2, 3], + } + ) + df = TSDataset.to_dataset(df) + return df + + +@pytest.fixture +def l2_level_df_long(): + df = pd.DataFrame( + { + "timestamp": ["2000-01-01", "2000-01-02"] * 2, + "segment": ["X"] * 2 + ["Y"] * 2, + "target": [0, 1] + [2, 3], + } + ) + df = TSDataset.to_dataset(df) + return df + + +@pytest.fixture +def l1_level_df_long(): + df = pd.DataFrame( + { + "timestamp": ["2000-01-01", "2000-01-02"], + "segment": ["total"] * 2, + "target": [2, 4], + } + ) + df = TSDataset.to_dataset(df) + return df + + +@pytest.fixture +def l4_level_df_tailed(): + df = pd.DataFrame( + { + "timestamp": ["2000-01-01", "2000-01-02"] * 4, + "segment": ["e"] * 2 + ["h"] * 2 + ["f"] * 2 + ["g"] * 2, + "target": [0, 1] + [2, 3] + [4, 5] + [6, 7], + } + ) + df = TSDataset.to_dataset(df) + return df + + +@pytest.fixture +def l3_level_df_tailed(): + df = pd.DataFrame( + { + "timestamp": ["2000-01-01", "2000-01-02"] * 3, + "segment": ["a"] * 2 + ["c"] * 2 + ["d"] * 2, + "target": [2, 4] + [4, 5] + [6, 7], + } + ) + df = TSDataset.to_dataset(df) + return df + + +@pytest.fixture +def l2_level_df_tailed(): + df = pd.DataFrame( + { + "timestamp": ["2000-01-01", "2000-01-02"] * 2, + "segment": ["X"] * 2 + ["Y"] * 2, + "target": [2, 4] + [10, 12], + } + ) + df = TSDataset.to_dataset(df) + return df + + +@pytest.fixture +def l1_level_df_tailed(): + df = pd.DataFrame( + { + "timestamp": ["2000-01-01", "2000-01-02"], + "segment": ["total"] * 2, + "target": [12, 16], + } + ) + df = TSDataset.to_dataset(df) + return df + + +@pytest.fixture +def simple_hierarchical_ts(market_level_df, hierarchical_structure): + df = market_level_df + ts = TSDataset(df=df, freq="D", hierarchical_structure=hierarchical_structure) + return ts + + +def test_get_dataframe_level_different_level_segments_fails(different_level_segments_df, simple_hierarchical_ts): + with pytest.raises(ValueError, match="Segments in dataframe are from more than 1 hierarchical levels!"): + simple_hierarchical_ts._get_dataframe_level(df=different_level_segments_df) + + +def test_get_dataframe_level_missing_segments_fails(missing_segments_df, simple_hierarchical_ts): + with pytest.raises(ValueError, match="Some segments of hierarchical level are missing in dataframe!"): + simple_hierarchical_ts._get_dataframe_level(df=missing_segments_df) + + +@pytest.mark.parametrize("df, expected_level", [("market_level_df", "market"), ("product_level_df", "product")]) +def test_get_dataframe(df, expected_level, simple_hierarchical_ts, request): + df = request.getfixturevalue(df) + df_level = simple_hierarchical_ts._get_dataframe_level(df=df) + assert df_level == expected_level + + +def test_init_different_level_segments_df_fails(different_level_segments_df, hierarchical_structure): + df = different_level_segments_df + with pytest.raises(ValueError, match="Segments in dataframe are from more than 1 hierarchical levels!"): + _ = TSDataset(df=df, freq="D", hierarchical_structure=hierarchical_structure) + + +def test_init_different_level_segments_df_exog_fails( + market_level_df, different_level_segments_df_exog, hierarchical_structure +): + df, df_exog = market_level_df, different_level_segments_df_exog + with pytest.raises(ValueError, match="Segments in dataframe are from more than 1 hierarchical levels!"): + _ = TSDataset(df=df, freq="D", df_exog=df_exog, hierarchical_structure=hierarchical_structure) + + +def test_init_df_same_level_df_exog( + market_level_df, market_level_df_exog, hierarchical_structure, expected_columns={"target", "regressor", "exog"} +): + df, df_exog = market_level_df, market_level_df_exog + ts = TSDataset(df=df, freq="D", df_exog=df_exog, hierarchical_structure=hierarchical_structure) + df_columns = set(ts.columns.get_level_values("feature")) + assert df_columns == expected_columns + + +def test_init_df_different_level_df_exog( + product_level_df, market_level_df_exog, hierarchical_structure, expected_columns={"target"} +): + df, df_exog = product_level_df, market_level_df_exog + ts = TSDataset(df=df, freq="D", df_exog=df_exog, hierarchical_structure=hierarchical_structure) + df_columns = set(ts.columns.get_level_values("feature")) + assert df_columns == expected_columns + + +def test_init_missing_segmnets_df(missing_segments_df, hierarchical_structure): + df = missing_segments_df + with pytest.raises(ValueError, match="Some segments of hierarchical level are missing in dataframe!"): + _ = TSDataset(df=df, freq="D", hierarchical_structure=hierarchical_structure) + + +def test_make_future_df_same_level_df_exog( + market_level_df, market_level_df_exog, hierarchical_structure, expected_columns={"target", "regressor", "exog"} +): + df, df_exog = market_level_df, market_level_df_exog + ts = TSDataset(df=df, freq="D", df_exog=df_exog, hierarchical_structure=hierarchical_structure) + future = ts.make_future(future_steps=4) + future_columns = set(future.columns.get_level_values("feature")) + assert future_columns == expected_columns + + +def test_make_future_df_different_level_df_exog( + product_level_df, market_level_df_exog, hierarchical_structure, expected_columns={"target"} +): + df, df_exog = product_level_df, market_level_df_exog + ts = TSDataset(df=df, freq="D", df_exog=df_exog, hierarchical_structure=hierarchical_structure) + future = ts.make_future(future_steps=4) + future_columns = set(future.columns.get_level_values("feature")) + assert future_columns == expected_columns + + +def test_level_names_with_hierarchical_structure(simple_hierarchical_ts, expected_names=["total", "market", "product"]): + ts_level_names = simple_hierarchical_ts.level_names() + assert sorted(ts_level_names) == sorted(expected_names) + + +def test_level_names_without_hierarchical_structure(market_level_df): + df = market_level_df + ts = TSDataset(df=df, freq="D") + ts_level_names = ts.level_names() + assert ts_level_names is None + + +def test_to_hierarchical_dataset_not_change_input_df(product_level_df_long): + df = product_level_df_long + df_before = df.copy() + df_after, _ = TSDataset.to_hierarchical_dataset( + df=df, level_columns=["market", "product"], keep_level_columns=False, return_hierarchy=True + ) + pd.testing.assert_frame_equal(df, df_before) + + +@pytest.mark.parametrize( + "df_fixture, level_columns, sep, expected_segments", + [ + ("product_level_df_long", ["market", "product"], "_", ["X_a", "X_b", "Y_c", "Y_d"]), + ("product_level_df_long", ["market", "product"], "#", ["X#a", "X#b", "Y#c", "Y#d"]), + ("product_level_df_long", ["product"], "_", ["a", "b", "c", "d"]), + ( + "level_columns_different_types_df", + ["categorical", "string", "int"], + "_", + ["77_X_1", "77_X_2", "120_Y_3", "120_Y_4"], + ), + ], +) +def test_to_hierarchical_dataset_correct_segments(df_fixture, level_columns, sep, expected_segments, request): + df = request.getfixturevalue(df_fixture) + df, _ = TSDataset.to_hierarchical_dataset(df=df, level_columns=level_columns, sep=sep, return_hierarchy=True) + df_segments = df.columns.get_level_values("segment").unique() + assert sorted(df_segments) == sorted(expected_segments) + + +@pytest.mark.parametrize( + "keep_level_columns, expected_columns", [(True, ["target", "market", "product"]), (False, ["target"])] +) +def test_to_hierarchical_dataset_correct_columns(product_level_df_long, keep_level_columns, expected_columns): + df = product_level_df_long + df, _ = TSDataset.to_hierarchical_dataset( + df=df, keep_level_columns=keep_level_columns, level_columns=["market", "product"], return_hierarchy=True + ) + df_columns = df.columns.get_level_values("feature").unique() + assert sorted(df_columns) == sorted(expected_columns) + + +def test_to_hierarchical_dataset_correct_dataframe(product_level_df_long, product_level_df_wide): + df_wide_obtained, _ = TSDataset.to_hierarchical_dataset( + df=product_level_df_long, keep_level_columns=True, level_columns=["market", "product"], return_hierarchy=True + ) + pd.testing.assert_frame_equal(df_wide_obtained, product_level_df_wide) + + +def test_to_hierarchical_dataset_hierarchical_structure( + level_columns_different_types_df, hierarchical_structure_complex +): + _, hs = TSDataset.to_hierarchical_dataset( + df=level_columns_different_types_df, level_columns=["categorical", "string", "int"], return_hierarchy=True + ) + assert hs.level_names == hierarchical_structure_complex.level_names + for level_name in hierarchical_structure_complex.level_names: + assert level_name in hs.level_names + expected_level_segments = hierarchical_structure_complex.get_level_segments(level_name=level_name) + obtained_level_segments = hs.get_level_segments(level_name=level_name) + assert sorted(obtained_level_segments) == sorted(expected_level_segments) + + +@pytest.mark.parametrize( + "hierarchical_structure_name,source_df_name,target_level,target_df_name", + ( + ("hierarchical_structure", "product_level_df", "market", "market_level_df"), + ("hierarchical_structure", "product_level_df", "total", "total_level_df"), + ("hierarchical_structure", "market_level_df", "total", "total_level_df"), + ("hierarchical_structure", "product_level_df_w_nans", "market", "market_level_df_w_nans"), + ("hierarchical_structure", "product_level_df_w_nans", "total", "total_level_df_w_nans"), + ("hierarchical_structure", "market_level_df_w_nans", "total", "total_level_df_w_nans"), + ("long_hierarchical_structure", "l4_level_df_long", "l3", "l3_level_df_long"), + ("long_hierarchical_structure", "l4_level_df_long", "l2", "l2_level_df_long"), + ("long_hierarchical_structure", "l4_level_df_long", "l1", "l1_level_df_long"), + ("long_hierarchical_structure", "l3_level_df_long", "l2", "l2_level_df_long"), + ("long_hierarchical_structure", "l3_level_df_long", "l1", "l1_level_df_long"), + ("long_hierarchical_structure", "l2_level_df_long", "l1", "l1_level_df_long"), + ("tailed_hierarchical_structure", "l4_level_df_tailed", "l3", "l3_level_df_tailed"), + ("tailed_hierarchical_structure", "l4_level_df_tailed", "l2", "l2_level_df_tailed"), + ("tailed_hierarchical_structure", "l4_level_df_tailed", "l1", "l1_level_df_tailed"), + ("tailed_hierarchical_structure", "l3_level_df_tailed", "l2", "l2_level_df_tailed"), + ("tailed_hierarchical_structure", "l3_level_df_tailed", "l1", "l1_level_df_tailed"), + ("tailed_hierarchical_structure", "l2_level_df_tailed", "l1", "l1_level_df_tailed"), + ), +) +def test_get_level_dataset(hierarchical_structure_name, source_df_name, target_level, target_df_name, request): + hierarchical_structure = request.getfixturevalue(hierarchical_structure_name) + + source_df = request.getfixturevalue(source_df_name) + source_ts = TSDataset(df=source_df, freq="D", hierarchical_structure=hierarchical_structure) + + target_df = request.getfixturevalue(target_df_name) + target_ts = TSDataset(df=target_df, freq="D", hierarchical_structure=hierarchical_structure) + + estimated_target_ts = source_ts.get_level_dataset(target_level) + + # check attributes + assert target_ts.freq == estimated_target_ts.freq + assert target_ts.hierarchical_structure == estimated_target_ts.hierarchical_structure + assert target_ts.current_df_level == estimated_target_ts.current_df_level + + pd.testing.assert_frame_equal(target_ts.df, estimated_target_ts.df) + + +@pytest.mark.parametrize( + "source_df_name,target_level,target_df_name", + ( + ("product_level_df", "market", "market_level_df"), + ("product_level_df", "total", "total_level_df"), + ("market_level_df", "total", "total_level_df"), + ), +) +def test_get_level_dataset_with_exog( + source_df_name, target_level, target_df_name, market_level_df_exog, hierarchical_structure, request +): + source_df = request.getfixturevalue(source_df_name) + source_ts = TSDataset( + df=source_df, + df_exog=market_level_df_exog, + freq="D", + hierarchical_structure=hierarchical_structure, + known_future=["regressor"], + ) + + target_df = request.getfixturevalue(target_df_name) + target_ts = TSDataset( + df=target_df, + df_exog=market_level_df_exog, + freq="D", + hierarchical_structure=hierarchical_structure, + known_future=["regressor"], + ) + + estimated_target_ts = source_ts.get_level_dataset(target_level) + + assert target_ts.current_df_exog_level == estimated_target_ts.current_df_exog_level + pd.testing.assert_frame_equal(target_ts.df, estimated_target_ts.df) + + +def test_get_level_dataset_no_hierarchy_error(market_level_df): + ts = TSDataset(df=market_level_df, freq="D") + with pytest.raises(ValueError, match="Method could be applied only to instances with a hierarchy!"): + ts.get_level_dataset(target_level="total") + + +@pytest.mark.parametrize( + "target_level", + ("", "abcd"), +) +def test_get_level_dataset_invalid_level_name_error(simple_hierarchical_ts, target_level): + with pytest.raises(ValueError, match=f"Invalid level name: {target_level}"): + simple_hierarchical_ts.get_level_dataset(target_level=target_level) + + +def test_get_level_dataset_lower_level_error(simple_hierarchical_ts): + with pytest.raises( + ValueError, match="Target level should be higher in the hierarchy than the current level of dataframe!" + ): + simple_hierarchical_ts.get_level_dataset(target_level="product") + + +@pytest.mark.parametrize( + "target_level,answer", + ( + ("product", np.array([[1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4], [1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4]])), + ("market", np.array([[3, 3, 3, 7, 7, 7], [3, 3, 3, 7, 7, 7]])), + ("total", np.array([[10, 10, 10], [10, 10, 10]])), + ), +) +def test_get_level_dataset_with_quantiles(product_level_constant_forecast_w_quantiles, target_level, answer): + forecast = product_level_constant_forecast_w_quantiles + np.testing.assert_array_almost_equal(forecast.get_level_dataset(target_level=target_level).df.values, answer) diff --git a/tests/test_datasets/test_hierarchical_structure.py b/tests/test_datasets/test_hierarchical_structure.py new file mode 100644 index 000000000..d4f953d4f --- /dev/null +++ b/tests/test_datasets/test_hierarchical_structure.py @@ -0,0 +1,278 @@ +from typing import Dict +from typing import List + +import numpy as np +import pytest + +from etna.datasets import HierarchicalStructure + + +@pytest.fixture +def simple_hierarchical_structure(): + return HierarchicalStructure( + level_structure={"total": ["X", "Y"], "X": ["a", "b"], "Y": ["c", "d"]}, level_names=["l1", "l2", "l3"] + ) + + +@pytest.mark.parametrize( + "struct, target,source,answer", + ( + ("simple_hierarchical_structure", "l1", "l1", np.array([[1]])), + ("simple_hierarchical_structure", "l2", "l2", np.array([[1, 0], [0, 1]])), + ("simple_hierarchical_structure", "l1", "l2", np.array([[1, 1]])), + ("simple_hierarchical_structure", "l1", "l3", np.array([[1, 1, 1, 1]])), + ("simple_hierarchical_structure", "l2", "l3", np.array([[1, 1, 0, 0], [0, 0, 1, 1]])), + ("tailed_hierarchical_structure", "l3", "l3", np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]])), + ("tailed_hierarchical_structure", "l1", "l2", np.array([[1, 1]])), + ("tailed_hierarchical_structure", "l1", "l3", np.array([[1, 1, 1]])), + ("tailed_hierarchical_structure", "l1", "l4", np.array([[1, 1, 1, 1]])), + ("tailed_hierarchical_structure", "l2", "l3", np.array([[1, 0, 0], [0, 1, 1]])), + ("tailed_hierarchical_structure", "l2", "l4", np.array([[1, 1, 0, 0], [0, 0, 1, 1]])), + ("tailed_hierarchical_structure", "l3", "l4", np.array([[1, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]])), + ("long_hierarchical_structure", "l1", "l2", np.array([[1, 1]])), + ("long_hierarchical_structure", "l1", "l3", np.array([[1, 1]])), + ("long_hierarchical_structure", "l1", "l4", np.array([[1, 1]])), + ("long_hierarchical_structure", "l2", "l3", np.array([[1, 0], [0, 1]])), + ("long_hierarchical_structure", "l2", "l4", np.array([[1, 0], [0, 1]])), + ("long_hierarchical_structure", "l3", "l4", np.array([[1, 0], [0, 1]])), + ), +) +def test_summing_matrix(struct: str, source: str, target: str, answer: np.ndarray, request: pytest.FixtureRequest): + np.testing.assert_array_almost_equal( + answer, request.getfixturevalue(struct).get_summing_matrix(target_level=target, source_level=source).toarray() + ) + + +@pytest.mark.parametrize( + "target,source,error", + ( + ("l0", "l2", "Invalid level name: l0"), + ("l1", "l0", "Invalid level name: l0"), + ("l2", "l1", "Target level must be higher or equal in hierarchy than source level!"), + ), +) +def test_level_transition_errors( + simple_hierarchical_structure: HierarchicalStructure, + target: str, + source: str, + error: str, +): + with pytest.raises(ValueError, match=error): + simple_hierarchical_structure.get_summing_matrix(target_level=target, source_level=source) + + +@pytest.mark.parametrize( + "structure", + ( + {"total": ["X", "Y"], "X": ["a"], "Y": ["c", "d"], "c": ["e", "f"]}, # e f leaves have lower level + {"total": ["X", "Y"], "X": ["a", "b"], "Y": ["c", "d"], "a": ["e"]}, # e has lower level + ), +) +def test_leaves_level_errors(structure: Dict[str, List[str]]): + with pytest.raises(ValueError, match="All hierarchy tree leaves must be on the same level!"): + HierarchicalStructure(level_structure=structure) + + +@pytest.mark.parametrize( + "structure,answer", + ( + ({"total": ["X", "Y"], "X": ["a", "b"], "Y": ["c", "d"]}, "total"), + ({"X": ["a", "b"]}, "X"), + ), +) +def test_root_finding(structure: Dict[str, List[str]], answer: str): + assert HierarchicalStructure._find_tree_root(structure) == answer + + +@pytest.mark.parametrize( + "structure,answer", + ( + ({"total": ["X", "Y"], "X": ["a", "b"], "Y": ["c", "d"]}, 7), + ({"X": ["a", "b"]}, 3), + ), +) +def test_num_nodes(structure: Dict[str, List[str]], answer: int): + assert HierarchicalStructure._find_num_nodes(structure) == answer + + +@pytest.mark.parametrize( + "level_names,tree_depth,answer", + ( + (None, 3, ["level_0", "level_1", "level_2"]), + (["l1", "l2", "l3", "l4"], 4, ["l1", "l2", "l3", "l4"]), + ), +) +def test_get_level_names(level_names: List[str], tree_depth: int, answer: List[str]): + assert HierarchicalStructure._get_level_names(level_names, tree_depth) == answer + + +@pytest.mark.parametrize( + "structure,answer", + ( + ({"total": ["X", "Y"], "X": ["a", "b"], "Y": ["c", "d"]}, [["total"], ["X", "Y"], ["a", "b", "c", "d"]]), + ({"X": ["a", "b"]}, [["X"], ["a", "b"]]), + ), +) +def test_find_hierarchy_levels(structure: Dict[str, List[str]], answer: List[List[str]]): + h = HierarchicalStructure(level_structure=structure) + hierarchy_levels = h._find_hierarchy_levels() + for i, level_segments in enumerate(answer): + assert hierarchy_levels[i] == level_segments + + +@pytest.mark.parametrize( + "structure,answer", + ( + ( + {"total": ["X", "Y"], "X": ["a", "b"], "Y": ["c", "d"]}, + {"total": 4, "X": 2, "Y": 2, "a": 1, "b": 1, "c": 1, "d": 1}, + ), + ({"total": ["X", "Y"], "X": ["a"], "Y": ["c", "d"]}, {"total": 3, "X": 1, "Y": 2, "a": 1, "c": 1, "d": 1}), + ({"X": ["a", "b"]}, {"X": 2, "a": 1, "b": 1}), + ), +) +def test_get_num_reachable_leafs(structure: Dict[str, List[str]], answer: Dict[str, int]): + h = HierarchicalStructure(level_structure=structure) + hierarchy_levels = h._find_hierarchy_levels() + reachable_leafs = h._get_num_reachable_leafs(hierarchy_levels) + assert len(reachable_leafs) == len(answer) + for segment in answer: + assert reachable_leafs[segment] == answer[segment] + + +@pytest.mark.parametrize( + "structure,level_names,answer", + ( + ( + {"total": ["X", "Y"], "X": ["a", "b"], "Y": ["c", "d"]}, + None, + {"level_0": 0, "level_1": 1, "level_2": 2}, + ), + ( + {"total": ["X", "Y"], "X": ["a", "b"], "Y": ["c", "d"]}, + ["l1", "l2", "l3"], + {"l1": 0, "l2": 1, "l3": 2}, + ), + ( + {"X": ["a"]}, + None, + {"level_0": 0, "level_1": 1}, + ), + ), +) +def test_level_to_index(structure: Dict[str, List[str]], level_names: List[str], answer: Dict[str, int]): + h = HierarchicalStructure(level_structure=structure, level_names=level_names) + assert len(h._level_to_index) == len(answer) + for level in answer: + assert h._level_to_index[level] == answer[level] + + +@pytest.mark.parametrize( + "structure,answer", + ( + ( + {"total": ["X", "Y"], "X": ["a", "b"], "Y": ["c", "d"]}, + { + "total": "level_0", + "X": "level_1", + "Y": "level_1", + "a": "level_2", + "b": "level_2", + "c": "level_2", + "d": "level_2", + }, + ), + ({"X": ["a"]}, {"X": "level_0", "a": "level_1"}), + ), +) +def test_segment_to_level(structure: Dict[str, List[str]], answer: Dict[str, str]): + h = HierarchicalStructure(level_structure=structure) + assert len(h._segment_to_level) == len(answer) + for segment in answer: + assert h._segment_to_level[segment] == answer[segment] + + +@pytest.mark.parametrize( + "structure", + ( + {"total": ["X", "Y"], "X": ["a", "b"], "Y": ["c", "d", "total"]}, # loop to root + {"X": ["a", "b"], "Y": ["c", "d"]}, # 2 trees + dict(), # empty list + ), +) +def test_root_finding_errors(structure: Dict[str, List[str]]): + with pytest.raises(ValueError, match="Invalid tree definition: unable to find root!"): + HierarchicalStructure(level_structure=structure) + + +@pytest.mark.parametrize( + "structure", + ( + {"total": ["X", "Y"], "X": ["a", "b"], "Y": ["c", "d"], "a": ["X"]}, # loop + {"total": ["X", "Y"], "X": ["a", "b"], "Y": ["c", "d", "Y"]}, # self loop + ), +) +def test_invalid_tree_structure_initialization_fails(structure: Dict[str, List[str]]): + with pytest.raises(ValueError, match="Invalid tree definition: invalid number of nodes and edges!"): + HierarchicalStructure(level_structure=structure) + + +@pytest.mark.parametrize( + "structure,names,answer", + ( + ({"total": ["X", "Y"], "X": ["a", "b"], "Y": ["c", "d"]}, None, ["level_0", "level_1", "level_2"]), + ({"total": ["X", "Y"], "X": ["a", "b"], "Y": ["c", "d"]}, ["l1", "l2", "l3"], ["l1", "l2", "l3"]), + ), +) +def test_level_names(structure: Dict[str, List[str]], names: List[str], answer: List[str]): + h = HierarchicalStructure(level_structure=structure, level_names=names) + assert h.level_names == answer + + +@pytest.mark.parametrize( + "structure,names", + ( + ({"total": ["X", "Y"], "X": ["a", "b"], "Y": ["c", "d"]}, ["l1"]), + ({"total": ["X", "Y"], "X": ["a", "b"], "Y": ["c", "d"]}, ["l1", "l2", "l3", "l4"]), + ), +) +def test_level_names_length_error(structure: Dict[str, List[str]], names: List[str]): + with pytest.raises(ValueError, match="Length of `level_names` must be equal to hierarchy tree depth!"): + HierarchicalStructure(level_structure=structure, level_names=names) + + +@pytest.mark.parametrize( + "level,answer", + ( + ("l1", ["total"]), + ("l2", ["X", "Y"]), + ("l3", ["a", "b", "c", "d"]), + ), +) +def test_level_segments(simple_hierarchical_structure: HierarchicalStructure, level: str, answer: List[str]): + assert simple_hierarchical_structure.get_level_segments(level) == answer + + +@pytest.mark.parametrize( + "segment,answer", + (("total", "l1"), ("Y", "l2"), ("c", "l3")), +) +def test_segments_level(simple_hierarchical_structure: HierarchicalStructure, segment: str, answer: str): + assert simple_hierarchical_structure.get_segment_level(segment) == answer + + +@pytest.mark.parametrize( + "target_level,answer", + (("l2", 1), ("l3", 2), ("l1", 0)), +) +def test_get_level_depth(simple_hierarchical_structure, target_level, answer): + assert simple_hierarchical_structure.get_level_depth(level_name=target_level) == answer + + +@pytest.mark.parametrize( + "target_level", + ("", "abcd"), +) +def test_get_level_depth_invalid_name_error(simple_hierarchical_structure, target_level): + with pytest.raises(ValueError, match=f"Invalid level name: {target_level}"): + simple_hierarchical_structure.get_level_depth(level_name=target_level) diff --git a/tests/test_datasets/test_utils.py b/tests/test_datasets/test_utils.py index 85fb5df10..f44a7f154 100644 --- a/tests/test_datasets/test_utils.py +++ b/tests/test_datasets/test_utils.py @@ -6,6 +6,8 @@ from etna.datasets import duplicate_data from etna.datasets import generate_ar_df from etna.datasets.utils import _TorchDataset +from etna.datasets.utils import get_level_dataframe +from etna.datasets.utils import get_target_with_quantiles from etna.datasets.utils import set_columns_wide @@ -174,3 +176,69 @@ def test_set_columns_wide( # compare values pd.testing.assert_frame_equal(df_obtained, df_expected) + + +@pytest.mark.parametrize("segments", (["s1"], ["s1", "s2"])) +@pytest.mark.parametrize( + "columns,answer", + ( + ({"a", "b"}, set()), + ({"a", "b", "target"}, {"target"}), + ({"a", "b", "target", "target_0.5"}, {"target", "target_0.5"}), + ({"a", "b", "target", "target_0.5", "target1"}, {"target", "target_0.5"}), + ), +) +def test_get_target_with_quantiles(segments, columns, answer): + columns = pd.MultiIndex.from_product([segments, columns], names=["segment", "feature"]) + targets_names = get_target_with_quantiles(columns) + assert targets_names == answer + + +@pytest.mark.parametrize("target_level,answer_name", (("market", "market_level_df"), ("total", "total_level_df"))) +def test_get_level_dataframe(product_level_simple_hierarchical_ts, target_level, answer_name, request): + ts = product_level_simple_hierarchical_ts + answer = request.getfixturevalue(answer_name) + answer.index.freq = "D" + + mapping_matrix = product_level_simple_hierarchical_ts.hierarchical_structure.get_summing_matrix( + target_level=target_level, source_level=ts.current_df_level + ) + + target_level_df = get_level_dataframe( + df=ts.df, + mapping_matrix=mapping_matrix, + source_level_segments=ts.hierarchical_structure.get_level_segments(level_name=ts.current_df_level), + target_level_segments=ts.hierarchical_structure.get_level_segments(level_name=target_level), + ) + + pd.testing.assert_frame_equal(target_level_df, answer) + + +@pytest.mark.parametrize( + "source_level_segments,target_level_segments,message", + ( + (("ABC", "c1"), ("X", "Y"), "Segments mismatch for provided dataframe and `source_level_segments`!"), + (("ABC", "a"), ("X", "Y"), "Segments mismatch for provided dataframe and `source_level_segments`!"), + ( + ("a", "b", "c", "d"), + ("X",), + "Number of target level segments do not match mapping matrix number of columns!", + ), + ), +) +def test_get_level_dataframe_segm_errors( + product_level_simple_hierarchical_ts, source_level_segments, target_level_segments, message +): + ts = product_level_simple_hierarchical_ts + + mapping_matrix = product_level_simple_hierarchical_ts.hierarchical_structure.get_summing_matrix( + target_level="market", source_level=ts.current_df_level + ) + + with pytest.raises(ValueError, match=message): + get_level_dataframe( + df=ts.df, + mapping_matrix=mapping_matrix, + source_level_segments=source_level_segments, + target_level_segments=target_level_segments, + ) diff --git a/tests/test_pipeline/test_hierarchical_pipeline.py b/tests/test_pipeline/test_hierarchical_pipeline.py new file mode 100644 index 000000000..8424a29a1 --- /dev/null +++ b/tests/test_pipeline/test_hierarchical_pipeline.py @@ -0,0 +1,262 @@ +from unittest.mock import Mock + +import numpy as np +import pytest + +from etna.datasets.utils import match_target_quantiles +from etna.metrics import MAE +from etna.metrics import Coverage +from etna.metrics import Width +from etna.models import LinearPerSegmentModel +from etna.models import NaiveModel +from etna.pipeline.hierarchical_pipeline import HierarchicalPipeline +from etna.reconciliation import BottomUpReconciliator +from etna.reconciliation import TopDownReconciliator +from etna.transforms import LagTransform +from etna.transforms import LinearTrendTransform +from etna.transforms import MeanTransform + + +@pytest.mark.parametrize( + "reconciliator", + ( + TopDownReconciliator(target_level="a", source_level="b", period=1, method="AHP"), + BottomUpReconciliator(target_level="a", source_level="b"), + ), +) +def test_init_pass(reconciliator): + model = NaiveModel() + pipeline = HierarchicalPipeline(reconciliator=reconciliator, model=model, transforms=[], horizon=1) + assert isinstance(pipeline.reconciliator, type(reconciliator)) + + +@pytest.mark.parametrize( + "reconciliator", + ( + TopDownReconciliator(target_level="market", source_level="total", period=1, method="AHP"), + BottomUpReconciliator(target_level="total", source_level="market"), + ), +) +def test_fit_mapping_matrix(market_level_simple_hierarchical_ts, reconciliator): + model = NaiveModel() + pipeline = HierarchicalPipeline(reconciliator=reconciliator, model=model, transforms=[], horizon=1) + + pipeline.reconciliator.fit = Mock() + pipeline.fit(market_level_simple_hierarchical_ts) + pipeline.reconciliator.fit.assert_called() + + +@pytest.mark.parametrize( + "reconciliator", + ( + TopDownReconciliator(target_level="market", source_level="total", period=1, method="AHP"), + BottomUpReconciliator(target_level="total", source_level="market"), + ), +) +def test_fit_dataset_level(market_level_simple_hierarchical_ts, reconciliator): + model = NaiveModel() + pipeline = HierarchicalPipeline(reconciliator=reconciliator, model=model, transforms=[], horizon=1) + pipeline.fit(market_level_simple_hierarchical_ts) + assert pipeline.ts.current_df_level == reconciliator.source_level + + +def test_fit_no_hierarchy(simple_no_hierarchy_ts): + model = NaiveModel() + reconciliator = BottomUpReconciliator(target_level="total", source_level="market") + pipeline = HierarchicalPipeline(reconciliator=reconciliator, model=model, transforms=[], horizon=1) + with pytest.raises(ValueError, match="The method can be applied only to instances with a hierarchy!"): + pipeline.fit(simple_no_hierarchy_ts) + + +@pytest.mark.parametrize( + "reconciliator,answer", + ( + (TopDownReconciliator(target_level="market", source_level="total", period=1, method="AHP"), 10), + (BottomUpReconciliator(target_level="total", source_level="market"), np.array([[3, 7]])), + ), +) +def test_raw_forecast_correctness(market_level_constant_hierarchical_ts, reconciliator, answer): + model = NaiveModel() + pipeline = HierarchicalPipeline(reconciliator=reconciliator, model=model, transforms=[], horizon=1) + pipeline.fit(ts=market_level_constant_hierarchical_ts) + forecast = pipeline.raw_forecast() + np.testing.assert_array_almost_equal(forecast[..., "target"].values, answer) + + +@pytest.mark.parametrize( + "reconciliator", + ( + TopDownReconciliator(target_level="market", source_level="total", period=1, method="AHP"), + BottomUpReconciliator(target_level="total", source_level="market"), + ), +) +def test_raw_forecast_level(market_level_simple_hierarchical_ts, reconciliator): + model = NaiveModel() + pipeline = HierarchicalPipeline(reconciliator=reconciliator, model=model, transforms=[], horizon=1) + pipeline.fit(ts=market_level_simple_hierarchical_ts) + forecast = pipeline.raw_forecast() + assert forecast.current_df_level == pipeline.reconciliator.source_level + + +@pytest.mark.parametrize( + "reconciliator,answer", + ( + (TopDownReconciliator(target_level="market", source_level="total", period=1, method="AHP"), np.array([[3, 7]])), + (BottomUpReconciliator(target_level="total", source_level="market"), 10), + ), +) +def test_forecast_correctness(market_level_constant_hierarchical_ts, reconciliator, answer): + model = NaiveModel() + pipeline = HierarchicalPipeline(reconciliator=reconciliator, model=model, transforms=[], horizon=1) + pipeline.fit(ts=market_level_constant_hierarchical_ts) + forecast = pipeline.forecast() + np.testing.assert_array_almost_equal(forecast[..., "target"].values, answer) + + +@pytest.mark.parametrize( + "reconciliator", + ( + TopDownReconciliator(target_level="market", source_level="total", period=1, method="AHP"), + BottomUpReconciliator(target_level="total", source_level="market"), + ), +) +def test_forecast_level(market_level_simple_hierarchical_ts, reconciliator): + model = NaiveModel() + pipeline = HierarchicalPipeline(reconciliator=reconciliator, model=model, transforms=[], horizon=1) + pipeline.fit(ts=market_level_simple_hierarchical_ts) + forecast = pipeline.forecast() + assert forecast.current_df_level == pipeline.reconciliator.target_level + + +@pytest.mark.parametrize( + "reconciliator", + ( + TopDownReconciliator(target_level="market", source_level="total", period=1, method="AHP"), + BottomUpReconciliator(target_level="total", source_level="market"), + ), +) +def test_forecast_columns_duplicates(market_level_constant_hierarchical_ts_w_exog, reconciliator): + ts = market_level_constant_hierarchical_ts_w_exog + model = NaiveModel() + pipeline = HierarchicalPipeline(reconciliator=reconciliator, model=model, transforms=[], horizon=1) + pipeline.fit(ts=ts) + forecast = pipeline.forecast() + assert not any(forecast.df.columns.duplicated()) + + +@pytest.mark.parametrize( + "reconciliator", + ( + TopDownReconciliator(target_level="market", source_level="total", period=1, method="AHP"), + TopDownReconciliator(target_level="market", source_level="total", period=1, method="PHA"), + BottomUpReconciliator(target_level="total", source_level="market"), + ), +) +def test_backtest(market_level_constant_hierarchical_ts, reconciliator): + ts = market_level_constant_hierarchical_ts + model = NaiveModel() + pipeline = HierarchicalPipeline(reconciliator=reconciliator, model=model, transforms=[], horizon=1) + metrics, _, _ = pipeline.backtest(ts=ts, metrics=[MAE()], n_folds=2, aggregate_metrics=True) + np.testing.assert_array_almost_equal(metrics["MAE"], 0) + + +@pytest.mark.parametrize( + "reconciliator", + ( + TopDownReconciliator(target_level="market", source_level="total", period=1, method="AHP"), + TopDownReconciliator(target_level="market", source_level="total", period=1, method="PHA"), + BottomUpReconciliator(target_level="total", source_level="market"), + ), +) +def test_backtest_w_transforms(market_level_constant_hierarchical_ts, reconciliator): + ts = market_level_constant_hierarchical_ts + model = LinearPerSegmentModel() + transforms = [ + MeanTransform(in_column="target", window=2), + LinearTrendTransform(in_column="target"), + LagTransform(in_column="target", lags=[1]), + ] + pipeline = HierarchicalPipeline(reconciliator=reconciliator, model=model, transforms=transforms, horizon=1) + metrics, _, _ = pipeline.backtest(ts=ts, metrics=[MAE()], n_folds=2, aggregate_metrics=True) + np.testing.assert_array_almost_equal(metrics["MAE"], 0) + + +@pytest.mark.parametrize( + "reconciliator", + ( + TopDownReconciliator(target_level="product", source_level="market", period=1, method="AHP"), + TopDownReconciliator(target_level="product", source_level="market", period=1, method="PHA"), + BottomUpReconciliator(target_level="total", source_level="market"), + ), +) +def test_backtest_w_exog(product_level_constant_hierarchical_ts_w_exog, reconciliator): + ts = product_level_constant_hierarchical_ts_w_exog + model = LinearPerSegmentModel() + pipeline = HierarchicalPipeline(reconciliator=reconciliator, model=model, transforms=[], horizon=1) + metrics, _, _ = pipeline.backtest(ts=ts, metrics=[MAE()], n_folds=2, aggregate_metrics=True) + np.testing.assert_array_almost_equal(metrics["MAE"], 0) + + +@pytest.mark.parametrize( + "reconciliator", + ( + TopDownReconciliator(target_level="product", source_level="market", period=1, method="PHA"), + BottomUpReconciliator(target_level="market", source_level="product"), + ), +) +def test_forecast_interval_presented(product_level_constant_hierarchical_ts, reconciliator): + ts = product_level_constant_hierarchical_ts + model = NaiveModel() + pipeline = HierarchicalPipeline(reconciliator=reconciliator, model=model, transforms=[], horizon=2) + + pipeline.fit(ts=ts) + forecast = pipeline.forecast(prediction_interval=True, n_folds=1, quantiles=[0.025, 0.5, 0.975]) + quantiles = match_target_quantiles(set(forecast.columns.get_level_values(1))) + assert quantiles == {"target_0.025", "target_0.5", "target_0.975"} + + +@pytest.mark.parametrize( + "reconciliator", + ( + TopDownReconciliator(target_level="product", source_level="market", period=1, method="AHP"), + TopDownReconciliator(target_level="product", source_level="market", period=1, method="PHA"), + BottomUpReconciliator(target_level="market", source_level="product"), + BottomUpReconciliator(target_level="total", source_level="market"), + ), +) +def test_forecast_prediction_intervals(product_level_constant_hierarchical_ts, reconciliator): + ts = product_level_constant_hierarchical_ts + model = NaiveModel() + pipeline = HierarchicalPipeline(reconciliator=reconciliator, model=model, transforms=[], horizon=2) + + pipeline.fit(ts=ts) + forecast = pipeline.forecast(prediction_interval=True, n_folds=1) + for segment in forecast.segments: + target = forecast[:, segment, "target"] + np.testing.assert_array_almost_equal(target, forecast[:, segment, "target_0.025"]) + np.testing.assert_array_almost_equal(target, forecast[:, segment, "target_0.975"]) + + +@pytest.mark.parametrize( + "metric_type,reconciliator,answer", + ( + (Width, TopDownReconciliator(target_level="product", source_level="market", period=1, method="AHP"), 0), + (Width, BottomUpReconciliator(target_level="total", source_level="market"), 0), + (Coverage, TopDownReconciliator(target_level="product", source_level="market", period=1, method="AHP"), 1), + (Coverage, BottomUpReconciliator(target_level="total", source_level="market"), 1), + ), +) +def test_interval_metrics(product_level_constant_hierarchical_ts, metric_type, reconciliator, answer): + ts = product_level_constant_hierarchical_ts + model = NaiveModel() + pipeline = HierarchicalPipeline(reconciliator=reconciliator, model=model, transforms=[], horizon=1) + + metric = metric_type() + results, _, _ = pipeline.backtest( + ts=ts, + metrics=[metric], + n_folds=2, + aggregate_metrics=True, + forecast_params={"prediction_interval": True, "n_folds": 1}, + ) + np.testing.assert_array_almost_equal(results[metric.name], answer) diff --git a/tests/test_reconciliation/__init__.py b/tests/test_reconciliation/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/test_reconciliation/conftest.py b/tests/test_reconciliation/conftest.py new file mode 100644 index 000000000..781533416 --- /dev/null +++ b/tests/test_reconciliation/conftest.py @@ -0,0 +1,29 @@ +import pandas as pd +import pytest + +from etna.datasets import TSDataset + + +@pytest.fixture +def market_level_df_w_negatives(): + df = pd.DataFrame( + { + "timestamp": ["2000-01-01", "2000-01-02"] * 2, + "segment": ["X"] * 2 + ["Y"] * 2, + "target": [-1.0, 2.0] + [0, -20.0], + } + ) + df = TSDataset.to_dataset(df) + return df + + +@pytest.fixture +def market_level_simple_hierarchical_ts_w_nans(market_level_df_w_nans, hierarchical_structure): + ts = TSDataset(df=market_level_df_w_nans, freq="D", hierarchical_structure=hierarchical_structure) + return ts + + +@pytest.fixture +def simple_hierarchical_ts_w_negatives(market_level_df_w_negatives, hierarchical_structure): + ts = TSDataset(df=market_level_df_w_negatives, freq="D", hierarchical_structure=hierarchical_structure) + return ts diff --git a/tests/test_reconciliation/test_base.py b/tests/test_reconciliation/test_base.py new file mode 100644 index 000000000..386c9e2dc --- /dev/null +++ b/tests/test_reconciliation/test_base.py @@ -0,0 +1,107 @@ +from unittest.mock import Mock + +import numpy as np +import pandas as pd +import pytest +from scipy.sparse import csr_matrix + +from etna.datasets import TSDataset +from etna.reconciliation.base import BaseReconciliator + + +class DummyReconciliator(BaseReconciliator): + def fit(self, ts: TSDataset) -> "DummyReconciliator": + self.mapping_matrix = Mock() + return self + + +@pytest.fixture +def hierarchical_ts(): + df = pd.DataFrame( + { + "timestamp": ["2000-01-01", "2000-01-02"] * 4, + "market": ["X"] * 2 + ["X"] * 2 + ["Y"] * 2 + ["Y"] * 2, + "product": ["a"] * 2 + ["b"] * 2 + ["c"] * 2 + ["d"] * 2, + "target": [1.0, 2.0] + [3.0, 4.0] + [5.0, 10.0] + [15.0, 20.0], + } + ) + df_exog = pd.DataFrame( + { + "timestamp": ["2000-01-01", "2000-01-02", "2000-01-03"] * 2, + "market": ["X"] * 3 + ["Y"] * 3, + "exog": [10.0, 12.0, 13.0] + [14.0, 15.0, 16.0], + } + ) + df, hs = TSDataset.to_hierarchical_dataset( + df=df, level_columns=["market", "product"], keep_level_columns=False, return_hierarchy=True + ) + df_exog, _ = TSDataset.to_hierarchical_dataset( + df=df_exog, level_columns=["market"], keep_level_columns=False, return_hierarchy=False + ) + ts = TSDataset(df=df, freq="D", df_exog=df_exog, known_future=["exog"], hierarchical_structure=hs) + return ts + + +@pytest.fixture +def market_total_mapping_matrix(): + mapping_matrix = np.array([[1, 1]]) + mapping_matrix = csr_matrix(mapping_matrix) + return mapping_matrix + + +@pytest.fixture +def total_market_mapping_matrix(): + mapping_matrix = np.array([[1 / 6], [5 / 6]]) + mapping_matrix = csr_matrix(mapping_matrix) + return mapping_matrix + + +@pytest.mark.parametrize("source_level", ("product", "market", "total")) +def test_aggregate(hierarchical_ts, source_level): + reconciliator = DummyReconciliator(target_level="level", source_level=source_level) + ts_aggregated = reconciliator.aggregate(ts=hierarchical_ts) + assert ts_aggregated.current_df_level == source_level + + +def test_aggregate_fails_low_source_level(hierarchical_ts): + ts_market_level = hierarchical_ts.get_level_dataset(target_level="market") + reconciliator = DummyReconciliator(target_level="level", source_level="product") + with pytest.raises( + ValueError, match="Target level should be higher in the hierarchy than the current level of dataframe!" + ): + _ = reconciliator.aggregate(ts=ts_market_level) + + +def test_reconcile_not_fitted_fails(hierarchical_ts): + reconciliator = DummyReconciliator(target_level="level", source_level="product") + with pytest.raises(ValueError, match="Reconciliator is not fitted!"): + _ = reconciliator.reconcile(ts=hierarchical_ts) + + +@pytest.mark.parametrize("cur_level", ("total", "product")) +def test_reconcile_wrong_level_fails(hierarchical_ts, cur_level, source_level="market"): + hierarchical_ts = hierarchical_ts.get_level_dataset(target_level=cur_level) + reconciliator = DummyReconciliator(target_level="level", source_level=source_level) + reconciliator.fit(hierarchical_ts) + with pytest.raises(ValueError, match=f"Dataset should be on the {source_level} level!"): + _ = reconciliator.reconcile(ts=hierarchical_ts) + + +@pytest.mark.parametrize( + "source_level, target_level, mapping_matrix", + [("market", "total", "market_total_mapping_matrix"), ("total", "market", "total_market_mapping_matrix")], +) +def test_reconcile(hierarchical_ts, source_level, target_level, mapping_matrix, request): + source_ts = hierarchical_ts.get_level_dataset(target_level=source_level) + expected_ts = hierarchical_ts.get_level_dataset(target_level=target_level) + + reconciliator = DummyReconciliator(target_level=target_level, source_level=source_level) + reconciliator.mapping_matrix = request.getfixturevalue(mapping_matrix) + obtained_ts = reconciliator.reconcile(ts=source_ts) + + assert obtained_ts.freq == expected_ts.freq + assert obtained_ts.current_df_level == expected_ts.current_df_level + assert obtained_ts.known_future == expected_ts.known_future + assert obtained_ts.regressors == expected_ts.regressors + pd.testing.assert_frame_equal(obtained_ts.df, expected_ts.df) + pd.testing.assert_frame_equal(obtained_ts.df_exog, expected_ts.df_exog) diff --git a/tests/test_reconciliation/test_bottom_up.py b/tests/test_reconciliation/test_bottom_up.py new file mode 100644 index 000000000..534ced802 --- /dev/null +++ b/tests/test_reconciliation/test_bottom_up.py @@ -0,0 +1,107 @@ +import numpy as np +import pytest + +from etna.reconciliation.bottom_up import BottomUpReconciliator + + +@pytest.mark.parametrize( + "target_level,source_level,error_message", + ( + ( + "market", + "total", + "Source level should be lower or equal in the hierarchy than the target level!", + ), + ( + "total", + "product", + "Current TSDataset level should be lower or equal in the hierarchy than the source level!", + ), + ), +) +def test_bottom_up_reconcile_level_order_errors( + market_level_simple_hierarchical_ts, target_level, source_level, error_message +): + reconciler = BottomUpReconciliator(target_level=target_level, source_level=source_level) + with pytest.raises(ValueError, match=error_message): + reconciler.fit(market_level_simple_hierarchical_ts) + + +@pytest.mark.parametrize( + "target_level,source_level", + (("abc", "total"), ("market", "abc")), +) +def test_bottom_up_reconcile_invalid_level_errors(market_level_simple_hierarchical_ts, target_level, source_level): + reconciler = BottomUpReconciliator(target_level=target_level, source_level=source_level) + with pytest.raises(ValueError, match="Invalid level name: abc"): + reconciler.fit(market_level_simple_hierarchical_ts) + + +def test_bottom_up_reconcile_no_hierarchy_error(simple_no_hierarchy_ts): + reconciler = BottomUpReconciliator(target_level="market", source_level="total") + with pytest.raises(ValueError, match="The method can be applied only to instances with a hierarchy!"): + reconciler.fit(simple_no_hierarchy_ts) + + +@pytest.mark.parametrize( + "ts_name,target_level,source_level,answer", + ( + ( + "product_level_simple_hierarchical_ts", + "product", + "product", + np.identity(4), + ), + ( + "product_level_simple_hierarchical_ts", + "market", + "product", + np.array([[1, 1, 0, 0], [0, 0, 1, 1]]), + ), + ( + "product_level_simple_hierarchical_ts", + "total", + "product", + np.array([[1, 1, 1, 1]]), + ), + ( + "product_level_simple_hierarchical_ts", + "total", + "market", + np.array([[1, 1]]), + ), + ( + "market_level_simple_hierarchical_ts", + "total", + "market", + np.array([[1, 1]]), + ), + ( + "total_level_simple_hierarchical_ts", + "total", + "total", + np.array([[1]]), + ), + ), +) +def test_bottom_up_reconcile_fit(ts_name, target_level, source_level, answer, request): + ts = request.getfixturevalue(ts_name) + reconciler = BottomUpReconciliator(target_level=target_level, source_level=source_level) + reconciler.fit(ts) + np.testing.assert_array_almost_equal(reconciler.mapping_matrix.toarray().round(5), answer, decimal=4) + + +def test_bottom_up_reconcile_fit_w_nans(market_level_simple_hierarchical_ts_w_nans): + answer = np.array([[1, 1]]) + reconciler = BottomUpReconciliator(source_level="market", target_level="total") + reconciler.fit(market_level_simple_hierarchical_ts_w_nans) + np.testing.assert_array_almost_equal(reconciler.mapping_matrix.toarray().round(5), answer, decimal=4) + + +def test_reconcile_with_quantiles(product_level_constant_forecast_w_quantiles, product_level_constant_hierarchical_ts): + answer = np.array([[3, 3, 3, 7, 7, 7], [3, 3, 3, 7, 7, 7]]) + ts = product_level_constant_hierarchical_ts + forecast = product_level_constant_forecast_w_quantiles + reconciliator = BottomUpReconciliator(target_level="market", source_level="product") + reconciliator.fit(ts=ts) + np.testing.assert_array_almost_equal(reconciliator.reconcile(ts=forecast).df.values, answer) diff --git a/tests/test_reconciliation/test_top_down.py b/tests/test_reconciliation/test_top_down.py new file mode 100644 index 000000000..f2b36f282 --- /dev/null +++ b/tests/test_reconciliation/test_top_down.py @@ -0,0 +1,300 @@ +import numpy as np +import pytest + +from etna.reconciliation.top_down import TopDownReconciliator + + +@pytest.mark.parametrize( + "period", + (0, -1), +) +def test_top_down_reconcile_init_period_error(period): + with pytest.raises(ValueError, match="Period length must be positive!"): + TopDownReconciliator(period=period, method="market", target_level="market", source_level="total") + + +@pytest.mark.parametrize( + "method", + ("abcd", ""), +) +def test_top_down_reconcile_init_method_error(method): + with pytest.raises( + ValueError, match=f"Unable to recognize reconciliation method '{method}'! Supported methods: AHP, PHA." + ): + TopDownReconciliator(period=1, method=method, target_level="market", source_level="total") + + +@pytest.mark.parametrize( + "ts_name,target_level,source_level,error_message", + ( + ( + "product_level_simple_hierarchical_ts", + "total", + "market", + "Target level should be lower or equal in the hierarchy than the source level!", + ), + ( + "market_level_simple_hierarchical_ts", + "product", + "total", + "Current TSDataset level should be lower or equal in the hierarchy than the target level!", + ), + ), +) +def test_top_down_reconcile_level_order_errors(ts_name, target_level, source_level, error_message, request): + ts = request.getfixturevalue(ts_name) + reconciler = TopDownReconciliator(period=1, method="AHP", target_level=target_level, source_level=source_level) + with pytest.raises(ValueError, match=error_message): + reconciler.fit(ts) + + +@pytest.mark.parametrize( + "target_level,source_level", + (("abc", "total"), ("market", "abc")), +) +def test_top_down_reconcile_invalid_level_errors(market_level_simple_hierarchical_ts, target_level, source_level): + reconciler = TopDownReconciliator(period=1, method="AHP", target_level=target_level, source_level=source_level) + with pytest.raises(ValueError, match="Invalid level name: abc"): + reconciler.fit(market_level_simple_hierarchical_ts) + + +def test_top_down_reconcile_no_hierarchy_error(simple_no_hierarchy_ts): + reconciler = TopDownReconciliator(method="AHP", period=1, target_level="market", source_level="total") + with pytest.raises(ValueError, match="The method can be applied only to instances with a hierarchy!"): + reconciler.fit(simple_no_hierarchy_ts) + + +def test_top_down_reconcile_negatives_error(simple_hierarchical_ts_w_negatives): + reconciler = TopDownReconciliator(method="AHP", period=1, target_level="market", source_level="total") + with pytest.raises(ValueError, match="Provided dataset should not contain any negative numbers!"): + reconciler.fit(simple_hierarchical_ts_w_negatives) + + +@pytest.mark.parametrize( + "ts_name,reconciler_args,answer", + ( + ( + "product_level_simple_hierarchical_ts", + { + "period": 1, + "target_level": "product", + "source_level": "product", + }, + np.identity(4), + ), + ( + "product_level_simple_hierarchical_ts", + { + "period": 1, + "target_level": "product", + "source_level": "market", + }, + np.array([[0.5, 0], [0.5, 0], [0, 0.9], [0, 0.1]]), + ), + ( + "product_level_simple_hierarchical_ts", + { + "period": 1, + "target_level": "product", + "source_level": "total", + }, + np.array([[0.04545], [0.04545], [0.8182], [0.0909]]), + ), + ( + "product_level_simple_hierarchical_ts", + { + "period": 1, + "target_level": "market", + "source_level": "total", + }, + np.array([[0.0909], [0.9091]]), + ), + ( + "total_level_simple_hierarchical_ts", + { + "period": 1, + "target_level": "total", + "source_level": "total", + }, + np.array([[1]]), + ), + ( + "product_level_simple_hierarchical_ts", + { + "period": 2, + "target_level": "product", + "source_level": "market", + }, + np.array([[0.75, 0], [0.25, 0], [0, 0.6], [0, 0.4]]), + ), + ( + "product_level_simple_hierarchical_ts", + { + "period": 2, + "target_level": "product", + "source_level": "total", + }, + np.array([[0.0682], [0.0227], [0.5455], [0.3636]]), + ), + ( + "product_level_simple_hierarchical_ts", + { + "period": 2, + "target_level": "market", + "source_level": "total", + }, + np.array([[0.0909], [0.9091]]), + ), + ), +) +def test_top_down_reconcile_ahp_fit(ts_name, reconciler_args, answer, request): + reconciler_args["method"] = "AHP" + ts = request.getfixturevalue(ts_name) + reconciler = TopDownReconciliator(**reconciler_args) + reconciler.fit(ts) + np.testing.assert_array_almost_equal(reconciler.mapping_matrix.toarray().round(5), answer, decimal=4) + + +@pytest.mark.parametrize( + "ts_name,reconciler_args,answer", + ( + ( + "product_level_simple_hierarchical_ts", + { + "period": 1, + "target_level": "product", + "source_level": "product", + }, + np.identity(4), + ), + ( + "product_level_simple_hierarchical_ts", + { + "period": 1, + "target_level": "product", + "source_level": "market", + }, + np.array([[0.5, 0], [0.5, 0], [0, 0.9], [0, 0.1]]), + ), + ( + "product_level_simple_hierarchical_ts", + { + "period": 1, + "target_level": "product", + "source_level": "total", + }, + np.array([[0.04545], [0.04545], [0.8182], [0.0909]]), + ), + ( + "product_level_simple_hierarchical_ts", + { + "period": 1, + "target_level": "market", + "source_level": "total", + }, + np.array([[0.0909], [0.9091]]), + ), + ( + "total_level_simple_hierarchical_ts", + { + "period": 1, + "target_level": "total", + "source_level": "total", + }, + np.array([[1]]), + ), + ( + "product_level_simple_hierarchical_ts", + { + "period": 2, + "target_level": "product", + "source_level": "market", + }, + np.array([[0.6667, 0], [0.3333, 0], [0, 0.7], [0, 0.3]]), + ), + ( + "product_level_simple_hierarchical_ts", + { + "period": 2, + "target_level": "product", + "source_level": "total", + }, + np.array([[0.0606], [0.0303], [0.6364], [0.2727]]), + ), + ( + "product_level_simple_hierarchical_ts", + { + "period": 2, + "target_level": "market", + "source_level": "total", + }, + np.array([[0.0909], [0.9091]]), + ), + ), +) +def test_top_down_reconcile_pha_fit(ts_name, reconciler_args, answer, request): + reconciler_args["method"] = "PHA" + ts = request.getfixturevalue(ts_name) + reconciler = TopDownReconciliator(**reconciler_args) + reconciler.fit(ts) + np.testing.assert_array_almost_equal(reconciler.mapping_matrix.toarray().round(5), answer, decimal=4) + + +@pytest.mark.parametrize( + "method,period,answer", + ( + ( + "AHP", + 1, + np.array([[np.nan], [np.nan]]), + ), + ( + "AHP", + 2, + np.array([[np.nan], [np.nan]]), + ), + ( + "AHP", + 3, + np.array([[0.1739], [0.8261]]), + ), + ( + "AHP", + 4, + np.array([[0.1739], [0.8261]]), + ), + ( + "PHA", + 1, + np.array([[np.nan], [np.nan]]), + ), + ( + "PHA", + 2, + np.array([[np.nan], [np.nan]]), + ), + ( + "PHA", + 3, + np.array([[0.2174], [0.8913]]), + ), + ( + "PHA", + 4, + np.array([[0.2174], [0.8406]]), + ), + ), +) +def test_top_down_reconcile_fit_w_nans(market_level_simple_hierarchical_ts_w_nans, method, period, answer): + reconciler = TopDownReconciliator(method=method, period=period, source_level="total", target_level="market") + reconciler.fit(market_level_simple_hierarchical_ts_w_nans) + np.testing.assert_array_almost_equal(reconciler.mapping_matrix.toarray().round(5), answer, decimal=4) + + +def test_reconcile_with_quantiles(total_level_constant_forecast_w_quantiles, product_level_constant_hierarchical_ts): + answer = np.array([[3, 3, 3, 7, 7, 7], [3, 3, 3, 7, 7, 7]]) + ts = product_level_constant_hierarchical_ts + forecast = total_level_constant_forecast_w_quantiles + reconciliator = TopDownReconciliator(target_level="market", source_level="total", method="AHP", period=1) + reconciliator.fit(ts=ts) + np.testing.assert_array_almost_equal(reconciliator.reconcile(ts=forecast).df.values, answer) From 0c0ce2726861e59fac6272debd288d96cdddcac1 Mon Sep 17 00:00:00 2001 From: Martin Gabdushev <33594071+martins0n@users.noreply.github.com> Date: Mon, 30 Jan 2023 12:56:58 +0300 Subject: [PATCH 02/13] Issue-1078: missed kwargs in TFT init (#1084) * issue-1087: missed kwargs * fix: linters + new test cases * FIX: changelog --- CHANGELOG.md | 2 +- etna/models/nn/tft.py | 2 +- tests/test_core/test_to_dict.py | 4 +++- tests/test_models/nn/test_tft.py | 8 ++++++++ 4 files changed, 13 insertions(+), 3 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index dbca6be42..757690de3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -26,7 +26,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - ### Fixed - -- +- Missed kwargs in TFT init([#1078](https://github.com/tinkoff-ai/etna/pull/1078)) - - - diff --git a/etna/models/nn/tft.py b/etna/models/nn/tft.py index e945cbddc..e53cef68b 100644 --- a/etna/models/nn/tft.py +++ b/etna/models/nn/tft.py @@ -53,7 +53,6 @@ def __init__( loss: "MultiHorizonMetric" = None, trainer_kwargs: Optional[Dict[str, Any]] = None, quantiles_kwargs: Optional[Dict[str, Any]] = None, - *args, **kwargs, ): """ @@ -113,6 +112,7 @@ def __init__( self.trainer: Optional[pl.Trainer] = None self._last_train_timestamp = None self._freq: Optional[str] = None + self.kwargs = kwargs def _from_dataset(self, ts_dataset: TimeSeriesDataSet) -> LightningModule: """ diff --git a/tests/test_core/test_to_dict.py b/tests/test_core/test_to_dict.py index 28b11a699..4a29f2690 100644 --- a/tests/test_core/test_to_dict.py +++ b/tests/test_core/test_to_dict.py @@ -17,6 +17,7 @@ from etna.models import LinearPerSegmentModel from etna.models.nn import DeepARModel from etna.models.nn import MLPModel +from etna.models.nn import TFTModel from etna.pipeline import Pipeline from etna.transforms import AddConstTransform from etna.transforms import ChangePointsTrendTransform @@ -98,10 +99,11 @@ def test_to_dict_transforms_with_expected(target_object, expected): @pytest.mark.parametrize( "target_model", [ - pytest.param(DeepARModel(), marks=pytest.mark.xfail(reason="some bug")), + pytest.param(DeepARModel(), marks=pytest.mark.xfail(raises=AssertionError)), LinearPerSegmentModel(), CatBoostModelPerSegment(), AutoARIMAModel(), + pytest.param(TFTModel(max_epochs=2), marks=pytest.mark.xfail(raises=AssertionError)), ], ) def test_to_dict_models(target_model): diff --git a/tests/test_models/nn/test_tft.py b/tests/test_models/nn/test_tft.py index 5292b5c5c..bb00968ee 100644 --- a/tests/test_models/nn/test_tft.py +++ b/tests/test_models/nn/test_tft.py @@ -187,3 +187,11 @@ def test_save_load(example_tsds): transform = _get_default_transform(horizon) transforms = [transform] assert_model_equals_loaded_original(model=model, ts=example_tsds, transforms=transforms, horizon=horizon) + + +def test_repr(): + model = TFTModel(max_epochs=2, learning_rate=[0.1], gpus=0, batch_size=64) + assert "TFTModel(max_epochs = 2, gpus = 0, gradient_clip_val = 0.1, " + "learning_rate = [0.1], batch_size = 64, context_length = None, hidden_size = 16, " + "lstm_layers = 1, attention_head_size = 4, dropout = 0.1, hidden_continuous_size = 8, " + "loss = QuantileLoss(), trainer_kwargs = {}, quantiles_kwargs = {}, )" == repr(model) From 30f9277917365b93f5f1644b1b5f4ab0e36a57ff Mon Sep 17 00:00:00 2001 From: Maxim Zherelo <60392282+brsnw250@users.noreply.github.com> Date: Tue, 31 Jan 2023 08:27:33 +0300 Subject: [PATCH 03/13] Updated docs (#1086) --- docs/source/api.rst | 1 + docs/source/index.rst | 1 + docs/source/reconciliation.rst | 20 ++++++++++++++++++++ 3 files changed, 22 insertions(+) create mode 100644 docs/source/reconciliation.rst diff --git a/docs/source/api.rst b/docs/source/api.rst index 761e9599e..6b1e20c7e 100644 --- a/docs/source/api.rst +++ b/docs/source/api.rst @@ -14,6 +14,7 @@ API transforms ensembles pipeline + reconciliation analysis auto clustering diff --git a/docs/source/index.rst b/docs/source/index.rst index b80010cb2..adb4b29c9 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -13,6 +13,7 @@ Welcome to ETNA's documentation transforms ensembles pipeline + reconciliation analysis auto clustering diff --git a/docs/source/reconciliation.rst b/docs/source/reconciliation.rst new file mode 100644 index 000000000..9710ae75c --- /dev/null +++ b/docs/source/reconciliation.rst @@ -0,0 +1,20 @@ +Reconciliation +============== + +.. _reconciliation: + +.. currentmodule:: etna + +Details of ETNA reconciliators +------------------------------ + +See the API documentation for further details on reconciliators: + +.. currentmodule:: etna + +.. moduleautosummary:: + :toctree: api/ + :template: custom-module-template.rst + :recursive: + + etna.reconciliation From b9246c6484c6560e69c3ac620928b6504c2aa86e Mon Sep 17 00:00:00 2001 From: looopka Date: Tue, 31 Jan 2023 11:19:16 +0300 Subject: [PATCH 04/13] Wape metric (#1085) * add new metric WAPE * add new metric WAPE (changelog) * add link in CHANGELOG.md and add import __init__.py --------- Co-authored-by: looopka --- CHANGELOG.md | 3 ++ etna/metrics/__init__.py | 2 ++ etna/metrics/functional_metrics.py | 31 +++++++++++++++++++ etna/metrics/metrics.py | 31 ++++++++++++++++++- tests/test_metrics/test_functional_metrics.py | 3 ++ tests/test_metrics/test_metrics.py | 19 +++++++----- 6 files changed, 81 insertions(+), 8 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 757690de3..33210234d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,6 +15,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - - - +- Add `WAPE` metric & `wape` functional metric ([#1085](https://github.com/tinkoff-ai/etna/pull/1085)) +- +- - - ### Changed diff --git a/etna/metrics/__init__.py b/etna/metrics/__init__.py index 4bd3d7d97..231931a49 100644 --- a/etna/metrics/__init__.py +++ b/etna/metrics/__init__.py @@ -11,6 +11,7 @@ from etna.metrics.functional_metrics import rmse from etna.metrics.functional_metrics import sign from etna.metrics.functional_metrics import smape +from etna.metrics.functional_metrics import wape from etna.metrics.intervals_metrics import Coverage from etna.metrics.intervals_metrics import Width from etna.metrics.metrics import MAE @@ -20,6 +21,7 @@ from etna.metrics.metrics import R2 from etna.metrics.metrics import RMSE from etna.metrics.metrics import SMAPE +from etna.metrics.metrics import WAPE from etna.metrics.metrics import MaxDeviation from etna.metrics.metrics import MedAE from etna.metrics.metrics import Sign diff --git a/etna/metrics/functional_metrics.py b/etna/metrics/functional_metrics.py index 12995d666..30a6be5c4 100644 --- a/etna/metrics/functional_metrics.py +++ b/etna/metrics/functional_metrics.py @@ -147,3 +147,34 @@ def max_deviation(y_true: ArrayLike, y_pred: ArrayLike) -> float: rmse = partial(mse, squared=False) + + +def wape(y_true: ArrayLike, y_pred: ArrayLike) -> float: + """Weighted average percentage Error metric. + + .. math:: + WAPE(y\_true, y\_pred) = \\frac{\\sum_{i=0}^{n} |y\_true_i - y\_pred_i|}{\\sum_{i=0}^{n}|y\\_true_i|} + + Parameters + ---------- + y_true: + array-like of shape (n_samples,) or (n_samples, n_outputs) + + Ground truth (correct) target values. + + y_pred: + array-like of shape (n_samples,) or (n_samples, n_outputs) + + Estimated target values. + + Returns + ------- + float + A floating point value (the best value is 0.0). + """ + y_true_array, y_pred_array = np.asarray(y_true), np.asarray(y_pred) + + if len(y_true_array.shape) != len(y_pred_array.shape): + raise ValueError("Shapes of the labels must be the same") + + return np.sum(np.abs(y_true_array - y_pred_array)) / np.sum(np.abs(y_true_array)) diff --git a/etna/metrics/metrics.py b/etna/metrics/metrics.py index 40aeb6923..0abe726cf 100644 --- a/etna/metrics/metrics.py +++ b/etna/metrics/metrics.py @@ -8,6 +8,7 @@ from etna.metrics import rmse from etna.metrics import sign from etna.metrics import smape +from etna.metrics import wape from etna.metrics.base import Metric from etna.metrics.base import MetricAggregationMode @@ -302,4 +303,32 @@ def greater_is_better(self) -> bool: return False -__all__ = ["MAE", "MSE", "RMSE", "R2", "MSLE", "MAPE", "SMAPE", "MedAE", "Sign", "MaxDeviation"] +class WAPE(Metric): + """Weighted average percentage Error metric with multi-segment computation support. + + .. math:: + WAPE(y\_true, y\_pred) = \\frac{\\sum_{i=0}^{n} |y\_true_i - y\_pred_i|}{\\sum_{i=0}^{n}|y\\_true_i|} + Notes + ----- + You can read more about logic of multi-segment metrics in Metric docs. + """ + + def __init__(self, mode: str = MetricAggregationMode.per_segment, **kwargs): + """Init metric. + + Parameters + ---------- + mode: 'macro' or 'per-segment' + metrics aggregation mode + kwargs: + metric's computation arguments + """ + super().__init__(mode=mode, metric_fn=wape, **kwargs) + + @property + def greater_is_better(self) -> bool: + """Whether higher metric value is better.""" + return False + + +__all__ = ["MAE", "MSE", "RMSE", "R2", "MSLE", "MAPE", "SMAPE", "MedAE", "Sign", "MaxDeviation", "WAPE"] diff --git a/tests/test_metrics/test_functional_metrics.py b/tests/test_metrics/test_functional_metrics.py index 6a5b8f1f3..a27d8fd27 100644 --- a/tests/test_metrics/test_functional_metrics.py +++ b/tests/test_metrics/test_functional_metrics.py @@ -10,6 +10,7 @@ from etna.metrics import rmse from etna.metrics import sign from etna.metrics import smape +from etna.metrics import wape @pytest.fixture() @@ -39,6 +40,7 @@ def y_pred_1d(): (r2_score, 0), (sign, -1), (max_deviation, 2), + (wape, 1), ), ) def test_all_1d_metrics(metric, right_metrics_value, y_true_1d, y_pred_1d): @@ -74,6 +76,7 @@ def y_pred_2d(): (r2_score, 0.0), (sign, -1), (max_deviation, 4), + (wape, 1), ), ) def test_all_2d_metrics(metric, right_metrics_value, y_true_2d, y_pred_2d): diff --git a/tests/test_metrics/test_metrics.py b/tests/test_metrics/test_metrics.py index 82771fc18..f6c343753 100644 --- a/tests/test_metrics/test_metrics.py +++ b/tests/test_metrics/test_metrics.py @@ -12,6 +12,7 @@ from etna.metrics import rmse from etna.metrics import sign from etna.metrics import smape +from etna.metrics import wape from etna.metrics.base import MetricAggregationMode from etna.metrics.metrics import MAE from etna.metrics.metrics import MAPE @@ -20,6 +21,7 @@ from etna.metrics.metrics import R2 from etna.metrics.metrics import RMSE from etna.metrics.metrics import SMAPE +from etna.metrics.metrics import WAPE from etna.metrics.metrics import MaxDeviation from etna.metrics.metrics import MedAE from etna.metrics.metrics import Sign @@ -41,6 +43,7 @@ (Sign, "Sign", {}, ""), (MaxDeviation, "MaxDeviation", {}, ""), (DummyMetric, "DummyMetric", {"alpha": 1.0}, "alpha = 1.0, "), + (WAPE, "WAPE", {}, ""), ), ) def test_repr(metric_class, metric_class_repr, metric_params, param_repr): @@ -56,7 +59,7 @@ def test_repr(metric_class, metric_class_repr, metric_params, param_repr): @pytest.mark.parametrize( "metric_class", - (MAE, MSE, RMSE, MedAE, MSLE, MAPE, SMAPE, R2, Sign, MaxDeviation), + (MAE, MSE, RMSE, MedAE, MSLE, MAPE, SMAPE, R2, Sign, MaxDeviation, WAPE), ) def test_name_class_name(metric_class): """Check metrics name property without changing its during inheritance""" @@ -80,7 +83,7 @@ def test_name_repr(metric_class): assert metric_name == true_name -@pytest.mark.parametrize("metric_class", (MAE, MSE, RMSE, MedAE, MSLE, MAPE, SMAPE, R2, Sign, MaxDeviation)) +@pytest.mark.parametrize("metric_class", (MAE, MSE, RMSE, MedAE, MSLE, MAPE, SMAPE, R2, Sign, MaxDeviation, WAPE)) def test_metrics_macro(metric_class, train_test_dfs): """Check metrics interface in 'macro' mode""" forecast_df, true_df = train_test_dfs @@ -90,7 +93,7 @@ def test_metrics_macro(metric_class, train_test_dfs): @pytest.mark.parametrize( - "metric_class", (MAE, MSE, RMSE, MedAE, MSLE, MAPE, SMAPE, R2, Sign, MaxDeviation, DummyMetric) + "metric_class", (MAE, MSE, RMSE, MedAE, MSLE, MAPE, SMAPE, R2, Sign, MaxDeviation, DummyMetric, WAPE) ) def test_metrics_per_segment(metric_class, train_test_dfs): """Check metrics interface in 'per-segment' mode""" @@ -103,7 +106,7 @@ def test_metrics_per_segment(metric_class, train_test_dfs): @pytest.mark.parametrize( - "metric_class", (MAE, MSE, RMSE, MedAE, MSLE, MAPE, SMAPE, R2, Sign, MaxDeviation, DummyMetric) + "metric_class", (MAE, MSE, RMSE, MedAE, MSLE, MAPE, SMAPE, R2, Sign, MaxDeviation, DummyMetric, WAPE) ) def test_metrics_invalid_aggregation(metric_class): """Check metrics behavior in case of invalid aggregation mode""" @@ -112,7 +115,7 @@ def test_metrics_invalid_aggregation(metric_class): @pytest.mark.parametrize( - "metric_class", (MAE, MSE, RMSE, MedAE, MSLE, MAPE, SMAPE, R2, Sign, MaxDeviation, DummyMetric) + "metric_class", (MAE, MSE, RMSE, MedAE, MSLE, MAPE, SMAPE, R2, Sign, MaxDeviation, DummyMetric, WAPE) ) def test_invalid_timestamps(metric_class, two_dfs_with_different_timestamps): """Check metrics behavior in case of invalid timeranges""" @@ -123,7 +126,7 @@ def test_invalid_timestamps(metric_class, two_dfs_with_different_timestamps): @pytest.mark.parametrize( - "metric_class", (MAE, MSE, RMSE, MedAE, MSLE, MAPE, SMAPE, R2, Sign, MaxDeviation, DummyMetric) + "metric_class", (MAE, MSE, RMSE, MedAE, MSLE, MAPE, SMAPE, R2, Sign, MaxDeviation, DummyMetric, WAPE) ) def test_invalid_segments(metric_class, two_dfs_with_different_segments_sets): """Check metrics behavior in case of invalid segments sets""" @@ -134,7 +137,7 @@ def test_invalid_segments(metric_class, two_dfs_with_different_segments_sets): @pytest.mark.parametrize( - "metric_class", (MAE, MSE, RMSE, MedAE, MSLE, MAPE, SMAPE, R2, Sign, MaxDeviation, DummyMetric) + "metric_class", (MAE, MSE, RMSE, MedAE, MSLE, MAPE, SMAPE, R2, Sign, MaxDeviation, DummyMetric, WAPE) ) def test_invalid_segments_target(metric_class, train_test_dfs): """Check metrics behavior in case of no target column in segment""" @@ -159,6 +162,7 @@ def test_invalid_segments_target(metric_class, train_test_dfs): (Sign, sign), (MaxDeviation, max_deviation), (DummyMetric, create_dummy_functional_metric()), + (WAPE, wape), ), ) def test_metrics_values(metric_class, metric_fn, train_test_dfs): @@ -191,6 +195,7 @@ def test_metrics_values(metric_class, metric_fn, train_test_dfs): (Sign(), None), (MaxDeviation(), False), (DummyMetric(), False), + (WAPE(), False), ), ) def test_metrics_greater_is_better(metric, greater_is_better): From d65999ea1d95058baf30e120383707f2db3ec128 Mon Sep 17 00:00:00 2001 From: Maxim Zherelo <60392282+brsnw250@users.noreply.github.com> Date: Tue, 31 Jan 2023 15:34:36 +0300 Subject: [PATCH 05/13] prepare for 1.15.0 release (#1091) --- CHANGELOG.md | 33 ++++++++++++++------------------- pyproject.toml | 2 +- 2 files changed, 15 insertions(+), 20 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 33210234d..bcece2350 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,35 +5,30 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). - ## Unreleased ### Added -- `RMSE` metric & `rmse` functional metric ([#1051](https://github.com/tinkoff-ai/etna/pull/1051)) -- `MaxDeviation` metric & `max_deviation` functional metric ([#1061](https://github.com/tinkoff-ai/etna/pull/1061)) -- Add saving/loading for transforms, models, pipelines, ensembles; tutorial for saving/loading ([#1068](https://github.com/tinkoff-ai/etna/pull/1068)) -- Add hierarchical time series support([#1083](https://github.com/tinkoff-ai/etna/pull/1083)) -- -- -- -- Add `WAPE` metric & `wape` functional metric ([#1085](https://github.com/tinkoff-ai/etna/pull/1085)) -- -- -- -- -### Changed + - - +### Changed + - - +### Fixed + - - + +## [1.15.0] - 2023-01-31 +### Added +- `RMSE` metric & `rmse` functional metric ([#1051](https://github.com/tinkoff-ai/etna/pull/1051)) +- `MaxDeviation` metric & `max_deviation` functional metric ([#1061](https://github.com/tinkoff-ai/etna/pull/1061)) +- Add saving/loading for transforms, models, pipelines, ensembles; tutorial for saving/loading ([#1068](https://github.com/tinkoff-ai/etna/pull/1068)) +- Add hierarchical time series support([#1083](https://github.com/tinkoff-ai/etna/pull/1083)) +- Add `WAPE` metric & `wape` functional metric ([#1085](https://github.com/tinkoff-ai/etna/pull/1085)) ### Fixed -- - Missed kwargs in TFT init([#1078](https://github.com/tinkoff-ai/etna/pull/1078)) -- -- -- -- + ## [1.14.0] - 2022-12-16 ### Added - Add python 3.10 support ([#1005](https://github.com/tinkoff-ai/etna/pull/1005)) diff --git a/pyproject.toml b/pyproject.toml index 8effc20e6..3c1cc2462 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "etna" -version = "1.14.0" +version = "1.15.0" repository = "https://github.com/tinkoff-ai/etna" readme = "README.md" description = "ETNA is the first python open source framework of Tinkoff.ru AI Center. It is designed to make working with time series simple, productive, and fun." From da08758a147e0952b6d202d1a5c795384a380906 Mon Sep 17 00:00:00 2001 From: Vlad Ilyuhin <91989984+GooseIt@users.noreply.github.com> Date: Tue, 7 Feb 2023 18:19:18 +0300 Subject: [PATCH 06/13] Fix order of columns after to flatten (#1095) --- CHANGELOG.md | 4 +-- etna/datasets/tsdataset.py | 38 ++++++++++++++++++----------- tests/test_datasets/test_dataset.py | 13 +++++----- 3 files changed, 32 insertions(+), 23 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index bcece2350..2db5e1b52 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,8 +11,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - - ### Changed - -- + +- Impose specific order of columns on return value of TSDataset.to_flatten ([#1095](https://github.com/tinkoff-ai/etna/pull/1095)) - ### Fixed diff --git a/etna/datasets/tsdataset.py b/etna/datasets/tsdataset.py index 09c89edc6..08e3f246c 100644 --- a/etna/datasets/tsdataset.py +++ b/etna/datasets/tsdataset.py @@ -579,6 +579,9 @@ def plot( def to_flatten(df: pd.DataFrame) -> pd.DataFrame: """Return pandas DataFrame with flatten index. + The order of columns is (timestamp, segment, target, + features in alphabetical order). + Parameters ---------- df: @@ -605,12 +608,12 @@ def to_flatten(df: pd.DataFrame) -> pd.DataFrame: 4 2021-06-05 segment_0 1.00 >>> df_ts_format = TSDataset.to_dataset(df) >>> TSDataset.to_flatten(df_ts_format).head(5) - timestamp target segment - 0 2021-06-01 1.0 segment_0 - 1 2021-06-02 1.0 segment_0 - 2 2021-06-03 1.0 segment_0 - 3 2021-06-04 1.0 segment_0 - 4 2021-06-05 1.0 segment_0 + timestamp segment target + 0 2021-06-01 segment_0 1.0 + 1 2021-06-02 segment_0 1.0 + 2 2021-06-03 segment_0 1.0 + 3 2021-06-04 segment_0 1.0 + 4 2021-06-05 segment_0 1.0 """ dtypes = df.dtypes category_columns = dtypes[dtypes == "category"].index.get_level_values(1).unique() @@ -618,8 +621,14 @@ def to_flatten(df: pd.DataFrame) -> pd.DataFrame: # flatten dataframe columns = df.columns.get_level_values("feature").unique() segments = df.columns.get_level_values("segment").unique() + df_dict = {} df_dict["timestamp"] = np.tile(df.index, len(segments)) + df_dict["segment"] = np.repeat(segments, len(df.index)) + if "target" in columns: + # set this value to lock position of key "target" in output dataframe columns + # None is a placeholder, actual column value will be assigned in the following cycle + df_dict["target"] = None for column in columns: df_cur = df.loc[:, pd.IndexSlice[:, column]] if column in category_columns: @@ -628,7 +637,6 @@ def to_flatten(df: pd.DataFrame) -> pd.DataFrame: stacked = df_cur.values.T.ravel() # creating series is necessary for dtypes like "Int64", "boolean", otherwise they will be objects df_dict[column] = pd.Series(stacked, dtype=df_cur.dtypes[0]) - df_dict["segment"] = np.repeat(segments, len(df.index)) df_flat = pd.DataFrame(df_dict) return df_flat @@ -641,7 +649,9 @@ def to_pandas(self, flatten: bool = False) -> pd.DataFrame: flatten: * If False, return pd.DataFrame with multiindex - * If True, return with flatten index + * If True, return with flatten index, + its order of columns is (timestamp, segment, target, + features in alphabetical order). Returns ------- @@ -665,12 +675,12 @@ def to_pandas(self, flatten: bool = False) -> pd.DataFrame: >>> df_ts_format = TSDataset.to_dataset(df) >>> ts = TSDataset(df_ts_format, "D") >>> ts.to_pandas(True).head(5) - timestamp target segment - 0 2021-06-01 1.0 segment_0 - 1 2021-06-02 1.0 segment_0 - 2 2021-06-03 1.0 segment_0 - 3 2021-06-04 1.0 segment_0 - 4 2021-06-05 1.0 segment_0 + timestamp segment target + 0 2021-06-01 segment_0 1.00 + 1 2021-06-02 segment_0 1.00 + 2 2021-06-03 segment_0 1.00 + 3 2021-06-04 segment_0 1.00 + 4 2021-06-05 segment_0 1.00 >>> ts.to_pandas(False).head(5) segment segment_0 segment_1 feature target target diff --git a/tests/test_datasets/test_dataset.py b/tests/test_datasets/test_dataset.py index 645bf3afc..c006aa607 100644 --- a/tests/test_datasets/test_dataset.py +++ b/tests/test_datasets/test_dataset.py @@ -623,14 +623,13 @@ def test_to_flatten_with_exog(df_and_regressors_flat): for column, dtype in dtypes.items(): if dtype == "category": expected_df[column] = expected_df[column].astype(dtype) - + # this logic wouldn't work in general case, here we use that all features' names start with 'r' + sorted_columns = ["timestamp", "segment", "target"] + sorted_columns[:-3] + # reindex df to assert correct columns order + expected_df = expected_df[sorted_columns] # get to_flatten result - obtained_df = TSDataset.to_flatten(TSDataset.to_dataset(flat_df))[sorted_columns].sort_values( - by=["segment", "timestamp"] - ) - assert np.all(sorted_columns == obtained_df.columns) - assert np.all(expected_df.dtypes == obtained_df.dtypes) - assert expected_df.equals(obtained_df) + obtained_df = TSDataset.to_flatten(TSDataset.to_dataset(flat_df)) + pd.testing.assert_frame_equal(obtained_df, expected_df) def test_transform_raise_warning_on_diff_endings(ts_diff_endings): From 7b58e3ce72d1f6a5597702a2535439c713863d68 Mon Sep 17 00:00:00 2001 From: Artyom Makhin <48079881+Ama16@users.noreply.github.com> Date: Tue, 14 Feb 2023 14:33:49 +0300 Subject: [PATCH 07/13] BUG Gale-Shapley (#1110) * BUG Gale-Shapley * changelog * black * fix comment --------- Co-authored-by: a.makhin --- CHANGELOG.md | 2 +- .../feature_selection/gale_shapley.py | 2 +- .../test_gale_shapley_transform.py | 37 +++++++++++++++++++ 3 files changed, 39 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 2db5e1b52..854e628fd 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -16,7 +16,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - ### Fixed -- +- Fix bug in `GaleShapleyFeatureSelectionTransform` with wrong number of remaining features ([#1110](https://github.com/tinkoff-ai/etna/pull/1110)) - ## [1.15.0] - 2023-01-31 diff --git a/etna/transforms/feature_selection/gale_shapley.py b/etna/transforms/feature_selection/gale_shapley.py index 61b3b003c..e0bc491d6 100644 --- a/etna/transforms/feature_selection/gale_shapley.py +++ b/etna/transforms/feature_selection/gale_shapley.py @@ -371,7 +371,7 @@ def fit(self, df: pd.DataFrame) -> "GaleShapleyFeatureSelectionTransform": segment_features_ranking=segment_features_ranking, feature_segments_ranking=feature_segments_ranking, ) - if step == gale_shapley_steps_number - 1: + if step == gale_shapley_steps_number - 1 and last_step_features_number != 0: selected_features = self._process_last_step( matches=matches, relevance_table=relevance_table, diff --git a/tests/test_transforms/test_feature_selection/test_gale_shapley_transform.py b/tests/test_transforms/test_feature_selection/test_gale_shapley_transform.py index ba92e786d..891b77f8f 100644 --- a/tests/test_transforms/test_feature_selection/test_gale_shapley_transform.py +++ b/tests/test_transforms/test_feature_selection/test_gale_shapley_transform.py @@ -19,6 +19,32 @@ from tests.test_transforms.utils import assert_transformation_equals_loaded_original +@pytest.fixture +def ts_with_exog_galeshapley(random_seed) -> TSDataset: + np.random.seed(random_seed) + + periods = 30 + df_1 = pd.DataFrame({"timestamp": pd.date_range("2020-01-15", periods=periods)}) + df_1["segment"] = "segment_1" + df_1["target"] = np.random.uniform(10, 20, size=periods) + + df_2 = pd.DataFrame({"timestamp": pd.date_range("2020-01-15", periods=periods)}) + df_2["segment"] = "segment_2" + df_2["target"] = np.random.uniform(-15, 5, size=periods) + + df = pd.concat([df_1, df_2]).reset_index(drop=True) + df = TSDataset.to_dataset(df) + tsds = TSDataset(df, freq="D") + df = tsds.to_pandas(flatten=True) + df_exog = df.copy().drop(columns=["target"]) + df_exog["weekday"] = df_exog["timestamp"].dt.weekday + df_exog["monthday"] = df_exog["timestamp"].dt.day + df_exog["month"] = df_exog["timestamp"].dt.month + df_exog["year"] = df_exog["timestamp"].dt.year + ts = TSDataset(df=TSDataset.to_dataset(df), df_exog=TSDataset.to_dataset(df_exog), freq="D") + return ts + + @pytest.fixture def ts_with_large_regressors_number(random_seed) -> TSDataset: df = generate_periodic_df(periods=100, start_time="2020-01-01", n_segments=3, period=7, scale=10) @@ -622,3 +648,14 @@ def test_work_with_non_regressors(ts_with_exog): ) def test_save_load(transform, ts_with_large_regressors_number): assert_transformation_equals_loaded_original(transform=transform, ts=ts_with_large_regressors_number) + + +def test_right_number_features_with_integer_division(ts_with_exog_galeshapley): + top_k = len(ts_with_exog_galeshapley.segments) + transform = GaleShapleyFeatureSelectionTransform(relevance_table=StatisticsRelevanceTable(), top_k=top_k) + + transform.fit(ts_with_exog_galeshapley.to_pandas()) + df = transform.transform(ts_with_exog_galeshapley.to_pandas()) + + remaining_columns = df.columns.get_level_values("feature").unique().tolist() + assert len(remaining_columns) == top_k + 1 From 840f3537db5dff5025dc468b158cb489c9b794d6 Mon Sep 17 00:00:00 2001 From: alex-hse-repository <55380696+alex-hse-repository@users.noreply.github.com> Date: Thu, 2 Mar 2023 11:52:36 +0100 Subject: [PATCH 08/13] Release 1.15.1 (#1144) --- CHANGELOG.md | 20 +++++-- pyproject.toml | 149 +++++++++++++++++++++++++------------------------ 2 files changed, 92 insertions(+), 77 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 854e628fd..e5a23c23e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,17 +7,27 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## Unreleased ### Added - -- - +- +- +- +- ### Changed - -- Impose specific order of columns on return value of TSDataset.to_flatten ([#1095](https://github.com/tinkoff-ai/etna/pull/1095)) - +- +- +- ### Fixed +- +- +- +- +## [1.15.1] - 2023-03-02 +### Changed +- Impose specific order of columns on return value of `TSDataset.to_flatten` ([#1095](https://github.com/tinkoff-ai/etna/pull/1095)) +### Fixed - Fix bug in `GaleShapleyFeatureSelectionTransform` with wrong number of remaining features ([#1110](https://github.com/tinkoff-ai/etna/pull/1110)) -- ## [1.15.0] - 2023-01-31 ### Added diff --git a/pyproject.toml b/pyproject.toml index 3c1cc2462..dcab867e4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "etna" -version = "1.15.0" +version = "1.15.1" repository = "https://github.com/tinkoff-ai/etna" readme = "README.md" description = "ETNA is the first python open source framework of Tinkoff.ru AI Center. It is designed to make working with time series simple, productive, and fun." @@ -187,77 +187,82 @@ line_length = 120 [tool.pytest.ini_options] minversion = "6.0" doctest_optionflags = "NORMALIZE_WHITESPACE IGNORE_EXCEPTION_DETAIL NUMBER" -filterwarnings = [ - "error", - "ignore: Torchmetrics v0.9 introduced a new argument class property called `full_state_update` that", - "ignore: TSDataset freq can't be inferred", - "ignore: test_size, test_start and test_end cannot be", - "ignore: You probably set wrong freq. Discovered freq in you data is None", - "ignore: Some regressors don't have enough values in segment", - "ignore: Segments contains NaNs in the last timestamps.", - "ignore: pandas.util.testing is deprecated. Use the functions in the public API", - "ignore: Call to deprecated class CatBoostModelPerSegment.", # OK - "ignore: Call to deprecated class CatBoostModelMultiSegment.", # OK - "ignore: Attribute 'loss' is an instance of `nn.Module` and is already", - "ignore: Columns from feature_to_use which are out of dataframe columns will", - "ignore: Comparison of Timestamp with datetime.date is deprecated in order to", - "ignore: CountryHoliday is deprecated, use country_holidays instead.", - "ignore: Exogenous or target data contains None! It will be dropped", - "ignore: is less than n_segments. Algo will filter data", - "ignore: Given top_k=30 is bigger than n_features=20. Transform will not filter", - "ignore: Implicitly cleaning up ", - "ignore: Maximum Likelihood optimization failed to converge. Check mle_retvals", - "ignore: Mean of empty slice", - "ignore: No frequency information was provided, so inferred frequency D will", - "ignore: Non-stationary starting autoregressive parameters found. Using zeros as starting parameters.", - "ignore: Slicing a positional slice with .loc is not supported", - "ignore: Some of external objects in input parameters could be not", - "ignore: The 'check_less_precise' keyword in testing.assert_*_equal is deprecated and will be", - "ignore: The default dtype for empty Series will be 'object' instead", - "ignore: This model does not work with exogenous features and regressors.", - "ignore: Transformation will be applied inplace, out_column param will be ignored", - "ignore: You defined a `validation_step` but have no `val_dataloader`. Skipping val", - "ignore: You probably set wrong freq. Discovered freq in you data", - "ignore: _SeasonalMovingAverageModel does not work with any exogenous series or features.", - "ignore: `np.object` is a deprecated alias for the builtin `object`. To", - "ignore: divide by zero encountered in log", - "ignore: inplace is deprecated and will be removed in a future", - "ignore: invalid value encountered in double_scalars", - "ignore: Arrays of bytes/strings is being converted to decimal numbers if", - "ignore: Attribute 'logging_metrics' is an instance of `nn.Module` and is already", - "ignore: Exogenous data contains columns with category type! It will be", - "ignore: Features {'unknown'} are not found and will be dropped!", - "ignore: SARIMAX model does not work with exogenous features", - "ignore: Series.dt.weekofyear and Series.dt.week have been deprecated", - "ignore: The dataloader, train_dataloader, does not have many workers which may", - "ignore: Creating a tensor from a list of numpy.ndarrays", - "ignore: Trying to infer the `batch_size` from an ambiguous collection", - "ignore: ReduceLROnPlateau conditioned on metric val_loss which is not available but strict", - "ignore: Checkpoint directory", - "ignore: Objective did not converge. You might want to increase the number", - "ignore: distutils Version classes are deprecated.", - "ignore: invalid escape sequence", - "ignore::pandas.core.common.SettingWithCopyWarning", - "ignore: You haven't set all parameters inside class __init__ method.* 'box_cox_bounds'", - "ignore: You haven't set all parameters inside class __init__ method.* 'use_box_cox'", - "ignore: You haven't set all parameters inside class __init__ method.* 'use_trend'", - "ignore: You haven't set all parameters inside class __init__ method.* 'use_damped_trend'", - "ignore: You haven't set all parameters inside class __init__ method.* 'seasonal_periods'", - "ignore: You haven't set all parameters inside class __init__ method.* 'show_warnings'", - "ignore: You haven't set all parameters inside class __init__ method.* 'n_jobs'", - "ignore: You haven't set all parameters inside class __init__ method.* 'multiprocessing_start_method'", - "ignore: You haven't set all parameters inside class __init__ method.* 'context'", - "ignore: You haven't set all parameters inside class __init__ method.* 'use_arma_errors'", - "ignore: New behaviour in v1.1.5", - "ignore: The 'check_less_precise' keyword in testing", - "ignore: Feature names only support names that are all strings", - "ignore: Given top_k=.* is less than n_segments. Algo will filter data without Gale-Shapley run.", - "ignore: Call to deprecated create function", # protobuf warning - "ignore: Dynamic prediction specified to begin during out-of-sample forecasting period, and so has no effect.", - "ignore: `tsfresh` is not available, to install it, run `pip install tsfresh==0.19.0 && pip install protobuf==3.20.1`", - "ignore::pytorch_lightning.utilities.rank_zero.LightningDeprecationWarning", - "ignore: The default method 'yw' can produce PACF values outside", -] + +# TODO: Uncomment after some solution in https://github.com/pytest-dev/pytest/issues/10773 +#filterwarnings = [ +# "error", +# "ignore: Torchmetrics v0.9 introduced a new argument class property called `full_state_update` that", +# "ignore: TSDataset freq can't be inferred", +# "ignore: test_size, test_start and test_end cannot be", +# "ignore: You probably set wrong freq. Discovered freq in you data is None", +# "ignore: Some regressors don't have enough values in segment", +# "ignore: Segments contains NaNs in the last timestamps.", +# "ignore: pandas.util.testing is deprecated. Use the functions in the public API", +# "ignore: Call to deprecated class CatBoostModelPerSegment.", # OK +# "ignore: Call to deprecated class CatBoostModelMultiSegment.", # OK +# "ignore: Attribute 'loss' is an instance of `nn.Module` and is already", +# "ignore: Columns from feature_to_use which are out of dataframe columns will", +# "ignore: Comparison of Timestamp with datetime.date is deprecated in order to", +# "ignore: CountryHoliday is deprecated, use country_holidays instead.", +# "ignore: Exogenous or target data contains None! It will be dropped", +# "ignore: is less than n_segments. Algo will filter data", +# "ignore: Given top_k=30 is bigger than n_features=20. Transform will not filter", +# "ignore: Implicitly cleaning up ", +# "ignore: Maximum Likelihood optimization failed to converge. Check mle_retvals", +# "ignore: Mean of empty slice", +# "ignore: No frequency information was provided, so inferred frequency D will", +# "ignore: Non-stationary starting autoregressive parameters found. Using zeros as starting parameters.", +# "ignore: Slicing a positional slice with .loc is not supported", +# "ignore: Some of external objects in input parameters could be not", +# "ignore: The 'check_less_precise' keyword in testing.assert_*_equal is deprecated and will be", +# "ignore: The default dtype for empty Series will be 'object' instead", +# "ignore: This model does not work with exogenous features and regressors.", +# "ignore: Transformation will be applied inplace, out_column param will be ignored", +# "ignore: You defined a `validation_step` but have no `val_dataloader`. Skipping val", +# "ignore: You probably set wrong freq. Discovered freq in you data", +# "ignore: SeasonalMovingAverageModel does not work with any exogenous series or features.", +# "ignore: MovingAverageModel does not work with any exogenous series or features.", +# "ignore: NaiveModel does not work with any exogenous series or features.", +# "ignore: `np.object` is a deprecated alias for the builtin `object`. To", +# "ignore: divide by zero encountered in log", +# "ignore: inplace is deprecated and will be removed in a future", +# "ignore: invalid value encountered in double_scalars", +# "ignore: Arrays of bytes/strings is being converted to decimal numbers if", +# "ignore: Attribute 'logging_metrics' is an instance of `nn.Module` and is already", +# "ignore: Exogenous data contains columns with category type! It will be", +# "ignore: Features {'unknown'} are not found and will be dropped!", +# "ignore: SARIMAX model does not work with exogenous features", +# "ignore: Series.dt.weekofyear and Series.dt.week have been deprecated", +# "ignore: The dataloader, train_dataloader, does not have many workers which may", +# "ignore: Creating a tensor from a list of numpy.ndarrays", +# "ignore: Trying to infer the `batch_size` from an ambiguous collection", +# "ignore: ReduceLROnPlateau conditioned on metric val_loss which is not available but strict", +# "ignore: Checkpoint directory", +# "ignore: Objective did not converge. You might want to increase the number", +# "ignore: distutils Version classes are deprecated.", +# "ignore: invalid escape sequence", +# "ignore::pandas.core.common.SettingWithCopyWarning", +# "ignore: You haven't set all parameters inside class __init__ method.* 'box_cox_bounds'", +# "ignore: You haven't set all parameters inside class __init__ method.* 'use_box_cox'", +# "ignore: You haven't set all parameters inside class __init__ method.* 'use_trend'", +# "ignore: You haven't set all parameters inside class __init__ method.* 'use_damped_trend'", +# "ignore: You haven't set all parameters inside class __init__ method.* 'seasonal_periods'", +# "ignore: You haven't set all parameters inside class __init__ method.* 'show_warnings'", +# "ignore: You haven't set all parameters inside class __init__ method.* 'n_jobs'", +# "ignore: You haven't set all parameters inside class __init__ method.* 'multiprocessing_start_method'", +# "ignore: You haven't set all parameters inside class __init__ method.* 'context'", +# "ignore: You haven't set all parameters inside class __init__ method.* 'use_arma_errors'", +# "ignore: New behaviour in v1.1.5", +# "ignore: The 'check_less_precise' keyword in testing", +# "ignore: Feature names only support names that are all strings", +# "ignore: Given top_k=.* is less than n_segments. Algo will filter data without Gale-Shapley run.", +# "ignore: Call to deprecated create function", # protobuf warning +# "ignore: Dynamic prediction specified to begin during out-of-sample forecasting period, and so has no effect.", +# "ignore: `tsfresh` is not available, to install it, run `pip install tsfresh==0.19.0 && pip install protobuf==3.20.1`", +# "ignore::pytorch_lightning.utilities.rank_zero.LightningDeprecationWarning", +# "ignore: The default method 'yw' can produce PACF values outside", +# "ignore: All-NaN slice encountered", +#] markers = [ "smoke", "long_1", From fb295036595f8603eaf887fb8b7a675ad9fdc1a9 Mon Sep 17 00:00:00 2001 From: "d.a.bunin" Date: Tue, 21 Mar 2023 15:59:32 +0300 Subject: [PATCH 09/13] fix: transfer fixes for BasePipeline and HierarchicalPipeline --- etna/pipeline/base.py | 8 +- etna/pipeline/hierarchical_pipeline.py | 86 +++++++++- .../test_hierarchical_pipeline.py | 148 +++++++++++++++++- 3 files changed, 228 insertions(+), 14 deletions(-) diff --git a/etna/pipeline/base.py b/etna/pipeline/base.py index fa5d2ce4a..ff3c81dc6 100644 --- a/etna/pipeline/base.py +++ b/etna/pipeline/base.py @@ -307,7 +307,7 @@ def _forecast_prediction_interval( ) -> TSDataset: """Add prediction intervals to the forecasts.""" with tslogger.disable(): - _, forecasts, _ = self.backtest(ts=self.ts, metrics=[MAE()], n_folds=n_folds) + _, forecasts, _ = self.backtest(ts=ts, metrics=[MAE()], n_folds=n_folds) self._add_forecast_borders(backtest_forecasts=forecasts, quantiles=quantiles, predictions=predictions) @@ -321,11 +321,7 @@ def _add_forecast_borders( raise ValueError("Pipeline is not fitted!") backtest_forecasts = TSDataset(df=backtest_forecasts, freq=self.ts.freq) - _, forecasts, _ = self.backtest(ts=ts, metrics=[MAE()], n_folds=n_folds) - forecasts = TSDataset(df=forecasts, freq=ts.freq) residuals = ( - forecasts.loc[:, pd.IndexSlice[:, "target"]] - - ts[forecasts.index.min() : forecasts.index.max(), :, "target"] backtest_forecasts.loc[:, pd.IndexSlice[:, "target"]] - self.ts[backtest_forecasts.index.min() : backtest_forecasts.index.max(), :, "target"] ) @@ -619,6 +615,7 @@ def _process_fold_forecast( forecast: TSDataset, train: TSDataset, test: TSDataset, + pipeline: "BasePipeline", fold_number: int, mask: FoldMask, metrics: List[Metric], @@ -798,6 +795,7 @@ def _run_all_folds( forecast=forecasts_flat[group_idx * refit + idx], train=train, test=test, + pipeline=pipelines[group_idx], fold_number=fold_groups[group_idx]["forecast_fold_numbers"][idx], mask=fold_groups[group_idx]["forecast_masks"][idx], metrics=metrics, diff --git a/etna/pipeline/hierarchical_pipeline.py b/etna/pipeline/hierarchical_pipeline.py index c7ca590ed..1ed4cc265 100644 --- a/etna/pipeline/hierarchical_pipeline.py +++ b/etna/pipeline/hierarchical_pipeline.py @@ -1,3 +1,4 @@ +import pathlib from copy import deepcopy from typing import Dict from typing import List @@ -67,12 +68,18 @@ def fit(self, ts: TSDataset) -> "HierarchicalPipeline": return self def raw_forecast( - self, prediction_interval: bool = False, quantiles: Sequence[float] = (0.25, 0.75), n_folds: int = 3 + self, + ts: TSDataset, + prediction_interval: bool = False, + quantiles: Sequence[float] = (0.25, 0.75), + n_folds: int = 3, ) -> TSDataset: """Make a prediction for target at the source level of hierarchy. Parameters ---------- + ts: + Dataset to forecast prediction_interval: If True returns prediction interval for forecast quantiles: @@ -85,9 +92,15 @@ def raw_forecast( : Dataset with predictions at the source level """ - forecast = super().forecast(prediction_interval=prediction_interval, quantiles=quantiles, n_folds=n_folds) - target_columns = tuple(get_target_with_quantiles(columns=forecast.columns)) + # handle `prediction_interval=True` separately + source_ts = self.reconciliator.aggregate(ts=ts) + forecast = super().forecast(ts=source_ts, prediction_interval=False, n_folds=n_folds) + if prediction_interval: + forecast = self._forecast_prediction_interval( + ts=ts, predictions=forecast, quantiles=quantiles, n_folds=n_folds + ) + target_columns = tuple(get_target_with_quantiles(columns=forecast.columns)) hierarchical_forecast = TSDataset( df=forecast[..., target_columns], freq=forecast.freq, @@ -98,12 +111,18 @@ def raw_forecast( return hierarchical_forecast def forecast( - self, prediction_interval: bool = False, quantiles: Sequence[float] = (0.025, 0.975), n_folds: int = 3 + self, + ts: Optional[TSDataset] = None, + prediction_interval: bool = False, + quantiles: Sequence[float] = (0.025, 0.975), + n_folds: int = 3, ) -> TSDataset: """Make a prediction for target at the source level of hierarchy and make reconciliation to target level. Parameters ---------- + ts: + Dataset to forecast. If not given, dataset given during :py:meth:``fit`` is used. prediction_interval: If True returns prediction interval for forecast quantiles: @@ -116,7 +135,16 @@ def forecast( : Dataset with predictions at the target level of hierarchy. """ - forecast = self.raw_forecast(prediction_interval=prediction_interval, quantiles=quantiles, n_folds=n_folds) + if ts is None: + if self._fit_ts is None: + raise ValueError( + "There is no ts to forecast! Pass ts into forecast method or make sure that pipeline is loaded with ts." + ) + ts = self._fit_ts + + forecast = self.raw_forecast( + ts=ts, prediction_interval=prediction_interval, quantiles=quantiles, n_folds=n_folds + ) forecast_reconciled = self.reconciliator.reconcile(forecast) return forecast_reconciled @@ -136,9 +164,10 @@ def _compute_metrics( return metrics_values def _forecast_prediction_interval( - self, predictions: TSDataset, quantiles: Sequence[float], n_folds: int + self, ts: TSDataset, predictions: TSDataset, quantiles: Sequence[float], n_folds: int ) -> TSDataset: """Add prediction intervals to the forecasts.""" + # TODO: fix this: what if during backtest KeyboardInterrupt is raised self.forecast, self.raw_forecast = self.raw_forecast, self.forecast # type: ignore if self.ts is None or self._fit_ts is None: @@ -147,10 +176,53 @@ def _forecast_prediction_interval( # TODO: rework intervals estimation for `BottomUpReconciliator` with tslogger.disable(): - _, forecasts, _ = self.backtest(ts=self._fit_ts, metrics=[MAE()], n_folds=n_folds) + _, forecasts, _ = self.backtest(ts=ts, metrics=[MAE()], n_folds=n_folds) self._add_forecast_borders(backtest_forecasts=forecasts, quantiles=quantiles, predictions=predictions) self.forecast, self.raw_forecast = self.raw_forecast, self.forecast # type: ignore return predictions + + def save(self, path: pathlib.Path): + """Save the object. + + Parameters + ---------- + path: + Path to save object to. + """ + fit_ts = self._fit_ts + + try: + # extract attributes we can't easily save + delattr(self, "_fit_ts") + + # save the remaining part + super().save(path=path) + finally: + self._fit_ts = fit_ts + + @classmethod + def load(cls, path: pathlib.Path, ts: Optional[TSDataset] = None) -> "HierarchicalPipeline": + """Load an object. + + Parameters + ---------- + path: + Path to load object from. + ts: + TSDataset to set into loaded pipeline. + + Returns + ------- + : + Loaded object. + """ + obj = super().load(path=path) + obj._fit_ts = deepcopy(ts) + if ts is not None: + obj.ts = obj.reconciliator.aggregate(ts=ts) + else: + obj.ts = None + return obj diff --git a/tests/test_pipeline/test_hierarchical_pipeline.py b/tests/test_pipeline/test_hierarchical_pipeline.py index 8424a29a1..449c83256 100644 --- a/tests/test_pipeline/test_hierarchical_pipeline.py +++ b/tests/test_pipeline/test_hierarchical_pipeline.py @@ -1,20 +1,28 @@ +import pathlib +from copy import deepcopy from unittest.mock import Mock +from unittest.mock import patch import numpy as np +import pandas as pd import pytest from etna.datasets.utils import match_target_quantiles from etna.metrics import MAE from etna.metrics import Coverage from etna.metrics import Width +from etna.models import CatBoostMultiSegmentModel from etna.models import LinearPerSegmentModel from etna.models import NaiveModel +from etna.models import ProphetModel from etna.pipeline.hierarchical_pipeline import HierarchicalPipeline from etna.reconciliation import BottomUpReconciliator from etna.reconciliation import TopDownReconciliator +from etna.transforms import DateFlagsTransform from etna.transforms import LagTransform from etna.transforms import LinearTrendTransform from etna.transforms import MeanTransform +from tests.test_pipeline.utils import assert_pipeline_equals_loaded_original @pytest.mark.parametrize( @@ -79,7 +87,7 @@ def test_raw_forecast_correctness(market_level_constant_hierarchical_ts, reconci model = NaiveModel() pipeline = HierarchicalPipeline(reconciliator=reconciliator, model=model, transforms=[], horizon=1) pipeline.fit(ts=market_level_constant_hierarchical_ts) - forecast = pipeline.raw_forecast() + forecast = pipeline.raw_forecast(ts=market_level_constant_hierarchical_ts) np.testing.assert_array_almost_equal(forecast[..., "target"].values, answer) @@ -94,7 +102,7 @@ def test_raw_forecast_level(market_level_simple_hierarchical_ts, reconciliator): model = NaiveModel() pipeline = HierarchicalPipeline(reconciliator=reconciliator, model=model, transforms=[], horizon=1) pipeline.fit(ts=market_level_simple_hierarchical_ts) - forecast = pipeline.raw_forecast() + forecast = pipeline.raw_forecast(ts=market_level_simple_hierarchical_ts) assert forecast.current_df_level == pipeline.reconciliator.source_level @@ -260,3 +268,139 @@ def test_interval_metrics(product_level_constant_hierarchical_ts, metric_type, r forecast_params={"prediction_interval": True, "n_folds": 1}, ) np.testing.assert_array_almost_equal(results[metric.name], answer) + + +@patch("etna.pipeline.pipeline.Pipeline.save") +def test_save(save_mock, product_level_constant_hierarchical_ts, tmp_path): + ts = product_level_constant_hierarchical_ts + model = NaiveModel() + reconciliator = BottomUpReconciliator(target_level="market", source_level="product") + pipeline = HierarchicalPipeline(reconciliator=reconciliator, model=model, transforms=[], horizon=1) + dir_path = pathlib.Path(tmp_path) + path = dir_path / "dummy.zip" + pipeline.fit(ts) + + def check_no_fit_ts(path): + assert not hasattr(pipeline, "_fit_ts") + + save_mock.side_effect = check_no_fit_ts + + pipeline.save(path) + + save_mock.assert_called_once_with(path=path) + assert hasattr(pipeline, "_fit_ts") + + +@patch("etna.pipeline.pipeline.Pipeline.load") +def test_load_no_ts(load_mock, product_level_constant_hierarchical_ts, tmp_path): + ts = product_level_constant_hierarchical_ts + model = NaiveModel() + reconciliator = BottomUpReconciliator(target_level="market", source_level="product") + pipeline = HierarchicalPipeline(reconciliator=reconciliator, model=model, transforms=[], horizon=1) + dir_path = pathlib.Path(tmp_path) + path = dir_path / "dummy.zip" + pipeline.fit(ts) + + pipeline.save(path) + loaded_pipeline = HierarchicalPipeline.load(path) + + load_mock.assert_called_once_with(path=path) + assert loaded_pipeline._fit_ts is None + assert loaded_pipeline.ts is None + assert loaded_pipeline == load_mock.return_value + + +@patch("etna.pipeline.pipeline.Pipeline.load") +def test_load_with_ts(load_mock, product_level_constant_hierarchical_ts, tmp_path): + ts = product_level_constant_hierarchical_ts + model = NaiveModel() + reconciliator = BottomUpReconciliator(target_level="market", source_level="product") + pipeline = HierarchicalPipeline(reconciliator=reconciliator, model=model, transforms=[], horizon=1) + dir_path = pathlib.Path(tmp_path) + path = dir_path / "dummy.zip" + pipeline.fit(ts) + + pipeline.save(path) + loaded_pipeline = HierarchicalPipeline.load(path, ts=ts) + + load_mock.assert_called_once_with(path=path) + load_mock.return_value.reconciliator.aggregate.assert_called_once_with(ts=ts) + pd.testing.assert_frame_equal(loaded_pipeline._fit_ts.to_pandas(), ts.to_pandas()) + assert loaded_pipeline.ts == load_mock.return_value.reconciliator.aggregate.return_value + + +@pytest.mark.parametrize( + "reconciliator", + ( + TopDownReconciliator(target_level="product", source_level="market", period=1, method="AHP"), + TopDownReconciliator(target_level="product", source_level="market", period=1, method="PHA"), + BottomUpReconciliator(target_level="market", source_level="product"), + BottomUpReconciliator(target_level="total", source_level="market"), + ), +) +@pytest.mark.parametrize( + "model, transforms", + [ + ( + CatBoostMultiSegmentModel(iterations=100), + [DateFlagsTransform(), LagTransform(in_column="target", lags=[1])], + ), + ( + LinearPerSegmentModel(), + [DateFlagsTransform(), LagTransform(in_column="target", lags=[1])], + ), + (NaiveModel(), []), + (ProphetModel(), []), + ], +) +def test_save_load(model, transforms, reconciliator, product_level_constant_hierarchical_ts): + horizon = 1 + pipeline = HierarchicalPipeline(reconciliator=reconciliator, model=model, transforms=transforms, horizon=horizon) + assert_pipeline_equals_loaded_original(pipeline=pipeline, ts=product_level_constant_hierarchical_ts) + + +@pytest.mark.parametrize( + "reconciliator", + ( + TopDownReconciliator(target_level="product", source_level="market", period=1, method="AHP"), + TopDownReconciliator(target_level="product", source_level="market", period=1, method="PHA"), + BottomUpReconciliator(target_level="market", source_level="product"), + BottomUpReconciliator(target_level="total", source_level="market"), + ), +) +@pytest.mark.parametrize( + "model, transforms", + [ + ( + CatBoostMultiSegmentModel(iterations=100), + [DateFlagsTransform(), LagTransform(in_column="target", lags=[1])], + ), + ( + LinearPerSegmentModel(), + [DateFlagsTransform(), LagTransform(in_column="target", lags=[1])], + ), + (NaiveModel(), []), + (ProphetModel(), []), + ], +) +def test_forecast_given_ts(model, transforms, reconciliator, product_level_constant_hierarchical_ts): + """Test that forecast makes forecasts with given ts. + + We don't use :py:func:`tests.test_pipeline.utils.assert_pipeline_forecasts_with_given_ts` here, + because it is difficult to set it up for hierarchy. + """ + horizon = 1 + ts = product_level_constant_hierarchical_ts + pipeline = HierarchicalPipeline(reconciliator=reconciliator, model=model, transforms=transforms, horizon=horizon) + + subset_ts = deepcopy(ts) + subset_ts.df = subset_ts.df.iloc[:-horizon] + + pipeline.fit(ts) + forecast_full = pipeline.forecast() + forecast_subset = pipeline.forecast(ts=subset_ts) + + expected_segments = forecast_full.segments + expected_index = forecast_full.index - pd.DateOffset(days=horizon) + assert forecast_subset.segments == expected_segments + pd.testing.assert_index_equal(forecast_subset.index, expected_index) From be6152c0f3e24c4410b0969839a95fd72c220dd9 Mon Sep 17 00:00:00 2001 From: "d.a.bunin" Date: Tue, 21 Mar 2023 16:36:12 +0300 Subject: [PATCH 10/13] test: fix inference tests for transforms --- .../test_inference/test_inverse_transform.py | 5 +---- tests/test_transforms/test_inference/test_transform.py | 7 ++----- 2 files changed, 3 insertions(+), 9 deletions(-) diff --git a/tests/test_transforms/test_inference/test_inverse_transform.py b/tests/test_transforms/test_inference/test_inverse_transform.py index 890056a2c..c85ee8483 100644 --- a/tests/test_transforms/test_inference/test_inverse_transform.py +++ b/tests/test_transforms/test_inference/test_inverse_transform.py @@ -444,7 +444,6 @@ def _test_inverse_transform_train_new_segments(self, ts, transform, train_segmen # feature_selection (FilterFeaturesTransform(exclude=["year"]), "ts_with_exog", {}), (FilterFeaturesTransform(exclude=["year"], return_features=True), "ts_with_exog", {"create": {"year"}}), - # TODO: this should remove only 2 features, wait for fixing [#1097](https://github.com/tinkoff-ai/etna/issues/1097) ( GaleShapleyFeatureSelectionTransform(relevance_table=StatisticsRelevanceTable(), top_k=2), "ts_with_exog", @@ -455,7 +454,7 @@ def _test_inverse_transform_train_new_segments(self, ts, transform, train_segmen relevance_table=StatisticsRelevanceTable(), top_k=2, return_features=True ), "ts_with_exog", - {"create": {"monthday", "year", "positive", "weekday", "month"}}, + {"create": {"year", "weekday", "month"}}, ), ( MRMRFeatureSelectionTransform(relevance_table=StatisticsRelevanceTable(), top_k=2), @@ -752,7 +751,6 @@ def _test_inverse_transform_future_new_segments(self, ts, transform, train_segme ), # feature_selection (FilterFeaturesTransform(exclude=["year"]), "ts_with_exog", {}), - # TODO: this should remove only 2 features ( GaleShapleyFeatureSelectionTransform(relevance_table=StatisticsRelevanceTable(), top_k=2), "ts_with_exog", @@ -1147,7 +1145,6 @@ def _test_inverse_transform_future_with_target( # feature_selection (FilterFeaturesTransform(exclude=["year"]), "ts_with_exog", {}), (FilterFeaturesTransform(exclude=["year"], return_features=True), "ts_with_exog", {"create": {"year"}}), - # TODO: this should remove only 2 features ( GaleShapleyFeatureSelectionTransform(relevance_table=StatisticsRelevanceTable(), top_k=2), "ts_with_exog", diff --git a/tests/test_transforms/test_inference/test_transform.py b/tests/test_transforms/test_inference/test_transform.py index d5d5e598c..9ed8ace5f 100644 --- a/tests/test_transforms/test_inference/test_transform.py +++ b/tests/test_transforms/test_inference/test_transform.py @@ -430,11 +430,10 @@ def _test_transform_train_new_segments(self, ts, transform, train_segments, expe ), # feature_selection (FilterFeaturesTransform(exclude=["year"]), "ts_with_exog", {"remove": {"year"}}), - # TODO: this should remove only 2 features, wait for fixing [#1097](https://github.com/tinkoff-ai/etna/issues/1097) ( GaleShapleyFeatureSelectionTransform(relevance_table=StatisticsRelevanceTable(), top_k=2), "ts_with_exog", - {"remove": {"weekday", "year", "month", "monthday", "positive"}}, + {"remove": {"weekday", "year", "month"}}, ), ( MRMRFeatureSelectionTransform(relevance_table=StatisticsRelevanceTable(), top_k=2), @@ -710,11 +709,10 @@ def _test_transform_future_new_segments(self, ts, transform, train_segments, exp ), # feature_selection (FilterFeaturesTransform(exclude=["year"]), "ts_with_exog", {"remove": {"year"}}), - # TODO: this should remove only 2 features ( GaleShapleyFeatureSelectionTransform(relevance_table=StatisticsRelevanceTable(), top_k=2), "ts_with_exog", - {"remove": {"weekday", "year", "month", "monthday", "positive"}}, + {"remove": {"weekday", "year", "month"}}, ), ( MRMRFeatureSelectionTransform(relevance_table=StatisticsRelevanceTable(), top_k=2), @@ -1054,7 +1052,6 @@ def _test_transform_future_with_target(self, ts, transform, expected_changes, ga (SegmentEncoderTransform(), "regular_ts", {"create": {"segment_code"}}), # feature_selection (FilterFeaturesTransform(exclude=["year"]), "ts_with_exog", {"remove": {"year"}}), - # TODO: this should remove only 2 features ( GaleShapleyFeatureSelectionTransform(relevance_table=StatisticsRelevanceTable(), top_k=2), "ts_with_exog", From 31826958d1c51b08d50606dc7a5aacfb27844539 Mon Sep 17 00:00:00 2001 From: "d.a.bunin" Date: Tue, 21 Mar 2023 16:59:14 +0300 Subject: [PATCH 11/13] test: fix test_process_fold_forecast --- tests/test_pipeline/test_pipeline.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/test_pipeline/test_pipeline.py b/tests/test_pipeline/test_pipeline.py index ee87b200c..fd19082c8 100644 --- a/tests/test_pipeline/test_pipeline.py +++ b/tests/test_pipeline/test_pipeline.py @@ -805,7 +805,7 @@ def test_generate_folds_datasets(ts_name, mask, request): """Check _generate_folds_datasets for correct work.""" ts = request.getfixturevalue(ts_name) pipeline = Pipeline(model=NaiveModel(lag=7)) - mask = pipeline._prepare_fold_masks(ts=ts, masks=[mask], mode="constant", stride=-1)[0] + mask = pipeline._prepare_fold_masks(ts=ts, masks=[mask], mode=CrossValidationMode.expand, stride=-1)[0] train, test = list(pipeline._generate_folds_datasets(ts, [mask], 4))[0] assert train.index.min() == np.datetime64(mask.first_train_timestamp) assert train.index.max() == np.datetime64(mask.last_train_timestamp) @@ -823,7 +823,7 @@ def test_generate_folds_datasets_without_first_date(ts_name, mask, request): """Check _generate_folds_datasets for correct work without first date.""" ts = request.getfixturevalue(ts_name) pipeline = Pipeline(model=NaiveModel(lag=7)) - mask = pipeline._prepare_fold_masks(ts=ts, masks=[mask], mode="constant", stride=-1)[0] + mask = pipeline._prepare_fold_masks(ts=ts, masks=[mask], mode=CrossValidationMode.expand, stride=-1)[0] train, test = list(pipeline._generate_folds_datasets(ts, [mask], 4))[0] assert train.index.min() == np.datetime64(ts.index.min()) assert train.index.max() == np.datetime64(mask.last_train_timestamp) @@ -847,7 +847,7 @@ def test_process_fold_forecast(ts_process_fold_forecast, mask: FoldMask, expecte pipeline = pipeline.fit(ts=train) forecast = pipeline.forecast() fold = pipeline._process_fold_forecast( - forecast=forecast, train=train, test=test, fold_number=1, mask=mask, metrics=[MAE()] + forecast=forecast, train=train, test=test, pipeline=pipeline, fold_number=1, mask=mask, metrics=[MAE()] ) for seg in fold["metrics"]["MAE"].keys(): assert fold["metrics"]["MAE"][seg] == expected[seg] From 0a088aba2a9a671c6db4a857ef2c6060fefe7945 Mon Sep 17 00:00:00 2001 From: "d.a.bunin" Date: Tue, 21 Mar 2023 17:01:54 +0300 Subject: [PATCH 12/13] style: reformat code --- etna/transforms/utils.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/etna/transforms/utils.py b/etna/transforms/utils.py index 6f83b3d96..8b491c752 100644 --- a/etna/transforms/utils.py +++ b/etna/transforms/utils.py @@ -1,9 +1,8 @@ -from etna.datasets.utils import match_target_quantiles # noqa: F401 -import re import reprlib from typing import List from typing import Optional -from typing import Set + +from etna.datasets.utils import match_target_quantiles # noqa: F401 def check_new_segments(transform_segments: List[str], fit_segments: Optional[List[str]]): From 43e1c3c61a1e2bdc4229b23f60d074b9c93cbc4c Mon Sep 17 00:00:00 2001 From: "d.a.bunin" Date: Tue, 21 Mar 2023 17:03:17 +0300 Subject: [PATCH 13/13] fix: set poetry version for ci --- .github/workflows/docs-on-pr.yml | 1 + .github/workflows/docs-unstable.yml | 1 + .github/workflows/notebooks.yml | 1 + .github/workflows/test.yml | 7 ++++++- 4 files changed, 9 insertions(+), 1 deletion(-) diff --git a/.github/workflows/docs-on-pr.yml b/.github/workflows/docs-on-pr.yml index 42bc8c982..c32372d59 100644 --- a/.github/workflows/docs-on-pr.yml +++ b/.github/workflows/docs-on-pr.yml @@ -16,6 +16,7 @@ jobs: - name: Install Poetry uses: snok/install-poetry@v1 with: + version: 1.4.0 # TODO: remove after poetry fix virtualenvs-create: true virtualenvs-in-project: true - name: Load cached venv diff --git a/.github/workflows/docs-unstable.yml b/.github/workflows/docs-unstable.yml index 5dbbe59e5..11fb090da 100644 --- a/.github/workflows/docs-unstable.yml +++ b/.github/workflows/docs-unstable.yml @@ -17,6 +17,7 @@ jobs: - name: Install Poetry uses: snok/install-poetry@v1 with: + version: 1.4.0 # TODO: remove after poetry fix virtualenvs-create: true virtualenvs-in-project: true - name: Load cached venv diff --git a/.github/workflows/notebooks.yml b/.github/workflows/notebooks.yml index 15cb3ba4e..83ba01fce 100644 --- a/.github/workflows/notebooks.yml +++ b/.github/workflows/notebooks.yml @@ -26,6 +26,7 @@ jobs: - name: Install Poetry uses: snok/install-poetry@v1 with: + version: 1.4.0 # TODO: remove after poetry fix virtualenvs-create: true virtualenvs-in-project: true - name: Install dependencies diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 1b4f40d09..d735cf455 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -21,7 +21,7 @@ jobs: - name: Install Dependencies run: | - pip install poetry + pip install poetry==1.4.0 # TODO: remove after poetry fix poetry --version poetry config virtualenvs.in-project true poetry install -E style --no-root @@ -48,6 +48,7 @@ jobs: - name: Install Poetry uses: snok/install-poetry@v1 with: + version: 1.4.0 # TODO: remove after poetry fix virtualenvs-create: true virtualenvs-in-project: true @@ -86,6 +87,7 @@ jobs: - name: Install Poetry uses: snok/install-poetry@v1 with: + version: 1.4.0 # TODO: remove after poetry fix virtualenvs-create: true virtualenvs-in-project: true @@ -123,6 +125,7 @@ jobs: - name: Install Poetry uses: snok/install-poetry@v1 with: + version: 1.4.0 # TODO: remove after poetry fix virtualenvs-create: true virtualenvs-in-project: true @@ -160,6 +163,7 @@ jobs: - name: Install Poetry uses: snok/install-poetry@v1 with: + version: 1.4.0 # TODO: remove after poetry fix virtualenvs-create: true virtualenvs-in-project: true @@ -199,6 +203,7 @@ jobs: - name: Install Poetry uses: snok/install-poetry@v1 with: + version: 1.4.0 # TODO: remove after poetry fix virtualenvs-create: true virtualenvs-in-project: true