Skip to content

Commit

Permalink
Fix Rotbaum serialization and deserialization (awslabs#3068)
Browse files Browse the repository at this point in the history
  • Loading branch information
pantanurag555 authored and lostella committed Dec 6, 2023
1 parent 3c434d8 commit 5a005e4
Show file tree
Hide file tree
Showing 4 changed files with 59 additions and 6 deletions.
9 changes: 8 additions & 1 deletion src/gluonts/ext/rotbaum/_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import gc
from collections import defaultdict

from gluonts.core.component import validated
from gluonts.core.component import equals, validated


class QRF:
Expand Down Expand Up @@ -121,6 +121,13 @@ def _create_xgboost_model(model_params: Optional[dict] = None):
}
return xgboost.sklearn.XGBModel(**model_params)

def __eq__(self, that):
"""
Two QRX instances are considered equal if they have the same
constructor arguments.
"""
return equals(self, that)

def fit(
self,
x_train: Union[pd.DataFrame, List],
Expand Down
27 changes: 27 additions & 0 deletions src/gluonts/ext/rotbaum/_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,14 @@

import concurrent.futures
import logging
import pickle
from itertools import chain
from typing import Iterator, List, Optional, Any, Dict
from toolz import first

import numpy as np
import pandas as pd
from pathlib import Path
from itertools import compress

from gluonts.core.component import validated
Expand Down Expand Up @@ -340,6 +342,31 @@ def predict( # type: ignore
item_id=ts.get("item_id"),
)

def serialize(self, path: Path) -> None:
"""
This function calls parent class serialize() in order to serialize
the class name, version information and constuctor arguments. It
persists the tree predictor by pickling the model list that is
generated when pickling the TreePredictor.
"""
super().serialize(path)
with (path / "predictor.pkl").open("wb") as f:
pickle.dump(self.model_list, f)

@classmethod
def deserialize(cls, path: Path, **kwargs: Any) -> "TreePredictor":
"""
This function loads and returns the serialized model. It loads
the predictor class with the serialized arguments. It then loads
the trained model list by reading the pickle file.
"""

predictor = super().deserialize(path)
assert isinstance(predictor, cls)
with (path / "predictor.pkl").open("rb") as f:
predictor.model_list = pickle.load(f)
return predictor

def explain(
self, importance_type: str = "gain", percentage: bool = True
) -> ExplanationResult:
Expand Down
5 changes: 4 additions & 1 deletion src/gluonts/ext/rotbaum/_preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -464,7 +464,10 @@ def make_features(self, time_series: Dict, starting_index: int) -> List:
if self.use_feat_static_real
else []
)
if self.cardinality:
if (
self.cardinality
and time_series.get("feat_static_cat", None) is not None
):
feat_static_cat = (
self.encode_one_hot_all(time_series["feat_static_cat"])
if self.one_hot_encode
Expand Down
24 changes: 20 additions & 4 deletions test/ext/rotbaum/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,11 @@
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.


from pathlib import Path
import pytest
import tempfile

from gluonts.ext.rotbaum import TreeEstimator
from gluonts.ext.rotbaum import TreeEstimator, TreePredictor


@pytest.fixture()
Expand All @@ -33,5 +34,20 @@ def test_accuracy(accuracy_test, hyperparameters, quantiles):
accuracy_test(TreeEstimator, hyperparameters, accuracy=0.20)


def test_serialize(serialize_test, hyperparameters):
serialize_test(TreeEstimator, hyperparameters)
def test_serialize(serialize_test, hyperparameters, dsinfo):
forecaster = TreeEstimator.from_hyperparameters(
freq=dsinfo.freq,
**{
"prediction_length": dsinfo.prediction_length,
"num_parallel_samples": dsinfo.num_parallel_samples,
},
**hyperparameters,
)

predictor_act = forecaster.train(dsinfo.train_ds)

with tempfile.TemporaryDirectory() as temp_dir:
predictor_act.serialize(Path(temp_dir))
predictor_exp = TreePredictor.deserialize(Path(temp_dir))
assert predictor_act == predictor_exp
assert predictor_act.model_list == predictor_exp.model_list

0 comments on commit 5a005e4

Please sign in to comment.