diff --git a/ax/storage/sqa_store/decoder.py b/ax/storage/sqa_store/decoder.py index c83c83636d4..a6e1a24d24f 100644 --- a/ax/storage/sqa_store/decoder.py +++ b/ax/storage/sqa_store/decoder.py @@ -11,7 +11,7 @@ from enum import Enum from io import StringIO from logging import Logger -from typing import cast, Union +from typing import Any, cast, Union import pandas as pd from ax.analysis.analysis import AnalysisCard @@ -108,29 +108,25 @@ def _auxiliary_experiments_by_purpose_from_experiment_sqa( ) -> dict[AuxiliaryExperimentPurpose, list[AuxiliaryExperiment]] | None: auxiliary_experiments_by_purpose = None if experiment_sqa.auxiliary_experiments_by_purpose: - from ax.storage.sqa_store.load import load_experiment - auxiliary_experiments_by_purpose = {} - aux_exp_name_dict = none_throws( - experiment_sqa.auxiliary_experiments_by_purpose - ) - for aux_exp_purpose_str, aux_exp_names in aux_exp_name_dict.items(): + aux_exps_dict = none_throws(experiment_sqa.auxiliary_experiments_by_purpose) + for aux_exp_purpose_str, aux_exps_json in aux_exps_dict.items(): aux_exp_purpose = next( member for member in self.config.auxiliary_experiment_purpose_enum if member.value == aux_exp_purpose_str ) auxiliary_experiments_by_purpose[aux_exp_purpose] = [] - for aux_exp_name in aux_exp_names: + for aux_exp_json in aux_exps_json: + # keeping this for backward compatibility since previously + # we used to save only the experiment name + if isinstance(aux_exp_json, str): + aux_exp_json = {"experiment_name": aux_exp_json} + aux_experiment = auxiliary_experiment_from_json( + json=aux_exp_json, config=self.config + ) auxiliary_experiments_by_purpose[aux_exp_purpose].append( - AuxiliaryExperiment( - experiment=load_experiment( - aux_exp_name, - config=self.config, - skip_runners_and_metrics=True, - load_auxiliary_experiments=False, - ) - ) + aux_experiment ) return auxiliary_experiments_by_purpose @@ -1321,3 +1317,28 @@ def _get_scalarized_outcome_constraint_children_metrics( ) metrics_sqa = query.all() return metrics_sqa + + +def auxiliary_experiment_from_json( + json: dict[str, Any], config: SQAConfig +) -> AuxiliaryExperiment: + """ + Load an ``AuxiliaryExperiment`` from JSON. + + Args: + json: A dictionary containing the JSON representation of an AuxiliaryExperiment. + config: The SQAConfig object used to load the experiment. + + Returns: + An AuxiliaryExperiment object constructed from the JSON representation. + """ + + from ax.storage.sqa_store.load import load_experiment + + experiment = load_experiment( + json.get("experiment_name"), + config=config, + skip_runners_and_metrics=True, + load_auxiliary_experiments=False, + ) + return AuxiliaryExperiment(experiment) diff --git a/ax/storage/sqa_store/encoder.py b/ax/storage/sqa_store/encoder.py index b348c3ec84f..76372fa5e92 100644 --- a/ax/storage/sqa_store/encoder.py +++ b/ax/storage/sqa_store/encoder.py @@ -188,7 +188,13 @@ def experiment_to_sqa(self, experiment: Experiment) -> SQAExperiment: aux_exps, ) in experiment.auxiliary_experiments_by_purpose.items(): aux_exp_type = aux_exp_type_enum.value - aux_exp_jsons = [aux_exp.experiment.name for aux_exp in aux_exps] + aux_exp_jsons = [ + { + "__type": aux_exp.__class__.__name__, + "experiment_name": aux_exp.experiment.name, + } + for aux_exp in aux_exps + ] auxiliary_experiments_by_purpose[aux_exp_type] = aux_exp_jsons properties = experiment._properties diff --git a/ax/storage/sqa_store/sqa_classes.py b/ax/storage/sqa_store/sqa_classes.py index 17175eca535..db6eb8a51fe 100644 --- a/ax/storage/sqa_store/sqa_classes.py +++ b/ax/storage/sqa_store/sqa_classes.py @@ -371,7 +371,7 @@ class SQAExperiment(Base): # pyre-fixme[8]: Incompatible attribute type [8]: Attribute # `auxiliary_experiments_by_purpose` declared in class `SQAExperiment` has # type `Optional[Dict[str, List[str]]]` but is used as type `Column[typing.Any]` - auxiliary_experiments_by_purpose: dict[str, List[str]] | None = Column( + auxiliary_experiments_by_purpose: dict[str, List[dict[str, Any]]] | None = Column( JSONEncodedTextDict, nullable=True, default={} )