Skip to content

Commit

Permalink
Init
Browse files Browse the repository at this point in the history
  • Loading branch information
KKIEEK authored and KKIEEK committed Dec 1, 2022
1 parent a234607 commit f6e85e4
Show file tree
Hide file tree
Showing 4 changed files with 114 additions and 70 deletions.
58 changes: 3 additions & 55 deletions siatune/apis/tune.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,10 @@
# Copyright (c) SI-Analytics. All rights reserved.
from os import path as osp

import mmcv
import ray
from mmcv.utils import Config

from siatune.mm.tasks import BaseTask
from siatune.ray.callbacks import build_callback
from siatune.ray.schedulers import build_scheduler
from siatune.ray.searchers import build_searcher
from siatune.ray.spaces import build_space
from siatune.ray.stoppers import build_stopper
from siatune.ray import Tuner


def tune(task_processor: BaseTask, tune_config: Config,
Expand All @@ -29,51 +23,5 @@ def tune(task_processor: BaseTask, tune_config: Config,
trainable_cfg = tune_config.get('trainable', dict())
trainable = task_processor.create_trainable(**trainable_cfg)

assert hasattr(tune_config, 'metric')
assert hasattr(tune_config, 'mode') and tune_config.mode in ['min', 'max']

tune_artifact_dir = osp.join(tune_config.work_dir, 'artifact')
mmcv.mkdir_or_exist(tune_artifact_dir)

stopper = tune_config.get('stop', None)
if stopper is not None:
stopper = build_stopper(stopper)

space = tune_config.get('space', None)
if space is not None:
space = build_space(space)

resources_per_trial = None
if not hasattr(trainable, 'default_resource_request'):
resources_per_trial = dict(
gpu=task_processor.num_workers *
task_processor.num_gpus_per_worker,
cpu=task_processor.num_workers *
task_processor.num_cpus_per_worker)

searcher = tune_config.get('searcher', None)
if searcher is not None:
searcher = build_searcher(searcher)

scheduler = tune_config.get('scheduler', None)
if scheduler is not None:
scheduler = build_scheduler(scheduler)

callbacks = tune_config.get('callbacks', None)
if callbacks is not None:
callbacks = [build_callback(callback) for callback in callbacks]

return ray.tune.run(
trainable,
name=exp_name,
metric=tune_config.metric,
mode=tune_config.mode,
stop=stopper,
config=space,
resources_per_trial=resources_per_trial,
num_samples=tune_config.get('num_samples', -1),
local_dir=tune_artifact_dir,
search_alg=searcher,
scheduler=scheduler,
raise_on_failed_trial=tune_config.get('raise_on_failed_trial', False),
callbacks=callbacks)
tuner = Tuner.from_cfg(tune_config, trainable)
return tuner.fit()
24 changes: 9 additions & 15 deletions siatune/mm/tasks/mmtrainbase.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@
from functools import partial

import mmcv
import ray
import torch
from ray.tune.integration.torch import DistributedTrainableCreator
from ray.air.config import ScalingConfig
from ray.train.torch import TorchTrainer

from .base import BaseTask
from .builder import TASKS
Expand Down Expand Up @@ -78,8 +78,7 @@ def context_aware_run(self,
def create_trainable(
self,
backend: str = 'nccl',
timeout_s: int = 1800,
) -> ray.tune.trainable:
) -> TorchTrainer:
"""Get ray trainable task.
Args:
Expand All @@ -94,14 +93,9 @@ def create_trainable(

assert backend in ['gloo', 'nccl']

return DistributedTrainableCreator(
partial(
self.context_aware_run,
backend=backend,
),
backend=backend,
timeout_s=timeout_s,
num_workers=self.num_workers,
num_gpus_per_worker=self.num_gpus_per_worker,
num_cpus_per_worker=self.num_cpus_per_worker,
)
return TorchTrainer(
partial(self.context_aware_run, backend=backend),
scaling_config=ScalingConfig(
resources_per_worker=dict(
CPU=self.num_cpus_per_worker,
GPU=self.num_gpus_per_worker)))
5 changes: 5 additions & 0 deletions siatune/ray/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,9 @@
# Copyright (c) SI-Analytics. All rights reserved.
from .callbacks import * # noqa F403
from .schedulers import * # noqa F403
from .searchers import * # noqa F403
from .spaces import * # noqa F403
from .stoppers import * # noqa F403
from .tuner import Tuner

__all__ = ['Tuner']
97 changes: 97 additions & 0 deletions siatune/ray/tuner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
# Copyright (c) SI-Analytics. All rights reserved.
import copy
import os.path as osp

from ray.air.config import RunConfig
from ray.tune.tune_config import TuneConfig
from ray.tune.tuner import Tuner as RayTuner

from siatune.ray import (build_callback, build_scheduler, build_searcher,
build_space, build_stopper)


class Tuner:
"""Wrapper class of :class:`ray.tune.tuner.Tuner`.
Args:
trainable (Callable):
work_dir (str):
param_space (dict, optional):
tune_cfg (dict, optional):
Refer to https://github.com/ray-project/ray/blob/ray-2.1.0/python/ray/tune/tune_config.py for details. # noqa
searcher (dict, optional):
trial_scheduler (dict, optional):
stopper (dict, optional):
callbacks (list, optional):
"""

def __init__(
self,
trainable,
work_dir,
param_space=None,
tune_cfg=None,
searcher=None,
trial_scheduler=None,
stopper=None,
callbacks=None,
):
work_dir = osp.abspath(work_dir)

if param_space is not None:
param_space = build_space(param_space)

tune_cfg = copy.deepcopy(tune_cfg or dict())

if searcher is not None:
searcher = build_searcher(searcher)

if trial_scheduler is not None:
trial_scheduler = build_scheduler(trial_scheduler)

if stopper is not None:
stopper = build_stopper(stopper)

if callbacks is not None:
if isinstance(callbacks, dict):
callbacks = [callbacks]
callbacks = [build_callback(callback) for callback in callbacks]

self.tuner = RayTuner(
trainable,
param_space=param_space,
tune_config=TuneConfig(
searcher=searcher, trial_scheduler=trial_scheduler,
**tune_cfg),
run_config=RunConfig(
local_dir=work_dir,
stop=stopper,
callbacks=callbacks,
failure_config=None, # todo
sync_config=None, # todo
checkpoint_config=None, # todo
),
)

@classmethod
def from_cfg(cls, cfg, trainable):
cfg = copy.deepcopy(cfg)
tuner = cls(
trainable,
work_dir=cfg['work_dir'],
param_space=cfg.get('space', None),
tune_cfg=cfg.get('tune_cfg', None),
searcher=cfg.get('searcher', None),
trial_scheduler=cfg.get('trial_scheduler', None),
stopper=cfg.get('stopper', None),
callbacks=cfg.get('callbacks', None),
)

return tuner

@classmethod
def resume(cls, path, **kwargs):
return RayTuner.restore(path, **kwargs)

def fit(self):
return self.tuner.fit()

0 comments on commit f6e85e4

Please sign in to comment.