diff --git a/src/gluonts/ext/rotbaum/_model.py b/src/gluonts/ext/rotbaum/_model.py index c2924a8171..29bd1e3cf7 100644 --- a/src/gluonts/ext/rotbaum/_model.py +++ b/src/gluonts/ext/rotbaum/_model.py @@ -20,7 +20,7 @@ import gc from collections import defaultdict -from gluonts.core.component import validated +from gluonts.core.component import equals, validated class QRF: @@ -121,6 +121,13 @@ def _create_xgboost_model(model_params: Optional[dict] = None): } return xgboost.sklearn.XGBModel(**model_params) + def __eq__(self, that): + """ + Two QRX instances are considered equal if they have the same + constructor arguments. + """ + return equals(self, that) + def fit( self, x_train: Union[pd.DataFrame, List], diff --git a/src/gluonts/ext/rotbaum/_predictor.py b/src/gluonts/ext/rotbaum/_predictor.py index 5a769b3fee..63daea1a97 100644 --- a/src/gluonts/ext/rotbaum/_predictor.py +++ b/src/gluonts/ext/rotbaum/_predictor.py @@ -13,12 +13,14 @@ import concurrent.futures import logging +import pickle from itertools import chain from typing import Iterator, List, Optional, Any, Dict from toolz import first import numpy as np import pandas as pd +from pathlib import Path from itertools import compress from gluonts.core.component import validated @@ -340,6 +342,31 @@ def predict( # type: ignore item_id=ts.get("item_id"), ) + def serialize(self, path: Path) -> None: + """ + This function calls parent class serialize() in order to serialize + the class name, version information and constuctor arguments. It + persists the tree predictor by pickling the model list that is + generated when pickling the TreePredictor. + """ + super().serialize(path) + with (path / "predictor.pkl").open("wb") as f: + pickle.dump(self.model_list, f) + + @classmethod + def deserialize(cls, path: Path, **kwargs: Any) -> "TreePredictor": + """ + This function loads and returns the serialized model. It loads + the predictor class with the serialized arguments. It then loads + the trained model list by reading the pickle file. + """ + + predictor = super().deserialize(path) + assert isinstance(predictor, cls) + with (path / "predictor.pkl").open("rb") as f: + predictor.model_list = pickle.load(f) + return predictor + def explain( self, importance_type: str = "gain", percentage: bool = True ) -> ExplanationResult: diff --git a/src/gluonts/ext/rotbaum/_preprocess.py b/src/gluonts/ext/rotbaum/_preprocess.py index 0f39095976..b068730829 100644 --- a/src/gluonts/ext/rotbaum/_preprocess.py +++ b/src/gluonts/ext/rotbaum/_preprocess.py @@ -464,7 +464,10 @@ def make_features(self, time_series: Dict, starting_index: int) -> List: if self.use_feat_static_real else [] ) - if self.cardinality: + if ( + self.cardinality + and time_series.get("feat_static_cat", None) is not None + ): feat_static_cat = ( self.encode_one_hot_all(time_series["feat_static_cat"]) if self.one_hot_encode diff --git a/test/ext/rotbaum/test_model.py b/test/ext/rotbaum/test_model.py index f4feaad2d9..51869034c7 100644 --- a/test/ext/rotbaum/test_model.py +++ b/test/ext/rotbaum/test_model.py @@ -11,10 +11,11 @@ # express or implied. See the License for the specific language governing # permissions and limitations under the License. - +from pathlib import Path import pytest +import tempfile -from gluonts.ext.rotbaum import TreeEstimator +from gluonts.ext.rotbaum import TreeEstimator, TreePredictor @pytest.fixture() @@ -33,5 +34,20 @@ def test_accuracy(accuracy_test, hyperparameters, quantiles): accuracy_test(TreeEstimator, hyperparameters, accuracy=0.20) -def test_serialize(serialize_test, hyperparameters): - serialize_test(TreeEstimator, hyperparameters) +def test_serialize(serialize_test, hyperparameters, dsinfo): + forecaster = TreeEstimator.from_hyperparameters( + freq=dsinfo.freq, + **{ + "prediction_length": dsinfo.prediction_length, + "num_parallel_samples": dsinfo.num_parallel_samples, + }, + **hyperparameters, + ) + + predictor_act = forecaster.train(dsinfo.train_ds) + + with tempfile.TemporaryDirectory() as temp_dir: + predictor_act.serialize(Path(temp_dir)) + predictor_exp = TreePredictor.deserialize(Path(temp_dir)) + assert predictor_act == predictor_exp + assert predictor_act.model_list == predictor_exp.model_list