Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Separating out get_map_values helper from MapUnitX transform #3313

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 27 additions & 7 deletions ax/modelbridge/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,13 @@
from ax.core.data import Data
from ax.core.experiment import Experiment
from ax.core.generator_run import extract_arm_predictions, GeneratorRun
from ax.core.map_data import MapData
from ax.core.observation import (
Observation,
ObservationData,
ObservationFeatures,
observations_from_data,
observations_from_map_data,
recombine_observations,
separate_observations,
)
Expand Down Expand Up @@ -112,6 +114,7 @@ def __init__(
fit_abandoned: bool = False,
fit_tracking_metrics: bool = True,
fit_on_init: bool = True,
fit_only_completed_map_metrics: bool = True,
) -> None:
"""
Applies transforms and fits model.
Expand Down Expand Up @@ -156,6 +159,9 @@ def __init__(
To fit the model afterwards, use `_process_and_transform_data`
to get the transformed inputs and call `_fit_if_implemented` with
the transformed inputs.
fit_only_completed_map_metrics: Whether to fit a model to map metrics only
when the trial is completed. This is useful for applications like
modeling partially completed learning curves in AutoML.
"""
t_fit_start = time.monotonic()
transforms = transforms or []
Expand Down Expand Up @@ -184,6 +190,7 @@ def __init__(
self._fit_out_of_design = fit_out_of_design
self._fit_abandoned = fit_abandoned
self._fit_tracking_metrics = fit_tracking_metrics
self._fit_only_completed_map_metrics = fit_only_completed_map_metrics
self.outcomes: list[str] = []
self._experiment_has_immutable_search_space_and_opt_config: bool = (
experiment is not None and experiment.immutable_search_space_and_opt_config
Expand Down Expand Up @@ -292,12 +299,21 @@ def _prepare_observations(
) -> list[Observation]:
if experiment is None or data is None:
return []
return observations_from_data(
experiment=experiment,
data=data,
statuses_to_include=self.statuses_to_fit,
statuses_to_include_map_metric=self.statuses_to_fit_map_metric,
)
if not self._fit_only_completed_map_metrics and isinstance(data, MapData):
return observations_from_map_data(
experiment=experiment,
map_data=data,
map_keys_as_parameters=True,
statuses_to_include=self.statuses_to_fit,
statuses_to_include_map_metric=self.statuses_to_fit_map_metric,
)
else:
return observations_from_data(
experiment=experiment,
data=data,
statuses_to_include=self.statuses_to_fit,
statuses_to_include_map_metric=self.statuses_to_fit_map_metric,
)

def _transform_data(
self,
Expand Down Expand Up @@ -557,7 +573,11 @@ def statuses_to_fit(self) -> set[TrialStatus]:
@property
def statuses_to_fit_map_metric(self) -> set[TrialStatus]:
"""Statuses to fit the model on."""
return {TrialStatus.COMPLETED}
return (
{TrialStatus.COMPLETED}
if self._fit_only_completed_map_metrics
else self.statuses_to_fit
)

@training_in_design.setter
def training_in_design(self, training_in_design: list[bool]) -> None:
Expand Down
44 changes: 44 additions & 0 deletions ax/modelbridge/tests/test_base_modelbridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from ax.core.arm import Arm
from ax.core.data import Data
from ax.core.experiment import Experiment
from ax.core.map_data import MapData
from ax.core.metric import Metric
from ax.core.objective import Objective, ScalarizedObjective
from ax.core.observation import ObservationData, ObservationFeatures
Expand Down Expand Up @@ -1040,3 +1041,46 @@ def test_SetModelSpace(self) -> None:
)
self.assertEqual(sum(m.training_in_design), 7)
self.assertEqual(m.model_space.parameters["x2"].upper, 20)

@mock.patch(
"ax.modelbridge.base.observations_from_map_data",
autospec=True,
return_value=([get_observation1()]),
)
@mock.patch(
"ax.modelbridge.base.observations_from_data",
autospec=True,
return_value=([get_observation1(), get_observation2()]),
)
def test_fit_only_completed_map_metrics(
self, mock_observations_from_data: Mock, mock_observations_from_map_data: Mock
) -> None:
# NOTE: If empty data object is not passed, observations are not
# extracted, even with mock.
# _prepare_observations is called in the constructor and itself calls
# observations_from_map_data.
Adapter(
search_space=get_search_space_for_value(),
model=0,
experiment=get_experiment_for_value(),
data=MapData(),
status_quo_name="1_1",
fit_only_completed_map_metrics=False,
)
self.assertTrue(mock_observations_from_map_data.called)
self.assertFalse(mock_observations_from_data.called)

# calling without map data calls regular observations_from_data even
# if fit_only_completed_map_metrics is False
mock_observations_from_data.reset_mock()
mock_observations_from_map_data.reset_mock()
Adapter(
search_space=get_search_space_for_value(),
model=0,
experiment=get_experiment_for_value(),
data=Data(),
status_quo_name="1_1",
fit_only_completed_map_metrics=False,
)
self.assertFalse(mock_observations_from_map_data.called)
self.assertTrue(mock_observations_from_data.called)
1 change: 1 addition & 0 deletions ax/modelbridge/tests/test_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,7 @@ def test_enum_sobol_legacy_GPEI(self) -> None:
"fit_tracking_metrics": True,
"fit_on_init": True,
"default_model_gen_options": None,
"fit_only_completed_map_metrics": True,
},
)
prior_kwargs = {"lengthscale_prior": GammaPrior(6.0, 6.0)}
Expand Down
2 changes: 2 additions & 0 deletions ax/modelbridge/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ def __init__(
fit_tracking_metrics: bool = True,
fit_on_init: bool = True,
default_model_gen_options: TConfig | None = None,
fit_only_completed_map_metrics: bool = True,
) -> None:
# This warning is being added while we are on 0.4.3, so it will be
# released in 0.4.4 or 0.5.0. The `torch_dtype` argument can be removed
Expand Down Expand Up @@ -161,6 +162,7 @@ def __init__(
fit_abandoned=fit_abandoned,
fit_tracking_metrics=fit_tracking_metrics,
fit_on_init=fit_on_init,
fit_only_completed_map_metrics=fit_only_completed_map_metrics,
)

