Skip to content

Add SaveEnsembleMixin #1046

Merged
merged 6 commits into from
Dec 22, 2022
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,13 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
## Unreleased
### Added
-
- Add `SaveModelPipelineMixin`, add `load`, add saving for `Pipeline` and `AutoRegressivePipeline` ([#1036](https://github.com/tinkoff-ai/etna/pull/1036))
- Add `SaveModelPipelineMixin`, add `load`, add saving and loading for `Pipeline` and `AutoRegressivePipeline` ([#1036](https://github.com/tinkoff-ai/etna/pull/1036))
- Add `SaveMixin` to models and transforms ([#1007](https://github.com/tinkoff-ai/etna/pull/1007))
- Add `plot_change_points_interactive` ([#988](https://github.com/tinkoff-ai/etna/pull/988))
- Add `experimental` module with `TimeSeriesBinaryClassifier` and `PredictabilityAnalyzer` ([#985](https://github.com/tinkoff-ai/etna/pull/985))
- Inference track results: add `predict` method to pipelines, teach some models to work with context, change hierarchy of base models, update notebook examples ([#979](https://github.com/tinkoff-ai/etna/pull/979))
- Add `get_ruptures_regularization` into `experimental` module ([#1001](https://github.com/tinkoff-ai/etna/pull/1001))
-
- Add `SaveEnsembleMixin`, add saving and loading for `VotingEnsemble`, `StackingEnsemble` and `DirectEnsemble` ([#1046](https://github.com/tinkoff-ai/etna/pull/1046))
### Changed
-
- Change returned model in get_model of BATSModel, TBATSModel ([#987](https://github.com/tinkoff-ai/etna/pull/987))
Expand Down
6 changes: 4 additions & 2 deletions etna/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,15 @@
from hydra_slayer import get_factory


def load(path: pathlib.Path) -> Any:
def load(path: pathlib.Path, **kwargs: Any) -> Any:
"""Load saved object by path.

Parameters
----------
path:
Path to load object from.
kwargs:
Parameters for loading specific for the loaded object.

Returns
-------
Expand All @@ -33,7 +35,7 @@ def load(path: pathlib.Path) -> Any:

# create object for that class
object_class = get_factory(object_class_name)
loaded_object = object_class.load(path=path)
loaded_object = object_class.load(path=path, **kwargs)

return loaded_object

Expand Down
2 changes: 1 addition & 1 deletion etna/ensembles/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from etna.ensembles.base import EnsembleMixin
from etna.ensembles.direct_ensemble import DirectEnsemble
from etna.ensembles.mixins import EnsembleMixin
from etna.ensembles.stacking_ensemble import StackingEnsemble
from etna.ensembles.voting_ensemble import VotingEnsemble
78 changes: 0 additions & 78 deletions etna/ensembles/base.py

This file was deleted.

5 changes: 3 additions & 2 deletions etna/ensembles/direct_ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,12 @@
from joblib import delayed

from etna.datasets import TSDataset
from etna.ensembles import EnsembleMixin
from etna.ensembles.mixins import EnsembleMixin
from etna.ensembles.mixins import SaveEnsembleMixin
from etna.pipeline.base import BasePipeline


class DirectEnsemble(EnsembleMixin, BasePipeline):
class DirectEnsemble(EnsembleMixin, SaveEnsembleMixin, BasePipeline):
"""DirectEnsemble is a pipeline that forecasts future values merging the forecasts of base pipelines.

Ensemble expects several pipelines during init. These pipelines are expected to have different forecasting horizons.
Expand Down
148 changes: 148 additions & 0 deletions etna/ensembles/mixins.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
import pathlib
import tempfile
import zipfile
from copy import deepcopy
from typing import Any
from typing import List
from typing import Optional

import pandas as pd

from etna.core import SaveMixin
from etna.core import load
from etna.datasets import TSDataset
from etna.loggers import tslogger
from etna.pipeline.base import BasePipeline


class EnsembleMixin:
"""Base mixin for the ensembles."""

@staticmethod
def _validate_pipeline_number(pipelines: List[BasePipeline]):
"""Check that given valid number of pipelines."""
if len(pipelines) < 2:
raise ValueError("At least two pipelines are expected.")

@staticmethod
def _get_horizon(pipelines: List[BasePipeline]) -> int:
"""Get ensemble's horizon."""
horizons = {pipeline.horizon for pipeline in pipelines}
if len(horizons) > 1:
raise ValueError("All the pipelines should have the same horizon.")
return horizons.pop()

@staticmethod
def _fit_pipeline(pipeline: BasePipeline, ts: TSDataset) -> BasePipeline:
"""Fit given pipeline with ``ts``."""
tslogger.log(msg=f"Start fitting {pipeline}.")
pipeline.fit(ts=ts)
tslogger.log(msg=f"Pipeline {pipeline} is fitted.")
return pipeline

@staticmethod
def _forecast_pipeline(pipeline: BasePipeline) -> TSDataset:
"""Make forecast with given pipeline."""
tslogger.log(msg=f"Start forecasting with {pipeline}.")
forecast = pipeline.forecast()
tslogger.log(msg=f"Forecast is done with {pipeline}.")
return forecast

@staticmethod
def _predict_pipeline(
ts: TSDataset,
pipeline: BasePipeline,
start_timestamp: Optional[pd.Timestamp],
end_timestamp: Optional[pd.Timestamp],
) -> TSDataset:
"""Make predict with given pipeline."""
tslogger.log(msg=f"Start prediction with {pipeline}.")
prediction = pipeline.predict(ts=ts, start_timestamp=start_timestamp, end_timestamp=end_timestamp)
tslogger.log(msg=f"Prediction is done with {pipeline}.")
return prediction


class SaveEnsembleMixin(SaveMixin):
"""Implementation of ``AbstractSaveable`` abstract class for ensemble pipelines.

It saves object to the zip archive with 3 entities:

* metadata.json: contains library version and class name.

* object.pkl: pickled without pipelines and ts.

* pipelines: folder with saved pipelines.
"""

pipelines: List[BasePipeline]
ts: Optional[TSDataset]

def save(self, path: pathlib.Path):
"""Save the object.

Parameters
----------
path:
Path to save object to.
"""
pipelines = self.pipelines
ts = self.ts
try:
# extract attributes we can't easily save
delattr(self, "pipelines")
delattr(self, "ts")

# save the remaining part
super().save(path=path)
finally:
self.pipelines = pipelines
self.ts = ts

with zipfile.ZipFile(path, "a") as archive:
with tempfile.TemporaryDirectory() as _temp_dir:
temp_dir = pathlib.Path(_temp_dir)

# save transforms separately
pipelines_dir = temp_dir / "pipelines"
pipelines_dir.mkdir()
num_digits = len(str(len(pipelines) - 1))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💣
What do you want to do?
We always can use fix len - for example 8 symbols - it will be enough

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, I will also change that for transforms (we have the same logic where).

for i, pipeline in enumerate(pipelines):
save_name = f"{i:0{num_digits}d}.zip"
pipeline_save_path = pipelines_dir / save_name
pipeline.save(pipeline_save_path)
archive.write(pipeline_save_path, f"pipelines/{save_name}")

@classmethod
def load(cls, path: pathlib.Path, ts: Optional[TSDataset] = None) -> Any:
"""Load an object.

Parameters
----------
path:
Path to load object from.
ts:
TSDataset to set into loaded pipeline.

Returns
-------
:
Loaded object.
"""
obj = super().load(path=path)
obj.ts = deepcopy(ts)

with zipfile.ZipFile(path, "r") as archive:
with tempfile.TemporaryDirectory() as _temp_dir:
temp_dir = pathlib.Path(_temp_dir)

archive.extractall(temp_dir)

# load pipelines
pipelines_dir = temp_dir / "pipelines"
pipelines = []
for path in sorted(pipelines_dir.iterdir()):
pipelines.append(load(path, ts=ts))

obj.pipelines = pipelines

return obj
5 changes: 3 additions & 2 deletions etna/ensembles/stacking_ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,14 @@
from typing_extensions import Literal

from etna.datasets import TSDataset
from etna.ensembles import EnsembleMixin
from etna.ensembles.mixins import EnsembleMixin
from etna.ensembles.mixins import SaveEnsembleMixin
from etna.loggers import tslogger
from etna.metrics import MAE
from etna.pipeline.base import BasePipeline


class StackingEnsemble(EnsembleMixin, BasePipeline):
class StackingEnsemble(EnsembleMixin, SaveEnsembleMixin, BasePipeline):
"""StackingEnsemble is a pipeline that forecast future using the metamodel to combine the forecasts of the base models.

Examples
Expand Down
5 changes: 3 additions & 2 deletions etna/ensembles/voting_ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,14 @@

from etna.analysis.feature_relevance.relevance_table import TreeBasedRegressor
from etna.datasets import TSDataset
from etna.ensembles import EnsembleMixin
from etna.ensembles.mixins import EnsembleMixin
from etna.ensembles.mixins import SaveEnsembleMixin
from etna.loggers import tslogger
from etna.metrics import MAE
from etna.pipeline.base import BasePipeline


class VotingEnsemble(EnsembleMixin, BasePipeline):
class VotingEnsemble(EnsembleMixin, SaveEnsembleMixin, BasePipeline):
"""VotingEnsemble is a pipeline that forecast future values with weighted averaging of it's pipelines forecasts.

Examples
Expand Down
7 changes: 6 additions & 1 deletion etna/pipeline/mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def _predict(
class SaveModelPipelineMixin(SaveMixin):
"""Implementation of ``AbstractSaveable`` abstract class for pipelines with model inside.

It saves object to the zip archive with 2 files:
It saves object to the zip archive with 4 entities:

* metadata.json: contains library version and class name.

Expand All @@ -118,6 +118,7 @@ def save(self, path: pathlib.Path):
model = self.model
transforms = self.transforms
ts = self.ts

try:
# extract attributes we can't easily save
delattr(self, "model")
Expand Down Expand Up @@ -189,4 +190,8 @@ def load(cls, path: pathlib.Path, ts: Optional[TSDataset] = None) -> Any:

obj.transforms = transforms

# set transforms in ts
if obj.ts is not None:
obj.ts.transforms = transforms

return obj
19 changes: 19 additions & 0 deletions tests/test_core/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
import pathlib
import tempfile

import pandas as pd
import pytest

from etna.core import load
from etna.models import NaiveModel
from etna.pipeline import Pipeline
from etna.transforms import AddConstTransform


Expand All @@ -21,6 +24,22 @@ def test_load_ok():
transform.save(save_path)

new_transform = load(save_path)

assert type(new_transform) == type(transform)
for attribute in ["in_column", "value", "inplace"]:
assert getattr(new_transform, attribute) == getattr(transform, attribute)


def test_load_ok_with_params(example_tsds):
pipeline = Pipeline(model=NaiveModel(), horizon=7)
with tempfile.TemporaryDirectory() as _temp_dir:
temp_dir = pathlib.Path(_temp_dir)
save_path = temp_dir / "pipeline.zip"
pipeline.fit(ts=example_tsds)
pipeline.save(save_path)

new_pipeline = load(save_path, ts=example_tsds)

assert new_pipeline.ts is not None
assert type(new_pipeline) == type(pipeline)
pd.testing.assert_frame_equal(new_pipeline.ts.to_pandas(), example_tsds.to_pandas())
Loading