Skip to content

Commit

Permalink
Add test case for ImmutableContainer
Browse files Browse the repository at this point in the history
  • Loading branch information
KKIEEK authored and KKIEEK committed Sep 5, 2022
1 parent dd8922b commit aeec90d
Showing 1 changed file with 66 additions and 0 deletions.
66 changes: 66 additions & 0 deletions tests/test_ray/test_searchers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from mmtune.ray.searchers import (SEARCHERS, AxSearch, BlendSearch, CFOSearch,
HyperOptSearch, NevergradSearch,
TrustRegionSearcher, build_searcher)
from mmtune.utils.container import ImmutableContainer


def test_build_searcher():
Expand All @@ -21,6 +22,11 @@ def config():
steps=10, width=tune.uniform(0, 20), height=tune.uniform(-100, 100))


@pytest.fixture
def config_im(config):
return ImmutableContainer(config)


@pytest.fixture
def trainable():

Expand All @@ -44,6 +50,16 @@ def test_ax(trainable, config):
config=config)


def test_ax_im(trainable, config_im):
tune.run(
trainable,
metric='mean_loss',
mode='min',
search_alg=AxSearch(),
num_samples=2,
config=config_im)


def test_blend(trainable, config):
tune.run(
trainable,
Expand All @@ -54,6 +70,16 @@ def test_blend(trainable, config):
config=config)


def test_blend_im(trainable, config_im):
tune.run(
trainable,
metric='mean_loss',
mode='min',
search_alg=BlendSearch(),
num_samples=2,
config=config_im)


def test_cfo(trainable, config):
tune.run(
trainable,
Expand All @@ -64,6 +90,16 @@ def test_cfo(trainable, config):
config=config)


def test_cfo_im(trainable, config_im):
tune.run(
trainable,
metric='mean_loss',
mode='min',
search_alg=CFOSearch(),
num_samples=2,
config=config_im)


def test_hyperopt(trainable, config):
tune.run(
trainable,
Expand All @@ -74,6 +110,16 @@ def test_hyperopt(trainable, config):
config=config)


def test_hyperopt_im(trainable, config_im):
tune.run(
trainable,
metric='mean_loss',
mode='min',
search_alg=HyperOptSearch(),
num_samples=2,
config=config_im)


def test_nevergrad(trainable, config):
tune.run(
trainable,
Expand All @@ -84,6 +130,16 @@ def test_nevergrad(trainable, config):
config=config)


def test_nevergrad_im(trainable, config_im):
tune.run(
trainable,
metric='mean_loss',
mode='min',
search_alg=NevergradSearch(optimizer='PSO', budget=2),
num_samples=2,
config=config_im)


def test_trust_region(trainable, config):
tune.run(
trainable,
Expand All @@ -92,3 +148,13 @@ def test_trust_region(trainable, config):
search_alg=TrustRegionSearcher(),
num_samples=2,
config=config)


def test_trust_region_im(trainable, config_im):
tune.run(
trainable,
metric='mean_loss',
mode='min',
search_alg=TrustRegionSearcher(),
num_samples=2,
config=config_im)

0 comments on commit aeec90d

Please sign in to comment.