def feature_importances(self, metric_name: str) -> dict[str, float]:
Expand Down
36 changes: 28 additions & 8 deletions ax/modelbridge/transforms/map_unit_x.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,14 +41,7 @@ def __init__(
) -> None:
assert observations is not None, "MapUnitX requires observations"
assert search_space is not None, "MapUnitX requires search space"
# Loop through observation features and identify parameters that
# are not part of the search space. Store all observed values to
# infer bounds
map_values = defaultdict(list)
for obs in observations:
for p in obs.features.parameters:
if p not in search_space.parameters:
map_values[p].append(obs.features.parameters[p])
map_values = get_map_values(search_space, observations)

# pyre-fixme[24]: Generic type `list` expects 1 type parameter, use
# `typing.List` to avoid runtime subscripting errors.
Expand Down Expand Up @@ -81,3 +74,30 @@ def untransform_observation_features(
scale_fac = (u - l) / self.target_range
obsf.parameters[p_name] = scale_fac * (param - self.target_lb) + l
return observation_features


def get_map_values(
search_space: SearchSpace,
observations: list[Observation],
) -> dict[str, list[float]]:
"""Computes a dictionary mapping the name of a map parameter to its associated
progression values, in the same order as they occur in the observations.

Args:
search_space: The search space.
observations: A list of observations associated with the search space.

Returns:
The dictionary mapping the name of a map metric to the associated values,
in the same order they occur in `observations`.
"""
# Loop through observation features and identify parameters that
# are not part of the search space. Store all observed values to
# infer bounds
map_values = defaultdict(list)
for obs in observations:
# if we had access to the original data object, could loop over data.map_keys
for p in obs.features.parameters:
if p not in search_space.parameters:
map_values[p].append(obs.features.parameters[p])
return map_values