Skip to content

Commit

Permalink
Merge pull request #699 from bouthilx/hotfix/asha_hyperband_identic_o…
Browse files Browse the repository at this point in the history
…bjectives

Fix ASHA for identical objectives
  • Loading branch information
bouthilx authored Nov 25, 2021
2 parents b30a478 + ed034de commit ed9a57a
Show file tree
Hide file tree
Showing 4 changed files with 68 additions and 4 deletions.
9 changes: 6 additions & 3 deletions src/orion/algo/asha.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,9 +232,12 @@ def get_candidates(self, rung_id):

rung = list(
sorted(
(objective, trial)
for objective, trial in rung.values()
if objective is not None
(
(objective, trial)
for objective, trial in rung.values()
if objective is not None
),
key=lambda item: item[0],
)
)
k = len(rung) // self.hyperband.reduction_factor
Expand Down
3 changes: 2 additions & 1 deletion src/orion/algo/hyperband.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,7 +341,8 @@ def suggest(self, num):
)
else:
logger.warning(
f"{self.__class__.__name__} cannot suggest new samples, exit."
f"{self.__class__.__name__} cannot suggest new samples and must wait "
"for trials to complete."
)

return []
Expand Down
30 changes: 30 additions & 0 deletions tests/unittests/algo/test_asha.py
Original file line number Diff line number Diff line change
Expand Up @@ -574,6 +574,36 @@ def test_suggest_promote_many_plus_random(
== 20 - 2 - 3 * 3
)

def test_suggest_promote_identic_objectives(
self, asha, bracket, big_rung_0, big_rung_1
):
"""Test that identic objectives are handled properly"""
asha.brackets = [bracket]
bracket.asha = asha

n_trials = 9
resources = 1

results = {}
for param in np.linspace(0, 8, 9):
trial = create_trial_for_hb((resources, param), objective=0)
trial_hash = trial.compute_trial_hash(
trial,
ignore_fidelity=True,
ignore_experiment=True,
)
results[trial_hash] = (trial.objective.value, trial)

bracket.rungs[0] = dict(n_trials=n_trials, resources=resources, results=results)

candidates = asha.suggest(2)

assert len(candidates) == 2
assert (
sum(1 for trial in candidates if trial.params[asha.fidelity_index] == 3)
== 2
)


class TestGenericASHA(BaseAlgoTests):
algo_name = "asha"
Expand Down
30 changes: 30 additions & 0 deletions tests/unittests/algo/test_hyperband.py
Original file line number Diff line number Diff line change
Expand Up @@ -584,6 +584,36 @@ def test_suggest_promote(self, hyperband, bracket, rung_0):
assert points[1].params == {"epoch": 3, "lr": 1}
assert points[2].params == {"epoch": 3, "lr": 2}

def test_suggest_promote_identic_objectives(self, hyperband, bracket):
"""Test that identic objectives are handled properly"""
hyperband.brackets = [bracket]
bracket.hyperband = hyperband

n_trials = 9
resources = 1

results = {}
for param in np.linspace(0, 8, 9):
trial = create_trial_for_hb((resources, param), objective=0)
trial_hash = trial.compute_trial_hash(
trial,
ignore_fidelity=True,
ignore_experiment=True,
)
results[trial_hash] = (trial.objective.value, trial)

bracket.rungs[0] = dict(n_trials=n_trials, resources=resources, results=results)

candidates = hyperband.suggest(2)

assert len(candidates) == 2
assert (
sum(
1 for trial in candidates if trial.params[hyperband.fidelity_index] == 3
)
== 2
)

def test_is_filled(self, hyperband, bracket, rung_0, rung_1, rung_2):
"""Test that Hyperband bracket detects when rung is filled."""
hyperband.brackets = [bracket]
Expand Down

0 comments on commit ed9a57a

Please sign in to comment.