Skip to content

Commit

Permalink
Separating out get_map_values helper from MapUnitX transform
Browse files Browse the repository at this point in the history
Summary: This commit separates out a `get_map_values` helper function from the `MapUnitX` transform.

Differential Revision: D69213291
  • Loading branch information
SebastianAment authored and facebook-github-bot committed Feb 6, 2025
1 parent ba33ba4 commit 6c6c507
Showing 1 changed file with 28 additions and 8 deletions.
36 changes: 28 additions & 8 deletions ax/modelbridge/transforms/map_unit_x.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,14 +41,7 @@ def __init__(
) -> None:
assert observations is not None, "MapUnitX requires observations"
assert search_space is not None, "MapUnitX requires search space"
# Loop through observation features and identify parameters that
# are not part of the search space. Store all observed values to
# infer bounds
map_values = defaultdict(list)
for obs in observations:
for p in obs.features.parameters:
if p not in search_space.parameters:
map_values[p].append(obs.features.parameters[p])
map_values = get_map_values(search_space, observations)

# pyre-fixme[24]: Generic type `list` expects 1 type parameter, use
# `typing.List` to avoid runtime subscripting errors.
Expand Down Expand Up @@ -81,3 +74,30 @@ def untransform_observation_features(
scale_fac = (u - l) / self.target_range
obsf.parameters[p_name] = scale_fac * (param - self.target_lb) + l
return observation_features


def get_map_values(
search_space: SearchSpace,
observations: list[Observation],
) -> dict[str, list[float]]:
"""Computes a dictionary mapping the name of a map parameter to its associated
progression values, in the same order as they occur in the observations.
Args:
search_space: The search space.
observations: A list of observations associated with the search space.
Returns:
The dictionary mapping the name of a map metric to the associated values,
in the same order they occur in `observations`.
"""
# Loop through observation features and identify parameters that
# are not part of the search space. Store all observed values to
# infer bounds
map_values = defaultdict(list)
for obs in observations:
# if we had access to the original data object, could loop over data.map_keys
for p in obs.features.parameters:
if p not in search_space.parameters:
map_values[p].append(obs.features.parameters[p])
return map_values

0 comments on commit 6c6c507

Please sign in to comment.