diff --git a/ax/storage/json_store/encoder.py b/ax/storage/json_store/encoder.py index a49e53c2223..e8a1af4ae9e 100644 --- a/ax/storage/json_store/encoder.py +++ b/ax/storage/json_store/encoder.py @@ -11,6 +11,7 @@ import enum from collections import OrderedDict from collections.abc import Callable +from functools import partial from inspect import isclass from typing import Any @@ -28,19 +29,15 @@ from ax.utils.common.typeutils_torch import torch_type_to_str -# pyre-fixme[3]: Return annotation cannot be `Any`. -def object_to_json( # noqa C901 - # pyre-fixme[2]: Parameter annotation cannot be `Any`. +# pyre-ignore[3]: Missing return annotation +def object_to_json( + # pyre-ignore[2]: Missing parameter annotation obj: Any, - # pyre-fixme[2]: Parameter annotation cannot contain `Any`. - # pyre-fixme[24]: Generic type `type` expects 1 type parameter, use - # `typing.Type` to avoid runtime subscripting errors. + # pyre-ignore[2, 24]: Missing parameter annotation, Invalid type parameters encoder_registry: dict[ type, Callable[[Any], dict[str, Any]] ] = CORE_ENCODER_REGISTRY, - # pyre-fixme[2]: Parameter annotation cannot contain `Any`. - # pyre-fixme[24]: Generic type `type` expects 1 type parameter, use - # `typing.Type` to avoid runtime subscripting errors. + # pyre-ignore[2, 24]: Missing parameter annotation, Invalid type parameters class_encoder_registry: dict[ type, Callable[[Any], dict[str, Any]] ] = CORE_CLASS_ENCODER_REGISTRY, @@ -59,6 +56,11 @@ def object_to_json( # noqa C901 We then pass each item of the dictionary back into this function to recursively convert the entire object. """ + _object_to_json = partial( + object_to_json, + encoder_registry=encoder_registry, + class_encoder_registry=class_encoder_registry, + ) obj = numpy_type_to_python_type(obj) _type = type(obj) @@ -67,14 +69,7 @@ def object_to_json( # noqa C901 for class_type in class_encoder_registry: if issubclass(obj, class_type): obj_dict = class_encoder_registry[class_type](obj) - return { - k: object_to_json( - v, - encoder_registry=encoder_registry, - class_encoder_registry=class_encoder_registry, - ) - for k, v in obj_dict.items() - } + return {k: _object_to_json(v) for k, v in obj_dict.items()} raise ValueError( f"{obj} is a class. Add it to the CLASS_ENCODER_REGISTRY " @@ -83,87 +78,36 @@ def object_to_json( # noqa C901 if _type in encoder_registry: obj_dict = encoder_registry[_type](obj) - return { - k: object_to_json( - v, - encoder_registry=encoder_registry, - class_encoder_registry=class_encoder_registry, - ) - for k, v in obj_dict.items() - } - + return {k: _object_to_json(v) for k, v in obj_dict.items()} # Python built-in types + `typing` module types if _type in (str, int, float, bool, type(None)): return obj elif _type is list: - return [ - object_to_json( - x, - encoder_registry=encoder_registry, - class_encoder_registry=class_encoder_registry, - ) - for x in obj - ] + return [_object_to_json(x) for x in obj] elif _type is tuple: - return tuple( - object_to_json( - x, - encoder_registry=encoder_registry, - class_encoder_registry=class_encoder_registry, - ) - for x in obj - ) + return tuple(_object_to_json(x) for x in obj) elif _type is dict: - return { - k: object_to_json( - v, - encoder_registry=encoder_registry, - class_encoder_registry=class_encoder_registry, - ) - for k, v in obj.items() - } + return {k: _object_to_json(v) for k, v in obj.items()} elif _is_named_tuple(obj): return { "__type": _type.__name__, - **{ - k: object_to_json( - v, - encoder_registry=encoder_registry, - class_encoder_registry=class_encoder_registry, - ) - for k, v in obj._asdict().items() - }, + **{k: _object_to_json(v) for k, v in obj._asdict().items()}, } elif dataclasses.is_dataclass(obj): field_names = [f.name for f in dataclasses.fields(obj)] return { "__type": _type.__name__, **{ - k: object_to_json( - v, - encoder_registry=encoder_registry, - class_encoder_registry=class_encoder_registry, - ) + k: _object_to_json(v) for k, v in obj.__dict__.items() if k in field_names }, } - # Types from libraries, commonly used in Ax (e.g., numpy, pandas, torch) elif _type is OrderedDict: return { "__type": _type.__name__, - "value": [ - ( - k, - object_to_json( - v, - encoder_registry=encoder_registry, - class_encoder_registry=class_encoder_registry, - ), - ) - for k, v in obj.items() - ], + "value": [(k, _object_to_json(v)) for k, v in obj.items()], } elif _type is datetime.datetime: return { @@ -183,7 +127,6 @@ def object_to_json( # noqa C901 elif _type.__module__ == "torch": # Torch does not support saving to string, so save to buffer first return {"__type": f"torch_{_type.__name__}", "value": torch_type_to_str(obj)} - err = ( f"Object {obj} passed to `object_to_json` (of type {_type}, module: " f"{_type.__module__}) is not registered with a corresponding encoder "