Skip to content

Commit

Permalink
Make recursive calls in object_from_json less verbose (#3267)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #3267

As titled.

Reviewed By: Balandat

Differential Revision: D68582616

fbshipit-source-id: 0755088cd06d66d2f4c9a4f53626b073a91c5005
  • Loading branch information
ItsMrLin authored and facebook-github-bot committed Jan 26, 2025
1 parent cb5f5a3 commit 124f6e6
Showing 1 changed file with 35 additions and 83 deletions.
118 changes: 35 additions & 83 deletions ax/storage/json_store/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@

import datetime
from collections import OrderedDict
from dataclasses import dataclass
from enum import Enum
from functools import partial
from inspect import isclass
from io import StringIO
from logging import Logger
Expand Down Expand Up @@ -82,6 +84,12 @@
}


@dataclass
class RegistryKwargs:
decoder_registry: TDecoderRegistry
class_decoder_registry: TClassDecoderRegistry


# pyre-fixme[3]: Return annotation cannot be `Any`.
def object_from_json(
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
Expand All @@ -90,40 +98,26 @@ def object_from_json(
class_decoder_registry: TClassDecoderRegistry = CORE_CLASS_DECODER_REGISTRY,
) -> Any:
"""Recursively load objects from a JSON-serializable dictionary."""

registry_kwargs = RegistryKwargs(
decoder_registry=decoder_registry, class_decoder_registry=class_decoder_registry
)

_object_from_json = partial(object_from_json, **vars(registry_kwargs))

if type(object_json) in (str, int, float, bool, type(None)) or isinstance(
object_json, Enum
):
return object_json
elif isinstance(object_json, list):
return [
object_from_json(
i,
decoder_registry=decoder_registry,
class_decoder_registry=class_decoder_registry,
)
for i in object_json
]
return [_object_from_json(i) for i in object_json]
elif isinstance(object_json, tuple):
return tuple(
object_from_json(
i,
decoder_registry=decoder_registry,
class_decoder_registry=class_decoder_registry,
)
for i in object_json
)
return tuple(_object_from_json(i) for i in object_json)
elif isinstance(object_json, dict):
if "__type" not in object_json:
# this is just a regular dictionary, e.g. the one in Parameter
# containing parameterizations
return {
k: object_from_json(
v,
decoder_registry=decoder_registry,
class_decoder_registry=class_decoder_registry,
)
for k, v in object_json.items()
}
return {k: _object_from_json(v) for k, v in object_json.items()}

_type = object_json.pop("__type")

Expand All @@ -133,17 +127,7 @@ def object_from_json(
)
elif _type == "OrderedDict":
return OrderedDict(
[
(
k,
object_from_json(
v,
decoder_registry=decoder_registry,
class_decoder_registry=class_decoder_registry,
),
)
for k, v in object_json["value"]
]
[(k, _object_from_json(v)) for k, v in object_json["value"]]
)
elif _type == "DataFrame":
# Need dtype=False, otherwise infers arm_names like "4_1"
Expand All @@ -160,9 +144,7 @@ def object_from_json(
)
elif _type == "ListSurrogate":
return surrogate_from_list_surrogate_json(
list_surrogate_json=object_json,
decoder_registry=decoder_registry,
class_decoder_registry=class_decoder_registry,
list_surrogate_json=object_json, **vars(registry_kwargs)
)
elif _type == "set":
return set(object_json["value"])
Expand Down Expand Up @@ -191,67 +173,46 @@ def object_from_json(
return botorch_component_from_json(botorch_class=_class, json=object_json)
elif _class == GeneratorRun:
return generator_run_from_json(
object_json=object_json,
decoder_registry=decoder_registry,
class_decoder_registry=class_decoder_registry,
object_json=object_json, **vars(registry_kwargs)
)
elif _class == GenerationStep:
return generation_step_from_json(
generation_step_json=object_json,
decoder_registry=decoder_registry,
class_decoder_registry=class_decoder_registry,
generation_step_json=object_json, **vars(registry_kwargs)
)
elif _class == GenerationNode:
return generation_node_from_json(
generation_node_json=object_json,
decoder_registry=decoder_registry,
class_decoder_registry=class_decoder_registry,
generation_node_json=object_json, **vars(registry_kwargs)
)
elif _class == ModelSpec:
return model_spec_from_json(
model_spec_json=object_json,
decoder_registry=decoder_registry,
class_decoder_registry=class_decoder_registry,
model_spec_json=object_json, **vars(registry_kwargs)
)
elif _class == GenerationStrategy:
return generation_strategy_from_json(
generation_strategy_json=object_json,
decoder_registry=decoder_registry,
class_decoder_registry=class_decoder_registry,
generation_strategy_json=object_json, **vars(registry_kwargs)
)
elif _class == MultiTypeExperiment:
return multi_type_experiment_from_json(
object_json=object_json,
decoder_registry=decoder_registry,
class_decoder_registry=class_decoder_registry,
object_json=object_json, **vars(registry_kwargs)
)
elif _class == Experiment:
return experiment_from_json(
object_json=object_json,
decoder_registry=decoder_registry,
class_decoder_registry=class_decoder_registry,
object_json=object_json, **vars(registry_kwargs)
)
elif _class == SearchSpace:
return search_space_from_json(
search_space_json=object_json,
decoder_registry=decoder_registry,
class_decoder_registry=class_decoder_registry,
search_space_json=object_json, **vars(registry_kwargs)
)
elif _class == Objective:
return objective_from_json(
object_json=object_json,
decoder_registry=decoder_registry,
class_decoder_registry=class_decoder_registry,
)
return objective_from_json(object_json=object_json, **vars(registry_kwargs))
elif _class in (SurrogateSpec, Surrogate, ModelConfig):
if "input_transform" in object_json:
(
input_transform_classes_json,
input_transform_options_json,
) = get_input_transform_json_components(
input_transforms_json=object_json.pop("input_transform"),
decoder_registry=decoder_registry,
class_decoder_registry=class_decoder_registry,
**vars(registry_kwargs),
)
object_json["input_transform_classes"] = input_transform_classes_json
object_json["input_transform_options"] = input_transform_options_json
Expand All @@ -261,8 +222,7 @@ def object_from_json(
outcome_transform_options_json,
) = get_outcome_transform_json_components(
outcome_transforms_json=object_json.pop("outcome_transform"),
decoder_registry=decoder_registry,
class_decoder_registry=class_decoder_registry,
**vars(registry_kwargs),
)
object_json["outcome_transform_classes"] = (
outcome_transform_classes_json
Expand All @@ -280,8 +240,7 @@ def object_from_json(
return unpack_transition_criteria_from_json(
class_=_class,
transition_criteria_json=object_json,
decoder_registry=decoder_registry,
class_decoder_registry=class_decoder_registry,
**vars(registry_kwargs),
)
elif isclass(_class) and issubclass(_class, SerializationMixin):
return _class(
Expand All @@ -291,9 +250,7 @@ def object_from_json(
# another Ax class that needs serialization should implement its own
# _to_json and _from_json methods and register them appropriately.
**_class.deserialize_init_args(
args=object_json,
decoder_registry=decoder_registry,
class_decoder_registry=class_decoder_registry,
args=object_json, **vars(registry_kwargs)
)
)

Expand All @@ -302,16 +259,11 @@ def object_from_json(
# we want to have the input & outcome transform arguments updated
# before we call surrogate_spec_from_json.
return surrogate_spec_from_json(
surrogate_spec_json=object_json,
decoder_registry=decoder_registry,
class_decoder_registry=class_decoder_registry,
surrogate_spec_json=object_json, **vars(registry_kwargs)
)

return ax_class_from_json_dict(
_class=_class,
object_json=object_json,
decoder_registry=decoder_registry,
class_decoder_registry=class_decoder_registry,
_class=_class, object_json=object_json, **vars(registry_kwargs)
)
else:
err = (
Expand Down

0 comments on commit 124f6e6

Please sign in to comment.