diff --git a/ax/storage/json_store/decoder.py b/ax/storage/json_store/decoder.py index 202d51a2eae..f2af62835a7 100644 --- a/ax/storage/json_store/decoder.py +++ b/ax/storage/json_store/decoder.py @@ -8,7 +8,9 @@ import datetime from collections import OrderedDict +from dataclasses import dataclass from enum import Enum +from functools import partial from inspect import isclass from io import StringIO from logging import Logger @@ -82,6 +84,12 @@ } +@dataclass +class RegistryKwargs: + decoder_registry: TDecoderRegistry + class_decoder_registry: TClassDecoderRegistry + + # pyre-fixme[3]: Return annotation cannot be `Any`. def object_from_json( # pyre-fixme[2]: Parameter annotation cannot be `Any`. @@ -90,40 +98,26 @@ def object_from_json( class_decoder_registry: TClassDecoderRegistry = CORE_CLASS_DECODER_REGISTRY, ) -> Any: """Recursively load objects from a JSON-serializable dictionary.""" + + registry_kwargs = RegistryKwargs( + decoder_registry=decoder_registry, class_decoder_registry=class_decoder_registry + ) + + _object_from_json = partial(object_from_json, **vars(registry_kwargs)) + if type(object_json) in (str, int, float, bool, type(None)) or isinstance( object_json, Enum ): return object_json elif isinstance(object_json, list): - return [ - object_from_json( - i, - decoder_registry=decoder_registry, - class_decoder_registry=class_decoder_registry, - ) - for i in object_json - ] + return [_object_from_json(i) for i in object_json] elif isinstance(object_json, tuple): - return tuple( - object_from_json( - i, - decoder_registry=decoder_registry, - class_decoder_registry=class_decoder_registry, - ) - for i in object_json - ) + return tuple(_object_from_json(i) for i in object_json) elif isinstance(object_json, dict): if "__type" not in object_json: # this is just a regular dictionary, e.g. the one in Parameter # containing parameterizations - return { - k: object_from_json( - v, - decoder_registry=decoder_registry, - class_decoder_registry=class_decoder_registry, - ) - for k, v in object_json.items() - } + return {k: _object_from_json(v) for k, v in object_json.items()} _type = object_json.pop("__type") @@ -133,17 +127,7 @@ def object_from_json( ) elif _type == "OrderedDict": return OrderedDict( - [ - ( - k, - object_from_json( - v, - decoder_registry=decoder_registry, - class_decoder_registry=class_decoder_registry, - ), - ) - for k, v in object_json["value"] - ] + [(k, _object_from_json(v)) for k, v in object_json["value"]] ) elif _type == "DataFrame": # Need dtype=False, otherwise infers arm_names like "4_1" @@ -160,9 +144,7 @@ def object_from_json( ) elif _type == "ListSurrogate": return surrogate_from_list_surrogate_json( - list_surrogate_json=object_json, - decoder_registry=decoder_registry, - class_decoder_registry=class_decoder_registry, + list_surrogate_json=object_json, **vars(registry_kwargs) ) elif _type == "set": return set(object_json["value"]) @@ -191,58 +173,38 @@ def object_from_json( return botorch_component_from_json(botorch_class=_class, json=object_json) elif _class == GeneratorRun: return generator_run_from_json( - object_json=object_json, - decoder_registry=decoder_registry, - class_decoder_registry=class_decoder_registry, + object_json=object_json, **vars(registry_kwargs) ) elif _class == GenerationStep: return generation_step_from_json( - generation_step_json=object_json, - decoder_registry=decoder_registry, - class_decoder_registry=class_decoder_registry, + generation_step_json=object_json, **vars(registry_kwargs) ) elif _class == GenerationNode: return generation_node_from_json( - generation_node_json=object_json, - decoder_registry=decoder_registry, - class_decoder_registry=class_decoder_registry, + generation_node_json=object_json, **vars(registry_kwargs) ) elif _class == ModelSpec: return model_spec_from_json( - model_spec_json=object_json, - decoder_registry=decoder_registry, - class_decoder_registry=class_decoder_registry, + model_spec_json=object_json, **vars(registry_kwargs) ) elif _class == GenerationStrategy: return generation_strategy_from_json( - generation_strategy_json=object_json, - decoder_registry=decoder_registry, - class_decoder_registry=class_decoder_registry, + generation_strategy_json=object_json, **vars(registry_kwargs) ) elif _class == MultiTypeExperiment: return multi_type_experiment_from_json( - object_json=object_json, - decoder_registry=decoder_registry, - class_decoder_registry=class_decoder_registry, + object_json=object_json, **vars(registry_kwargs) ) elif _class == Experiment: return experiment_from_json( - object_json=object_json, - decoder_registry=decoder_registry, - class_decoder_registry=class_decoder_registry, + object_json=object_json, **vars(registry_kwargs) ) elif _class == SearchSpace: return search_space_from_json( - search_space_json=object_json, - decoder_registry=decoder_registry, - class_decoder_registry=class_decoder_registry, + search_space_json=object_json, **vars(registry_kwargs) ) elif _class == Objective: - return objective_from_json( - object_json=object_json, - decoder_registry=decoder_registry, - class_decoder_registry=class_decoder_registry, - ) + return objective_from_json(object_json=object_json, **vars(registry_kwargs)) elif _class in (SurrogateSpec, Surrogate, ModelConfig): if "input_transform" in object_json: ( @@ -250,8 +212,7 @@ def object_from_json( input_transform_options_json, ) = get_input_transform_json_components( input_transforms_json=object_json.pop("input_transform"), - decoder_registry=decoder_registry, - class_decoder_registry=class_decoder_registry, + **vars(registry_kwargs), ) object_json["input_transform_classes"] = input_transform_classes_json object_json["input_transform_options"] = input_transform_options_json @@ -261,8 +222,7 @@ def object_from_json( outcome_transform_options_json, ) = get_outcome_transform_json_components( outcome_transforms_json=object_json.pop("outcome_transform"), - decoder_registry=decoder_registry, - class_decoder_registry=class_decoder_registry, + **vars(registry_kwargs), ) object_json["outcome_transform_classes"] = ( outcome_transform_classes_json @@ -280,8 +240,7 @@ def object_from_json( return unpack_transition_criteria_from_json( class_=_class, transition_criteria_json=object_json, - decoder_registry=decoder_registry, - class_decoder_registry=class_decoder_registry, + **vars(registry_kwargs), ) elif isclass(_class) and issubclass(_class, SerializationMixin): return _class( @@ -291,9 +250,7 @@ def object_from_json( # another Ax class that needs serialization should implement its own # _to_json and _from_json methods and register them appropriately. **_class.deserialize_init_args( - args=object_json, - decoder_registry=decoder_registry, - class_decoder_registry=class_decoder_registry, + args=object_json, **vars(registry_kwargs) ) ) @@ -302,16 +259,11 @@ def object_from_json( # we want to have the input & outcome transform arguments updated # before we call surrogate_spec_from_json. return surrogate_spec_from_json( - surrogate_spec_json=object_json, - decoder_registry=decoder_registry, - class_decoder_registry=class_decoder_registry, + surrogate_spec_json=object_json, **vars(registry_kwargs) ) return ax_class_from_json_dict( - _class=_class, - object_json=object_json, - decoder_registry=decoder_registry, - class_decoder_registry=class_decoder_registry, + _class=_class, object_json=object_json, **vars(registry_kwargs) ) else: err = (