diff --git a/siatune/apis/tune.py b/siatune/apis/tune.py index 40bbc467..8486bc22 100644 --- a/siatune/apis/tune.py +++ b/siatune/apis/tune.py @@ -6,11 +6,7 @@ from mmcv.utils import Config from siatune.mm.tasks import BaseTask -from siatune.ray.callbacks import build_callback -from siatune.ray.schedulers import build_scheduler -from siatune.ray.searchers import build_searcher -from siatune.ray.spaces import build_space -from siatune.ray.stoppers import build_stopper +from siatune.ray import Tuner def tune(task_processor: BaseTask, tune_config: Config, @@ -29,51 +25,13 @@ def tune(task_processor: BaseTask, tune_config: Config, trainable_cfg = tune_config.get('trainable', dict()) trainable = task_processor.create_trainable(**trainable_cfg) - assert hasattr(tune_config, 'metric') - assert hasattr(tune_config, 'mode') and tune_config.mode in ['min', 'max'] - tune_artifact_dir = osp.join(tune_config.work_dir, 'artifact') mmcv.mkdir_or_exist(tune_artifact_dir) - stopper = tune_config.get('stop', None) - if stopper is not None: - stopper = build_stopper(stopper) - - space = tune_config.get('space', None) - if space is not None: - space = build_space(space) - - resources_per_trial = None - if not hasattr(trainable, 'default_resource_request'): - resources_per_trial = dict( - gpu=task_processor.num_workers * - task_processor.num_gpus_per_worker, - cpu=task_processor.num_workers * - task_processor.num_cpus_per_worker) - - searcher = tune_config.get('searcher', None) - if searcher is not None: - searcher = build_searcher(searcher) - - scheduler = tune_config.get('scheduler', None) - if scheduler is not None: - scheduler = build_scheduler(scheduler) + tuner = Tuner.from_cfg(tune_config, trainable) - callbacks = tune_config.get('callbacks', None) - if callbacks is not None: - callbacks = [build_callback(callback) for callback in callbacks] + # name=exp_name, + # local_dir=tune_artifact_dir, + # raise_on_failed_trial=tune_config.get('raise_on_failed_trial', False), - return ray.tune.run( - trainable, - name=exp_name, - metric=tune_config.metric, - mode=tune_config.mode, - stop=stopper, - config=space, - resources_per_trial=resources_per_trial, - num_samples=tune_config.get('num_samples', -1), - local_dir=tune_artifact_dir, - search_alg=searcher, - scheduler=scheduler, - raise_on_failed_trial=tune_config.get('raise_on_failed_trial', False), - callbacks=callbacks) + return tuner.fit() diff --git a/siatune/mm/tasks/mmtrainbase.py b/siatune/mm/tasks/mmtrainbase.py index fac8c4f8..227ea0b1 100644 --- a/siatune/mm/tasks/mmtrainbase.py +++ b/siatune/mm/tasks/mmtrainbase.py @@ -4,9 +4,9 @@ from functools import partial import mmcv -import ray import torch -from ray.tune.integration.torch import DistributedTrainableCreator +from ray.air.config import ScalingConfig +from ray.train.torch import TorchTrainer from .base import BaseTask from .builder import TASKS @@ -78,8 +78,7 @@ def context_aware_run(self, def create_trainable( self, backend: str = 'nccl', - timeout_s: int = 1800, - ) -> ray.tune.trainable: + ) -> TorchTrainer: """Get ray trainable task. Args: @@ -94,14 +93,9 @@ def create_trainable( assert backend in ['gloo', 'nccl'] - return DistributedTrainableCreator( - partial( - self.context_aware_run, - backend=backend, - ), - backend=backend, - timeout_s=timeout_s, - num_workers=self.num_workers, - num_gpus_per_worker=self.num_gpus_per_worker, - num_cpus_per_worker=self.num_cpus_per_worker, - ) + return TorchTrainer( + partial(self.context_aware_run, backend=backend), + scaling_config=ScalingConfig( + resources_per_worker=dict( + CPU=self.num_cpus_per_worker, + GPU=self.num_gpus_per_worker))) diff --git a/siatune/ray/__init__.py b/siatune/ray/__init__.py index 061afde0..cb03c07b 100644 --- a/siatune/ray/__init__.py +++ b/siatune/ray/__init__.py @@ -1,4 +1,9 @@ # Copyright (c) SI-Analytics. All rights reserved. +from .callbacks import * # noqa F403 from .schedulers import * # noqa F403 +from .searchers import * # noqa F403 from .spaces import * # noqa F403 from .stoppers import * # noqa F403 +from .tuner import Tuner + +__all__ = ['Tuner'] diff --git a/siatune/ray/tuner.py b/siatune/ray/tuner.py new file mode 100644 index 00000000..c470c5c8 --- /dev/null +++ b/siatune/ray/tuner.py @@ -0,0 +1,97 @@ +# Copyright (c) SI-Analytics. All rights reserved. +import copy +import os.path as osp + +from ray.air.config import RunConfig +from ray.tune.tune_config import TuneConfig +from ray.tune.tuner import Tuner as RayTuner + +from siatune.ray import (build_callback, build_scheduler, build_searcher, + build_space, build_stopper) + + +class Tuner: + """Wrapper class of :class:`ray.tune.tuner.Tuner`. + + Args: + trainable (Callable): + work_dir (str): + param_space (dict, optional): + tune_cfg (dict, optional): + Refer to https://github.com/ray-project/ray/blob/ray-2.1.0/python/ray/tune/tune_config.py for details. # noqa + searcher (dict, optional): + trial_scheduler (dict, optional): + stopper (dict, optional): + callbacks (list, optional): + """ + + def __init__( + self, + trainable, + work_dir, + param_space=None, + tune_cfg=None, + searcher=None, + trial_scheduler=None, + stopper=None, + callbacks=None, + ): + work_dir = osp.abspath(work_dir) + + if param_space is not None: + param_space = build_space(param_space) + + tune_cfg = copy.deepcopy(tune_cfg or dict()) + + if searcher is not None: + searcher = build_searcher(searcher) + + if trial_scheduler is not None: + trial_scheduler = build_scheduler(trial_scheduler) + + if stopper is not None: + stopper = build_stopper(stopper) + + if callbacks is not None: + if isinstance(callbacks, dict): + callbacks = [callbacks] + callbacks = [build_callback(callback) for callback in callbacks] + + self.tuner = RayTuner( + trainable, + param_space=param_space, + tune_config=TuneConfig( + searcher=searcher, trial_scheduler=trial_scheduler, + **tune_cfg), + run_config=RunConfig( + local_dir=work_dir, + stop=stopper, + callbacks=callbacks, + failure_config=None, # todo + sync_config=None, # todo + checkpoint_config=None, # todo + ), + ) + + @classmethod + def from_cfg(cls, cfg, trainable): + cfg = copy.deepcopy(cfg) + tuner = cls( + trainable, + work_dir=cfg['work_dir'], + param_space=cfg.get('space', None), + tune_cfg=cfg.get('tune_cfg', None), + searcher=cfg.get('searcher', None), + trial_scheduler=cfg.get('trial_scheduler', None), + stopper=cfg.get('stopper', None), + callbacks=cfg.get('callbacks', None), + ) + + return tuner + + @classmethod + def resume(cls, path, **kwargs): + return RayTuner.restore(path, **kwargs) + + def fit(self): + return self.tuner.fit()