-
Notifications
You must be signed in to change notification settings - Fork 7
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
KKIEEK
authored and
KKIEEK
committed
Dec 1, 2022
1 parent
a234607
commit f6e85e4
Showing
4 changed files
with
114 additions
and
70 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,9 @@ | ||
# Copyright (c) SI-Analytics. All rights reserved. | ||
from .callbacks import * # noqa F403 | ||
from .schedulers import * # noqa F403 | ||
from .searchers import * # noqa F403 | ||
from .spaces import * # noqa F403 | ||
from .stoppers import * # noqa F403 | ||
from .tuner import Tuner | ||
|
||
__all__ = ['Tuner'] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,97 @@ | ||
# Copyright (c) SI-Analytics. All rights reserved. | ||
import copy | ||
import os.path as osp | ||
|
||
from ray.air.config import RunConfig | ||
from ray.tune.tune_config import TuneConfig | ||
from ray.tune.tuner import Tuner as RayTuner | ||
|
||
from siatune.ray import (build_callback, build_scheduler, build_searcher, | ||
build_space, build_stopper) | ||
|
||
|
||
class Tuner: | ||
"""Wrapper class of :class:`ray.tune.tuner.Tuner`. | ||
Args: | ||
trainable (Callable): | ||
work_dir (str): | ||
param_space (dict, optional): | ||
tune_cfg (dict, optional): | ||
Refer to https://github.com/ray-project/ray/blob/ray-2.1.0/python/ray/tune/tune_config.py for details. # noqa | ||
searcher (dict, optional): | ||
trial_scheduler (dict, optional): | ||
stopper (dict, optional): | ||
callbacks (list, optional): | ||
""" | ||
|
||
def __init__( | ||
self, | ||
trainable, | ||
work_dir, | ||
param_space=None, | ||
tune_cfg=None, | ||
searcher=None, | ||
trial_scheduler=None, | ||
stopper=None, | ||
callbacks=None, | ||
): | ||
work_dir = osp.abspath(work_dir) | ||
|
||
if param_space is not None: | ||
param_space = build_space(param_space) | ||
|
||
tune_cfg = copy.deepcopy(tune_cfg or dict()) | ||
|
||
if searcher is not None: | ||
searcher = build_searcher(searcher) | ||
|
||
if trial_scheduler is not None: | ||
trial_scheduler = build_scheduler(trial_scheduler) | ||
|
||
if stopper is not None: | ||
stopper = build_stopper(stopper) | ||
|
||
if callbacks is not None: | ||
if isinstance(callbacks, dict): | ||
callbacks = [callbacks] | ||
callbacks = [build_callback(callback) for callback in callbacks] | ||
|
||
self.tuner = RayTuner( | ||
trainable, | ||
param_space=param_space, | ||
tune_config=TuneConfig( | ||
searcher=searcher, trial_scheduler=trial_scheduler, | ||
**tune_cfg), | ||
run_config=RunConfig( | ||
local_dir=work_dir, | ||
stop=stopper, | ||
callbacks=callbacks, | ||
failure_config=None, # todo | ||
sync_config=None, # todo | ||
checkpoint_config=None, # todo | ||
), | ||
) | ||
|
||
@classmethod | ||
def from_cfg(cls, cfg, trainable): | ||
cfg = copy.deepcopy(cfg) | ||
tuner = cls( | ||
trainable, | ||
work_dir=cfg['work_dir'], | ||
param_space=cfg.get('space', None), | ||
tune_cfg=cfg.get('tune_cfg', None), | ||
searcher=cfg.get('searcher', None), | ||
trial_scheduler=cfg.get('trial_scheduler', None), | ||
stopper=cfg.get('stopper', None), | ||
callbacks=cfg.get('callbacks', None), | ||
) | ||
|
||
return tuner | ||
|
||
@classmethod | ||
def resume(cls, path, **kwargs): | ||
return RayTuner.restore(path, **kwargs) | ||
|
||
def fit(self): | ||
return self.tuner.fit() |