Skip to content

Commit

Permalink
storage for auxiliary sources (#3305)
Browse files Browse the repository at this point in the history
Summary:

Adding proper encoder and decoder for auxiliary_experiments_by_purpose argument of Experiment object. Previously, auxiliary_experiments_by_purpose used to include only AuxiliaryExperiment object that had an easy encoder and decoder via experiment name. But after allowing AuxiliarySources to be added in auxiliary_experiments_by_purpose we need to encode and decode AuxiliarySource objects as well.

Reviewed By: saitcakmak

Differential Revision: D68542281
  • Loading branch information
Jelena Markovic-Voronov authored and facebook-github-bot committed Feb 4, 2025
1 parent a4c8fe5 commit a91fdfa
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 18 deletions.
53 changes: 37 additions & 16 deletions ax/storage/sqa_store/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from enum import Enum
from io import StringIO
from logging import Logger
from typing import cast, Union
from typing import Any, cast, Union

import pandas as pd
from ax.analysis.analysis import AnalysisCard
Expand Down Expand Up @@ -108,29 +108,25 @@ def _auxiliary_experiments_by_purpose_from_experiment_sqa(
) -> dict[AuxiliaryExperimentPurpose, list[AuxiliaryExperiment]] | None:
auxiliary_experiments_by_purpose = None
if experiment_sqa.auxiliary_experiments_by_purpose:
from ax.storage.sqa_store.load import load_experiment

auxiliary_experiments_by_purpose = {}
aux_exp_name_dict = none_throws(
experiment_sqa.auxiliary_experiments_by_purpose
)
for aux_exp_purpose_str, aux_exp_names in aux_exp_name_dict.items():
aux_exps_dict = none_throws(experiment_sqa.auxiliary_experiments_by_purpose)
for aux_exp_purpose_str, aux_exps_json in aux_exps_dict.items():
aux_exp_purpose = next(
member
for member in self.config.auxiliary_experiment_purpose_enum
if member.value == aux_exp_purpose_str
)
auxiliary_experiments_by_purpose[aux_exp_purpose] = []
for aux_exp_name in aux_exp_names:
for aux_exp_json in aux_exps_json:
# keeping this for backward compatibility since previously
# we used to save only the experiment name
if isinstance(aux_exp_json, str):
aux_exp_json = {"experiment_name": aux_exp_json}
aux_experiment = auxiliary_experiment_from_json(
json=aux_exp_json, config=self.config
)
auxiliary_experiments_by_purpose[aux_exp_purpose].append(
AuxiliaryExperiment(
experiment=load_experiment(
aux_exp_name,
config=self.config,
skip_runners_and_metrics=True,
load_auxiliary_experiments=False,
)
)
aux_experiment
)
return auxiliary_experiments_by_purpose

Expand Down Expand Up @@ -1321,3 +1317,28 @@ def _get_scalarized_outcome_constraint_children_metrics(
)
metrics_sqa = query.all()
return metrics_sqa


def auxiliary_experiment_from_json(
json: dict[str, Any], config: SQAConfig
) -> AuxiliaryExperiment:
"""
Load an ``AuxiliaryExperiment`` from JSON.
Args:
json: A dictionary containing the JSON representation of an AuxiliaryExperiment.
config: The SQAConfig object used to load the experiment.
Returns:
An AuxiliaryExperiment object constructed from the JSON representation.
"""

from ax.storage.sqa_store.load import load_experiment

experiment = load_experiment(
json.get("experiment_name"),
config=config,
skip_runners_and_metrics=True,
load_auxiliary_experiments=False,
)
return AuxiliaryExperiment(experiment)
8 changes: 7 additions & 1 deletion ax/storage/sqa_store/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,13 @@ def experiment_to_sqa(self, experiment: Experiment) -> SQAExperiment:
aux_exps,
) in experiment.auxiliary_experiments_by_purpose.items():
aux_exp_type = aux_exp_type_enum.value
aux_exp_jsons = [aux_exp.experiment.name for aux_exp in aux_exps]
aux_exp_jsons = [
{
"__type": aux_exp.__class__.__name__,
"experiment_name": aux_exp.experiment.name,
}
for aux_exp in aux_exps
]
auxiliary_experiments_by_purpose[aux_exp_type] = aux_exp_jsons

properties = experiment._properties
Expand Down
2 changes: 1 addition & 1 deletion ax/storage/sqa_store/sqa_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,7 +371,7 @@ class SQAExperiment(Base):
# pyre-fixme[8]: Incompatible attribute type [8]: Attribute
# `auxiliary_experiments_by_purpose` declared in class `SQAExperiment` has
# type `Optional[Dict[str, List[str]]]` but is used as type `Column[typing.Any]`
auxiliary_experiments_by_purpose: dict[str, List[str]] | None = Column(
auxiliary_experiments_by_purpose: dict[str, List[dict[str, Any]]] | None = Column(
JSONEncodedTextDict, nullable=True, default={}
)

Expand Down

0 comments on commit a91fdfa

Please sign in to comment.