From b5150278c4a75b5bc06f1c146b08ea0364ff5d2b Mon Sep 17 00:00:00 2001 From: "d.a.bunin" Date: Fri, 16 Dec 2022 17:49:48 +0300 Subject: [PATCH 1/6] Add SaveEnsembleMixin and tests for it --- etna/core/utils.py | 6 +- etna/ensembles/__init__.py | 2 +- etna/ensembles/base.py | 78 --------- etna/ensembles/mixins.py | 171 ++++++++++++++++++++ etna/pipeline/mixins.py | 3 +- tests/test_core/test_utils.py | 19 +++ tests/test_ensembles/test_ensemble_mixin.py | 24 --- tests/test_ensembles/test_mixins.py | 137 ++++++++++++++++ tests/test_pipeline/test_mixins.py | 8 +- 9 files changed, 338 insertions(+), 110 deletions(-) delete mode 100644 etna/ensembles/base.py create mode 100644 etna/ensembles/mixins.py delete mode 100644 tests/test_ensembles/test_ensemble_mixin.py create mode 100644 tests/test_ensembles/test_mixins.py diff --git a/etna/core/utils.py b/etna/core/utils.py index d067e7af9..0dc2ee29b 100644 --- a/etna/core/utils.py +++ b/etna/core/utils.py @@ -10,13 +10,15 @@ from hydra_slayer import get_factory -def load(path: pathlib.Path) -> Any: +def load(path: pathlib.Path, **kwargs: Any) -> Any: """Load saved object by path. Parameters ---------- path: Path to load object from. + kwargs: + Parameters for loading specific for the loaded object. Returns ------- @@ -33,7 +35,7 @@ def load(path: pathlib.Path) -> Any: # create object for that class object_class = get_factory(object_class_name) - loaded_object = object_class.load(path=path) + loaded_object = object_class.load(path=path, **kwargs) return loaded_object diff --git a/etna/ensembles/__init__.py b/etna/ensembles/__init__.py index 835ee4366..a6dcd27ab 100644 --- a/etna/ensembles/__init__.py +++ b/etna/ensembles/__init__.py @@ -1,4 +1,4 @@ -from etna.ensembles.base import EnsembleMixin from etna.ensembles.direct_ensemble import DirectEnsemble +from etna.ensembles.mixins import EnsembleMixin from etna.ensembles.stacking_ensemble import StackingEnsemble from etna.ensembles.voting_ensemble import VotingEnsemble diff --git a/etna/ensembles/base.py b/etna/ensembles/base.py deleted file mode 100644 index 80e0468a0..000000000 --- a/etna/ensembles/base.py +++ /dev/null @@ -1,78 +0,0 @@ -import pathlib -from typing import Any -from typing import List -from typing import Optional - -import pandas as pd - -from etna.datasets import TSDataset -from etna.loggers import tslogger -from etna.pipeline.base import BasePipeline - - -class EnsembleMixin: - """Base mixin for the ensembles.""" - - @staticmethod - def _validate_pipeline_number(pipelines: List[BasePipeline]): - """Check that given valid number of pipelines.""" - if len(pipelines) < 2: - raise ValueError("At least two pipelines are expected.") - - @staticmethod - def _get_horizon(pipelines: List[BasePipeline]) -> int: - """Get ensemble's horizon.""" - horizons = {pipeline.horizon for pipeline in pipelines} - if len(horizons) > 1: - raise ValueError("All the pipelines should have the same horizon.") - return horizons.pop() - - @staticmethod - def _fit_pipeline(pipeline: BasePipeline, ts: TSDataset) -> BasePipeline: - """Fit given pipeline with ``ts``.""" - tslogger.log(msg=f"Start fitting {pipeline}.") - pipeline.fit(ts=ts) - tslogger.log(msg=f"Pipeline {pipeline} is fitted.") - return pipeline - - @staticmethod - def _forecast_pipeline(pipeline: BasePipeline) -> TSDataset: - """Make forecast with given pipeline.""" - tslogger.log(msg=f"Start forecasting with {pipeline}.") - forecast = pipeline.forecast() - tslogger.log(msg=f"Forecast is done with {pipeline}.") - return forecast - - @staticmethod - def _predict_pipeline( - ts: TSDataset, - pipeline: BasePipeline, - start_timestamp: Optional[pd.Timestamp], - end_timestamp: Optional[pd.Timestamp], - ) -> TSDataset: - """Make predict with given pipeline.""" - tslogger.log(msg=f"Start prediction with {pipeline}.") - prediction = pipeline.predict(ts=ts, start_timestamp=start_timestamp, end_timestamp=end_timestamp) - tslogger.log(msg=f"Prediction is done with {pipeline}.") - return prediction - - def save(self, path: pathlib.Path): - """Save the object. - - Parameters - ---------- - path: - Path to save object to. - """ - raise NotImplementedError() - - @classmethod - def load(cls, path: pathlib.Path) -> Any: - """Load an object. - - Parameters - ---------- - path: - Path to load object from. - """ - raise NotImplementedError() diff --git a/etna/ensembles/mixins.py b/etna/ensembles/mixins.py new file mode 100644 index 000000000..55c5fabe5 --- /dev/null +++ b/etna/ensembles/mixins.py @@ -0,0 +1,171 @@ +import pathlib +import tempfile +import zipfile +from copy import deepcopy +from typing import Any +from typing import List +from typing import Optional +from typing import Sequence + +import pandas as pd + +from etna.core import SaveMixin +from etna.core import load +from etna.datasets import TSDataset +from etna.loggers import tslogger +from etna.pipeline.base import AbstractPipeline +from etna.pipeline.base import BasePipeline + + +class EnsembleMixin: + """Base mixin for the ensembles.""" + + @staticmethod + def _validate_pipeline_number(pipelines: List[BasePipeline]): + """Check that given valid number of pipelines.""" + if len(pipelines) < 2: + raise ValueError("At least two pipelines are expected.") + + @staticmethod + def _get_horizon(pipelines: List[BasePipeline]) -> int: + """Get ensemble's horizon.""" + horizons = {pipeline.horizon for pipeline in pipelines} + if len(horizons) > 1: + raise ValueError("All the pipelines should have the same horizon.") + return horizons.pop() + + @staticmethod + def _fit_pipeline(pipeline: BasePipeline, ts: TSDataset) -> BasePipeline: + """Fit given pipeline with ``ts``.""" + tslogger.log(msg=f"Start fitting {pipeline}.") + pipeline.fit(ts=ts) + tslogger.log(msg=f"Pipeline {pipeline} is fitted.") + return pipeline + + @staticmethod + def _forecast_pipeline(pipeline: BasePipeline) -> TSDataset: + """Make forecast with given pipeline.""" + tslogger.log(msg=f"Start forecasting with {pipeline}.") + forecast = pipeline.forecast() + tslogger.log(msg=f"Forecast is done with {pipeline}.") + return forecast + + @staticmethod + def _predict_pipeline( + ts: TSDataset, + pipeline: BasePipeline, + start_timestamp: Optional[pd.Timestamp], + end_timestamp: Optional[pd.Timestamp], + ) -> TSDataset: + """Make predict with given pipeline.""" + tslogger.log(msg=f"Start prediction with {pipeline}.") + prediction = pipeline.predict(ts=ts, start_timestamp=start_timestamp, end_timestamp=end_timestamp) + tslogger.log(msg=f"Prediction is done with {pipeline}.") + return prediction + + def save(self, path: pathlib.Path): + """Save the object. + + Parameters + ---------- + path: + Path to save object to. + """ + raise NotImplementedError() + + @classmethod + def load(cls, path: pathlib.Path) -> Any: + """Load an object. + + Parameters + ---------- + path: + Path to load object from. + """ + raise NotImplementedError() + + +class SaveEnsembleMixin(SaveMixin): + """Implementation of ``AbstractSaveable`` abstract class for ensemble pipelines. + + It saves object to the zip archive with 3 entities: + + * metadata.json: contains library version and class name. + + * object.pkl: pickled without pipelines and ts. + + * pipelines: folder with saved pipelines. + """ + + pipelines: Sequence[AbstractPipeline] + ts: Optional[TSDataset] + + def save(self, path: pathlib.Path): + """Save the object. + + Parameters + ---------- + path: + Path to save object to. + """ + pipelines = self.pipelines + ts = self.ts + try: + # extract attributes we can't easily save + delattr(self, "pipelines") + delattr(self, "ts") + + # save the remaining part + super().save(path=path) + finally: + self.pipelines = pipelines + self.ts = ts + + with zipfile.ZipFile(path, "a") as archive: + with tempfile.TemporaryDirectory() as _temp_dir: + temp_dir = pathlib.Path(_temp_dir) + + # save transforms separately + pipelines_dir = temp_dir / "pipelines" + pipelines_dir.mkdir() + num_digits = len(str(len(pipelines) - 1)) + for i, pipeline in enumerate(pipelines): + save_name = f"{i:0{num_digits}d}.zip" + pipeline_save_path = pipelines_dir / save_name + pipeline.save(pipeline_save_path) + archive.write(pipeline_save_path, f"pipelines/{save_name}") + + @classmethod + def load(cls, path: pathlib.Path, ts: Optional[TSDataset] = None) -> Any: + """Load an object. + + Parameters + ---------- + path: + Path to load object from. + ts: + TSDataset to set into loaded pipeline. + + Returns + ------- + : + Loaded object. + """ + obj = super().load(path=path) + obj.ts = deepcopy(ts) + + with zipfile.ZipFile(path, "r") as archive: + with tempfile.TemporaryDirectory() as _temp_dir: + temp_dir = pathlib.Path(_temp_dir) + + archive.extractall(temp_dir) + + # load pipelines + pipelines_dir = temp_dir / "pipelines" + pipelines = [] + for path in sorted(pipelines_dir.iterdir()): + pipelines.append(load(path, ts=ts)) + + obj.pipelines = pipelines + + return obj diff --git a/etna/pipeline/mixins.py b/etna/pipeline/mixins.py index fb803d8f2..17ba4b3de 100644 --- a/etna/pipeline/mixins.py +++ b/etna/pipeline/mixins.py @@ -92,7 +92,7 @@ def _predict( class SaveModelPipelineMixin(SaveMixin): """Implementation of ``AbstractSaveable`` abstract class for pipelines with model inside. - It saves object to the zip archive with 2 files: + It saves object to the zip archive with 4 entities: * metadata.json: contains library version and class name. @@ -118,6 +118,7 @@ def save(self, path: pathlib.Path): model = self.model transforms = self.transforms ts = self.ts + try: # extract attributes we can't easily save delattr(self, "model") diff --git a/tests/test_core/test_utils.py b/tests/test_core/test_utils.py index bcdc4c873..583f272f2 100644 --- a/tests/test_core/test_utils.py +++ b/tests/test_core/test_utils.py @@ -1,9 +1,12 @@ import pathlib import tempfile +import pandas as pd import pytest from etna.core import load +from etna.models import NaiveModel +from etna.pipeline import Pipeline from etna.transforms import AddConstTransform @@ -21,6 +24,22 @@ def test_load_ok(): transform.save(save_path) new_transform = load(save_path) + assert type(new_transform) == type(transform) for attribute in ["in_column", "value", "inplace"]: assert getattr(new_transform, attribute) == getattr(transform, attribute) + + +def test_load_ok_with_params(example_tsds): + pipeline = Pipeline(model=NaiveModel(), horizon=7) + with tempfile.TemporaryDirectory() as _temp_dir: + temp_dir = pathlib.Path(_temp_dir) + save_path = temp_dir / "pipeline.zip" + pipeline.fit(ts=example_tsds) + pipeline.save(save_path) + + new_pipeline = load(save_path, ts=example_tsds) + + assert new_pipeline.ts is not None + assert type(new_pipeline) == type(pipeline) + pd.testing.assert_frame_equal(new_pipeline.ts.to_pandas(), example_tsds.to_pandas()) diff --git a/tests/test_ensembles/test_ensemble_mixin.py b/tests/test_ensembles/test_ensemble_mixin.py deleted file mode 100644 index 13ce2c4c3..000000000 --- a/tests/test_ensembles/test_ensemble_mixin.py +++ /dev/null @@ -1,24 +0,0 @@ -import pytest - -from etna.ensembles.stacking_ensemble import StackingEnsemble -from etna.pipeline import Pipeline - -HORIZON = 7 - - -def test_invalid_pipelines_number(catboost_pipeline: Pipeline): - """Test StackingEnsemble behavior in case of invalid pipelines number.""" - with pytest.raises(ValueError, match="At least two pipelines are expected."): - _ = StackingEnsemble(pipelines=[catboost_pipeline]) - - -def test_get_horizon_pass(catboost_pipeline: Pipeline, prophet_pipeline: Pipeline): - """Check that StackingEnsemble._get horizon works correctly in case of valid pipelines list.""" - horizon = StackingEnsemble._get_horizon(pipelines=[catboost_pipeline, prophet_pipeline]) - assert horizon == HORIZON - - -def test_get_horizon_fail(catboost_pipeline: Pipeline, naive_pipeline: Pipeline): - """Check that StackingEnsemble._get horizon works correctly in case of invalid pipelines list.""" - with pytest.raises(ValueError, match="All the pipelines should have the same horizon."): - _ = StackingEnsemble._get_horizon(pipelines=[catboost_pipeline, naive_pipeline]) diff --git a/tests/test_ensembles/test_mixins.py b/tests/test_ensembles/test_mixins.py new file mode 100644 index 000000000..3854e53de --- /dev/null +++ b/tests/test_ensembles/test_mixins.py @@ -0,0 +1,137 @@ +import json +import pathlib +import pickle +import zipfile +from copy import deepcopy +from unittest.mock import patch + +import pandas as pd +import pytest + +from etna.ensembles.mixins import SaveEnsembleMixin +from etna.ensembles.stacking_ensemble import StackingEnsemble +from etna.models import NaiveModel +from etna.pipeline import Pipeline + +HORIZON = 7 + + +def test_ensemble_invalid_pipelines_number(catboost_pipeline: Pipeline): + """Test StackingEnsemble behavior in case of invalid pipelines number.""" + with pytest.raises(ValueError, match="At least two pipelines are expected."): + _ = StackingEnsemble(pipelines=[catboost_pipeline]) + + +def test_ensemble_get_horizon_pass(catboost_pipeline: Pipeline, prophet_pipeline: Pipeline): + """Check that StackingEnsemble._get horizon works correctly in case of valid pipelines list.""" + horizon = StackingEnsemble._get_horizon(pipelines=[catboost_pipeline, prophet_pipeline]) + assert horizon == HORIZON + + +def test_ensemble_get_horizon_fail(catboost_pipeline: Pipeline, naive_pipeline: Pipeline): + """Check that StackingEnsemble._get horizon works correctly in case of invalid pipelines list.""" + with pytest.raises(ValueError, match="All the pipelines should have the same horizon."): + _ = StackingEnsemble._get_horizon(pipelines=[catboost_pipeline, naive_pipeline]) + + +class Dummy(SaveEnsembleMixin): + def __init__(self, a, b, ts, pipelines): + self.a = a + self.b = b + self.ts = ts + self.pipelines = pipelines + + +def test_save_mixin_save(example_tsds, tmp_path): + pipelines = [Pipeline(model=NaiveModel(lag=1), horizon=HORIZON), Pipeline(model=NaiveModel(lag=2), horizon=HORIZON)] + dummy = Dummy(a=1, b=2, ts=example_tsds, pipelines=pipelines) + dir_path = pathlib.Path(tmp_path) + path = dir_path / "dummy.zip" + + initial_dummy = deepcopy(dummy) + dummy.save(path) + + with zipfile.ZipFile(path, "r") as archive: + files = archive.namelist() + assert sorted(files) == sorted(["metadata.json", "object.pkl", "pipelines/0.zip", "pipelines/1.zip"]) + + with archive.open("metadata.json", "r") as input_file: + metadata_bytes = input_file.read() + metadata_str = metadata_bytes.decode("utf-8") + metadata = json.loads(metadata_str) + assert sorted(metadata.keys()) == ["class", "etna_version"] + assert metadata["class"] == "tests.test_ensembles.test_mixins.Dummy" + + with archive.open("object.pkl", "r") as input_file: + loaded_obj = pickle.load(input_file) + assert loaded_obj.a == dummy.a + assert loaded_obj.b == dummy.b + + # basic check that we didn't break dummy object itself + assert dummy.a == initial_dummy.a + assert pickle.dumps(dummy.ts) == pickle.dumps(initial_dummy.ts) + assert len(dummy.pipelines) == len(initial_dummy.pipelines) + + +def test_save_mixin_load_fail_file_not_found(): + non_existent_path = pathlib.Path("archive.zip") + with pytest.raises(FileNotFoundError): + Dummy.load(non_existent_path) + + +def test_save_mixin_load_ok_no_ts(example_tsds, recwarn, tmp_path): + lag_values = list(range(1, 11)) + pipelines = [Pipeline(model=NaiveModel(lag=lag), horizon=HORIZON) for lag in lag_values] + dummy = Dummy(a=1, b=2, ts=example_tsds, pipelines=pipelines) + dir_path = pathlib.Path(tmp_path) + path = dir_path / "dummy.zip" + + dummy.save(path) + loaded_dummy = Dummy.load(path) + + assert loaded_dummy.a == dummy.a + assert loaded_dummy.b == dummy.b + assert loaded_dummy.ts is None + assert [pipeline.model.lag for pipeline in loaded_dummy.pipelines] == lag_values + assert len(recwarn) == 0 + + +def test_save_mixin_load_ok_with_ts(example_tsds, recwarn, tmp_path): + lag_values = list(range(1, 11)) + pipelines = [Pipeline(model=NaiveModel(lag=lag), horizon=HORIZON) for lag in lag_values] + dummy = Dummy(a=1, b=2, ts=example_tsds, pipelines=pipelines) + dir_path = pathlib.Path(tmp_path) + path = dir_path / "dummy.zip" + + dummy.save(path) + loaded_dummy = Dummy.load(path, ts=example_tsds) + + assert loaded_dummy.a == dummy.a + assert loaded_dummy.b == dummy.b + assert loaded_dummy.ts is not example_tsds + pd.testing.assert_frame_equal(loaded_dummy.ts.to_pandas(), dummy.ts.to_pandas()) + assert [pipeline.model.lag for pipeline in loaded_dummy.pipelines] == lag_values + assert len(recwarn) == 0 + + +@pytest.mark.parametrize( + "save_version, load_version", [((1, 5, 0), (2, 5, 0)), ((2, 5, 0), (1, 5, 0)), ((1, 5, 0), (1, 3, 0))] +) +@patch("etna.core.mixins.get_etna_version") +def test_save_mixin_load_warning(get_version_mock, save_version, load_version, example_tsds, tmp_path): + pipelines = [Pipeline(model=NaiveModel(lag=1), horizon=HORIZON), Pipeline(model=NaiveModel(lag=2), horizon=HORIZON)] + dummy = Dummy(a=1, b=2, ts=example_tsds, pipelines=pipelines) + dir_path = pathlib.Path(tmp_path) + path = dir_path / "dummy.zip" + + get_version_mock.return_value = save_version + dummy.save(path) + + save_version_str = ".".join([str(x) for x in save_version]) + load_version_str = ".".join([str(x) for x in load_version]) + with pytest.warns( + UserWarning, + match=f"The object was saved under etna version {save_version_str} but running version is {load_version_str}", + ): + get_version_mock.return_value = load_version + _ = Dummy.load(path) diff --git a/tests/test_pipeline/test_mixins.py b/tests/test_pipeline/test_mixins.py index 5cb4f3148..487c2fb62 100644 --- a/tests/test_pipeline/test_mixins.py +++ b/tests/test_pipeline/test_mixins.py @@ -237,8 +237,7 @@ def test_save_mixin_save(example_tsds, tmp_path): dir_path = pathlib.Path(tmp_path) path = dir_path / "dummy.zip" - initial_ts = deepcopy(example_tsds) - initial_model = deepcopy(model) + initial_dummy = deepcopy(dummy) initial_transforms = deepcopy(transforms) dummy.save(path) @@ -259,8 +258,9 @@ def test_save_mixin_save(example_tsds, tmp_path): assert loaded_obj.b == dummy.b # check that we didn't break dummy object itself - assert pickle.dumps(dummy.ts) == pickle.dumps(initial_ts) - assert pickle.dumps(dummy.model) == pickle.dumps(initial_model) + assert dummy.a == initial_dummy.a + assert pickle.dumps(dummy.ts) == pickle.dumps(initial_dummy.ts) + assert pickle.dumps(dummy.model) == pickle.dumps(initial_dummy.model) assert pickle.dumps(dummy.transforms) == pickle.dumps(initial_transforms) From 53ab49c749fdc5347ee438a1e8edb9a870f4fa20 Mon Sep 17 00:00:00 2001 From: "d.a.bunin" Date: Mon, 19 Dec 2022 14:07:56 +0300 Subject: [PATCH 2/6] Add SaveEnsembleMixin into ensembles, fix tests --- etna/ensembles/direct_ensemble.py | 5 ++- etna/ensembles/mixins.py | 25 +----------- etna/ensembles/stacking_ensemble.py | 5 ++- etna/ensembles/voting_ensemble.py | 5 ++- etna/pipeline/mixins.py | 4 ++ tests/test_ensembles/test_direct_ensemble.py | 40 ++++++++++--------- .../test_ensembles/test_stacking_ensemble.py | 5 +++ tests/test_ensembles/test_voting_ensemble.py | 5 +++ tests/test_pipeline/utils.py | 11 +++-- 9 files changed, 53 insertions(+), 52 deletions(-) diff --git a/etna/ensembles/direct_ensemble.py b/etna/ensembles/direct_ensemble.py index 9c0c0a580..4b20c7e50 100644 --- a/etna/ensembles/direct_ensemble.py +++ b/etna/ensembles/direct_ensemble.py @@ -11,11 +11,12 @@ from joblib import delayed from etna.datasets import TSDataset -from etna.ensembles import EnsembleMixin +from etna.ensembles.mixins import EnsembleMixin +from etna.ensembles.mixins import SaveEnsembleMixin from etna.pipeline.base import BasePipeline -class DirectEnsemble(EnsembleMixin, BasePipeline): +class DirectEnsemble(EnsembleMixin, SaveEnsembleMixin, BasePipeline): """DirectEnsemble is a pipeline that forecasts future values merging the forecasts of base pipelines. Ensemble expects several pipelines during init. These pipelines are expected to have different forecasting horizons. diff --git a/etna/ensembles/mixins.py b/etna/ensembles/mixins.py index 55c5fabe5..997382e56 100644 --- a/etna/ensembles/mixins.py +++ b/etna/ensembles/mixins.py @@ -5,7 +5,6 @@ from typing import Any from typing import List from typing import Optional -from typing import Sequence import pandas as pd @@ -13,7 +12,6 @@ from etna.core import load from etna.datasets import TSDataset from etna.loggers import tslogger -from etna.pipeline.base import AbstractPipeline from etna.pipeline.base import BasePipeline @@ -63,27 +61,6 @@ def _predict_pipeline( tslogger.log(msg=f"Prediction is done with {pipeline}.") return prediction - def save(self, path: pathlib.Path): - """Save the object. - - Parameters - ---------- - path: - Path to save object to. - """ - raise NotImplementedError() - - @classmethod - def load(cls, path: pathlib.Path) -> Any: - """Load an object. - - Parameters - ---------- - path: - Path to load object from. - """ - raise NotImplementedError() - class SaveEnsembleMixin(SaveMixin): """Implementation of ``AbstractSaveable`` abstract class for ensemble pipelines. @@ -97,7 +74,7 @@ class SaveEnsembleMixin(SaveMixin): * pipelines: folder with saved pipelines. """ - pipelines: Sequence[AbstractPipeline] + pipelines: List[BasePipeline] ts: Optional[TSDataset] def save(self, path: pathlib.Path): diff --git a/etna/ensembles/stacking_ensemble.py b/etna/ensembles/stacking_ensemble.py index 7361391c9..db5fd31e5 100644 --- a/etna/ensembles/stacking_ensemble.py +++ b/etna/ensembles/stacking_ensemble.py @@ -19,13 +19,14 @@ from typing_extensions import Literal from etna.datasets import TSDataset -from etna.ensembles import EnsembleMixin +from etna.ensembles.mixins import EnsembleMixin +from etna.ensembles.mixins import SaveEnsembleMixin from etna.loggers import tslogger from etna.metrics import MAE from etna.pipeline.base import BasePipeline -class StackingEnsemble(EnsembleMixin, BasePipeline): +class StackingEnsemble(EnsembleMixin, SaveEnsembleMixin, BasePipeline): """StackingEnsemble is a pipeline that forecast future using the metamodel to combine the forecasts of the base models. Examples diff --git a/etna/ensembles/voting_ensemble.py b/etna/ensembles/voting_ensemble.py index 0bbc29307..effc1edfa 100644 --- a/etna/ensembles/voting_ensemble.py +++ b/etna/ensembles/voting_ensemble.py @@ -15,13 +15,14 @@ from etna.analysis.feature_relevance.relevance_table import TreeBasedRegressor from etna.datasets import TSDataset -from etna.ensembles import EnsembleMixin +from etna.ensembles.mixins import EnsembleMixin +from etna.ensembles.mixins import SaveEnsembleMixin from etna.loggers import tslogger from etna.metrics import MAE from etna.pipeline.base import BasePipeline -class VotingEnsemble(EnsembleMixin, BasePipeline): +class VotingEnsemble(EnsembleMixin, SaveEnsembleMixin, BasePipeline): """VotingEnsemble is a pipeline that forecast future values with weighted averaging of it's pipelines forecasts. Examples diff --git a/etna/pipeline/mixins.py b/etna/pipeline/mixins.py index 17ba4b3de..4f3807ac9 100644 --- a/etna/pipeline/mixins.py +++ b/etna/pipeline/mixins.py @@ -190,4 +190,8 @@ def load(cls, path: pathlib.Path, ts: Optional[TSDataset] = None) -> Any: obj.transforms = transforms + # set transforms in ts + if obj.ts is not None: + obj.ts.transforms = transforms + return obj diff --git a/tests/test_ensembles/test_direct_ensemble.py b/tests/test_ensembles/test_direct_ensemble.py index c48600b04..031b8efdd 100644 --- a/tests/test_ensembles/test_direct_ensemble.py +++ b/tests/test_ensembles/test_direct_ensemble.py @@ -8,6 +8,18 @@ from etna.ensembles import DirectEnsemble from etna.models import NaiveModel from etna.pipeline import Pipeline +from tests.test_pipeline.utils import assert_pipeline_equals_loaded_original + + +@pytest.fixture +def direct_ensemble_pipeline() -> DirectEnsemble: + ensemble = DirectEnsemble( + pipelines=[ + Pipeline(model=NaiveModel(lag=1), transforms=[], horizon=1), + Pipeline(model=NaiveModel(lag=3), transforms=[], horizon=2), + ] + ) + return ensemble @pytest.fixture @@ -36,32 +48,24 @@ def test_get_horizon_raise_error_on_same_horizons(): _ = DirectEnsemble(pipelines=[Mock(horizon=1), Mock(horizon=1)]) -def test_forecast(simple_ts_train, simple_ts_forecast): - ensemble = DirectEnsemble( - pipelines=[ - Pipeline(model=NaiveModel(lag=1), transforms=[], horizon=1), - Pipeline(model=NaiveModel(lag=3), transforms=[], horizon=2), - ] - ) - ensemble.fit(simple_ts_train) - forecast = ensemble.forecast() +def test_forecast(direct_ensemble_pipeline, simple_ts_train, simple_ts_forecast): + direct_ensemble_pipeline.fit(simple_ts_train) + forecast = direct_ensemble_pipeline.forecast() pd.testing.assert_frame_equal(forecast.to_pandas(), simple_ts_forecast.to_pandas()) -def test_predict(simple_ts_train): - ensemble = DirectEnsemble( - pipelines=[ - Pipeline(model=NaiveModel(lag=1), transforms=[], horizon=1), - Pipeline(model=NaiveModel(lag=3), transforms=[], horizon=2), - ] - ) +def test_predict(direct_ensemble_pipeline, simple_ts_train): smallest_pipeline = Pipeline(model=NaiveModel(lag=1), transforms=[], horizon=1) - ensemble.fit(simple_ts_train) + direct_ensemble_pipeline.fit(simple_ts_train) smallest_pipeline.fit(simple_ts_train) - prediction = ensemble.predict( + prediction = direct_ensemble_pipeline.predict( ts=simple_ts_train, start_timestamp=simple_ts_train.index[1], end_timestamp=simple_ts_train.index[2] ) expected_prediction = smallest_pipeline.predict( ts=simple_ts_train, start_timestamp=simple_ts_train.index[1], end_timestamp=simple_ts_train.index[2] ) pd.testing.assert_frame_equal(prediction.to_pandas(), expected_prediction.to_pandas()) + + +def test_save_load(direct_ensemble_pipeline, example_tsds): + assert_pipeline_equals_loaded_original(pipeline=direct_ensemble_pipeline, ts=example_tsds) diff --git a/tests/test_ensembles/test_stacking_ensemble.py b/tests/test_ensembles/test_stacking_ensemble.py index a8fe8add4..c10ae042c 100644 --- a/tests/test_ensembles/test_stacking_ensemble.py +++ b/tests/test_ensembles/test_stacking_ensemble.py @@ -14,6 +14,7 @@ from etna.ensembles.stacking_ensemble import StackingEnsemble from etna.metrics import MAE from etna.pipeline import Pipeline +from tests.test_pipeline.utils import assert_pipeline_equals_loaded_original HORIZON = 7 @@ -317,3 +318,7 @@ def test_forecast_raise_error_if_not_fitted(naive_ensemble: StackingEnsemble): """Test that StackingEnsemble raise error when calling forecast without being fit.""" with pytest.raises(ValueError, match="StackingEnsemble is not fitted!"): _ = naive_ensemble.forecast() + + +def test_save_load(stacking_ensemble_pipeline, example_tsds): + assert_pipeline_equals_loaded_original(pipeline=stacking_ensemble_pipeline, ts=example_tsds) diff --git a/tests/test_ensembles/test_voting_ensemble.py b/tests/test_ensembles/test_voting_ensemble.py index 996834bc8..72d8d3d8a 100644 --- a/tests/test_ensembles/test_voting_ensemble.py +++ b/tests/test_ensembles/test_voting_ensemble.py @@ -15,6 +15,7 @@ from etna.ensembles.voting_ensemble import VotingEnsemble from etna.metrics import MAE from etna.pipeline import Pipeline +from tests.test_pipeline.utils import assert_pipeline_equals_loaded_original HORIZON = 7 @@ -194,3 +195,7 @@ def test_backtest(voting_ensemble_pipeline: VotingEnsemble, example_tsds: TSData results = voting_ensemble_pipeline.backtest(ts=example_tsds, metrics=[MAE()], n_jobs=n_jobs, n_folds=3) for df in results: assert isinstance(df, pd.DataFrame) + + +def test_save_load(voting_ensemble_pipeline, example_tsds): + assert_pipeline_equals_loaded_original(pipeline=voting_ensemble_pipeline, ts=example_tsds) diff --git a/tests/test_pipeline/utils.py b/tests/test_pipeline/utils.py index 5613e1352..506c286d2 100644 --- a/tests/test_pipeline/utils.py +++ b/tests/test_pipeline/utils.py @@ -1,5 +1,6 @@ import pathlib import tempfile +from copy import deepcopy from typing import Tuple import pandas as pd @@ -8,13 +9,13 @@ from etna.pipeline.base import AbstractPipeline -def get_loaded_pipeline(pipeline: AbstractPipeline) -> AbstractPipeline: +def get_loaded_pipeline(pipeline: AbstractPipeline, ts: TSDataset) -> AbstractPipeline: with tempfile.TemporaryDirectory() as dir_path_str: dir_path = pathlib.Path(dir_path_str) path = dir_path.joinpath("dummy.zip") pipeline.save(path) - loaded_model = pipeline.load(path, ts=pipeline.ts) - return loaded_model + loaded_pipeline = pipeline.load(path, ts=ts) + return loaded_pipeline def assert_pipeline_equals_loaded_original( @@ -22,11 +23,13 @@ def assert_pipeline_equals_loaded_original( ) -> Tuple[AbstractPipeline, AbstractPipeline]: import torch # TODO: remove after fix at issue-802 + initial_ts = deepcopy(ts) + pipeline.fit(ts) torch.manual_seed(11) forecast_ts_1 = pipeline.forecast() - loaded_pipeline = get_loaded_pipeline(pipeline) + loaded_pipeline = get_loaded_pipeline(pipeline, ts=initial_ts) torch.manual_seed(11) forecast_ts_2 = loaded_pipeline.forecast() From d75ba1a4ad59371742858f422c208be879566040 Mon Sep 17 00:00:00 2001 From: "d.a.bunin" Date: Mon, 19 Dec 2022 14:12:46 +0300 Subject: [PATCH 3/6] Fix test for SaveModelPipelineMixin --- tests/test_pipeline/test_mixins.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test_pipeline/test_mixins.py b/tests/test_pipeline/test_mixins.py index 487c2fb62..9555ab3d9 100644 --- a/tests/test_pipeline/test_mixins.py +++ b/tests/test_pipeline/test_mixins.py @@ -303,6 +303,7 @@ def test_save_mixin_load_ok_with_ts(example_tsds, recwarn, tmp_path): assert loaded_dummy.a == dummy.a assert loaded_dummy.b == dummy.b assert loaded_dummy.ts is not example_tsds + assert loaded_dummy.ts.transforms is loaded_dummy.transforms pd.testing.assert_frame_equal(loaded_dummy.ts.to_pandas(), dummy.ts.to_pandas()) assert isinstance(loaded_dummy.model, NaiveModel) assert [transform.value for transform in loaded_dummy.transforms] == transform_values From 21d54074b330505b46b54b97ed12df46575fcff6 Mon Sep 17 00:00:00 2001 From: "d.a.bunin" Date: Mon, 19 Dec 2022 14:16:34 +0300 Subject: [PATCH 4/6] Update changelog --- CHANGELOG.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index f5b66a0fc..aff2112d5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,13 +9,13 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## Unreleased ### Added - -- Add `SaveModelPipelineMixin`, add `load`, add saving for `Pipeline` and `AutoRegressivePipeline` ([#1036](https://github.com/tinkoff-ai/etna/pull/1036)) +- Add `SaveModelPipelineMixin`, add `load`, add saving and loading for `Pipeline` and `AutoRegressivePipeline` ([#1036](https://github.com/tinkoff-ai/etna/pull/1036)) - Add `SaveMixin` to models and transforms ([#1007](https://github.com/tinkoff-ai/etna/pull/1007)) - Add `plot_change_points_interactive` ([#988](https://github.com/tinkoff-ai/etna/pull/988)) - Add `experimental` module with `TimeSeriesBinaryClassifier` and `PredictabilityAnalyzer` ([#985](https://github.com/tinkoff-ai/etna/pull/985)) - Inference track results: add `predict` method to pipelines, teach some models to work with context, change hierarchy of base models, update notebook examples ([#979](https://github.com/tinkoff-ai/etna/pull/979)) - Add `get_ruptures_regularization` into `experimental` module ([#1001](https://github.com/tinkoff-ai/etna/pull/1001)) -- +- Add `SaveEnsembleMixin`, add saving and loading for `VotingEnsemble`, `StackingEnsemble` and `DirectEnsemble` ([#1046](https://github.com/tinkoff-ai/etna/pull/1046)) ### Changed - - Change returned model in get_model of BATSModel, TBATSModel ([#987](https://github.com/tinkoff-ai/etna/pull/987)) From d7060048296c566d1ef56c68abcd087cf86453f9 Mon Sep 17 00:00:00 2001 From: "d.a.bunin" Date: Wed, 21 Dec 2022 10:30:36 +0300 Subject: [PATCH 5/6] Fix number of digits --- etna/ensembles/mixins.py | 2 +- etna/pipeline/mixins.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/etna/ensembles/mixins.py b/etna/ensembles/mixins.py index 997382e56..256ddd602 100644 --- a/etna/ensembles/mixins.py +++ b/etna/ensembles/mixins.py @@ -105,7 +105,7 @@ def save(self, path: pathlib.Path): # save transforms separately pipelines_dir = temp_dir / "pipelines" pipelines_dir.mkdir() - num_digits = len(str(len(pipelines) - 1)) + num_digits = 8 for i, pipeline in enumerate(pipelines): save_name = f"{i:0{num_digits}d}.zip" pipeline_save_path = pipelines_dir / save_name diff --git a/etna/pipeline/mixins.py b/etna/pipeline/mixins.py index 4f3807ac9..b0954ca1a 100644 --- a/etna/pipeline/mixins.py +++ b/etna/pipeline/mixins.py @@ -144,7 +144,7 @@ def save(self, path: pathlib.Path): # save transforms separately transforms_dir = temp_dir / "transforms" transforms_dir.mkdir() - num_digits = len(str(len(transforms) - 1)) + num_digits = 8 for i, transform in enumerate(transforms): save_name = f"{i:0{num_digits}d}.zip" transform_save_path = transforms_dir / save_name From 68283a7fb179cfa08293a04e14c4e5a673cc42cb Mon Sep 17 00:00:00 2001 From: "d.a.bunin" Date: Wed, 21 Dec 2022 11:02:06 +0300 Subject: [PATCH 6/6] Fix tests, reformat code --- tests/test_ensembles/test_mixins.py | 4 +++- tests/test_pipeline/test_mixins.py | 2 +- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/test_ensembles/test_mixins.py b/tests/test_ensembles/test_mixins.py index 3854e53de..95a4d7c1c 100644 --- a/tests/test_ensembles/test_mixins.py +++ b/tests/test_ensembles/test_mixins.py @@ -53,7 +53,9 @@ def test_save_mixin_save(example_tsds, tmp_path): with zipfile.ZipFile(path, "r") as archive: files = archive.namelist() - assert sorted(files) == sorted(["metadata.json", "object.pkl", "pipelines/0.zip", "pipelines/1.zip"]) + assert sorted(files) == sorted( + ["metadata.json", "object.pkl", "pipelines/00000000.zip", "pipelines/00000001.zip"] + ) with archive.open("metadata.json", "r") as input_file: metadata_bytes = input_file.read() diff --git a/tests/test_pipeline/test_mixins.py b/tests/test_pipeline/test_mixins.py index 9555ab3d9..284314ca2 100644 --- a/tests/test_pipeline/test_mixins.py +++ b/tests/test_pipeline/test_mixins.py @@ -243,7 +243,7 @@ def test_save_mixin_save(example_tsds, tmp_path): with zipfile.ZipFile(path, "r") as archive: files = archive.namelist() - assert sorted(files) == sorted(["metadata.json", "object.pkl", "model.zip", "transforms/0.zip"]) + assert sorted(files) == sorted(["metadata.json", "object.pkl", "model.zip", "transforms/00000000.zip"]) with archive.open("metadata.json", "r") as input_file: metadata_bytes = input_file.read()