Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Refactor] argument logic #27

Merged
merged 4 commits into from
May 26, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion configs/mmtune/mmdet_asynchb_nevergrad_pso.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@
'data.samples_per_gpu': {{_base_.batch_size}},
}

task = dict(type='MMDetection')
metric = 'val/AP'
mode = 'max'
raise_on_failed_trial = False,
raise_on_failed_trial = False
num_samples = 256
3 changes: 2 additions & 1 deletion configs/mmtune/mmseg_asynchb_nevergrad_pso.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
'model.auxiliary_head.num_classes': dict(type='Constant', value=21),
}

task = dict(type='MMSegmentation')
metric = 'val/mIoU'
mode = 'max'
raise_on_failed_trial = False,
raise_on_failed_trial = False
num_samples = 256
26 changes: 18 additions & 8 deletions mmtune/apis/analysis.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,36 @@
import time
from os import path as osp
from pprint import pformat
from typing import Optional

from mmcv.utils import Config, get_logger
from ray import tune

from mmtune.utils import ImmutableContainer


def log_analysis(analysis: tune.ExperimentAnalysis, tune_config: Config,
task_config: Config, log_dir: str) -> None:
def log_analysis(analysis: tune.ExperimentAnalysis,
tune_config: Config,
task_config: Optional[Config] = None,
log_dir: Optional[str] = None) -> None:
log_dir = log_dir or tune_config.work_dir
with open(osp.join(log_dir, 'tune_config.py'), 'w', encoding='utf-8') as f:
f.write(tune_config.pretty_text)
with open(osp.join(log_dir, 'task_config.py'), 'w', encoding='utf-8') as f:
f.write(task_config.pretty_text)

if task_config is not None:
with open(
osp.join(log_dir, 'task_config.py'), 'w',
encoding='utf-8') as f:
f.write(task_config.pretty_text)

timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime())
logger = get_logger(
'mmtune', log_file=osp.join(log_dir, f'{timestamp}.log'))

logger.info(
('Best Hyperparam', ImmutableContainer.decouple(analysis.best_config)))
f'Best Hyperparam: \n'
f'{pformat(ImmutableContainer.decouple(analysis.best_config))}')
logger.info(
('Best Results', ImmutableContainer.decouple(analysis.best_result)))
logger.info(('Best Logdir', analysis.best_logdir))
logger.info(analysis.results)
f'Best Results: \n'
f'{pformat(ImmutableContainer.decouple(analysis.best_result))}')
logger.info(f'Best Logdir: {analysis.best_logdir}')
55 changes: 35 additions & 20 deletions mmtune/apis/tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,34 +14,49 @@
def tune(task_processor: BaseTask, tune_config: Config,
exp_name: str) -> ray.tune.ExperimentAnalysis:
trainable = task_processor.create_trainable(
**getattr(tune_config, 'trainable', dict()))
**tune_config.get('trainable', dict()))

assert hasattr(tune_config, 'metric')
assert hasattr(tune_config, 'mode') and tune_config.mode in ['min', 'max']

tune_artifact_dir = osp.join(task_processor.args.work_dir, 'artifact')
tune_artifact_dir = osp.join(tune_config.work_dir, 'artifact')
mmcv.mkdir_or_exist(tune_artifact_dir)

stopper = tune_config.get('stop', None)
if stopper is not None:
stopper = build_stopper(stopper)

space = tune_config.get('space', None)
if space is not None:
space = build_space(space)

resources_per_trial = None
if not hasattr(trainable, 'default_resource_request'):
num_workers = trainable.get('num_workers', 1)
num_gpus_per_worker = trainable.get('num_gpus_per_worker', 1)
num_cpus_per_worker = trainable.get('num_cpus_per_worker', 1)
resources_per_trial = dict(
gpu=num_workers * num_gpus_per_worker,
cpu=num_workers * num_cpus_per_worker)

searcher = tune_config.get('searcher', None)
if searcher is not None:
searcher = build_searcher(searcher)

scheduler = tune_config.get('scheduler', None)
if scheduler is not None:
scheduler = build_scheduler(scheduler)

return ray.tune.run(
trainable,
name=exp_name,
metric=tune_config.metric,
mode=tune_config.mode,
name=exp_name,
resources_per_trial=None
if hasattr(trainable, 'default_resource_request') else dict(
cpu=task_processor.args.num_workers * # noqa W504
task_processor.args.num_cpus_per_worker,
gpu=task_processor.args.num_workers * # noqa W504
task_processor.args.num_gpus_per_worker),
stop=build_stopper(tune_config.stop)
if hasattr(tune_config, 'stop') else None,
config=build_space(tune_config.space)
if hasattr(tune_config, 'space') else None,
num_samples=getattr(tune_config, 'num_samples', -1),
stop=stopper,
config=space,
resources_per_trial=resources_per_trial,
num_samples=tune_config.get('num_samples', -1),
local_dir=tune_artifact_dir,
search_alg=build_searcher(tune_config.searcher) if hasattr(
tune_config, 'searcher') else None,
scheduler=build_scheduler(tune_config.scheduler) if hasattr(
tune_config, 'scheduler') else None,
raise_on_failed_trial=getattr(tune_config, 'raise_on_failed_trial',
False))
search_alg=searcher,
scheduler=scheduler,
raise_on_failed_trial=tune_config.get('raise_on_failed_trial', False))
3 changes: 3 additions & 0 deletions mmtune/mm/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .context import * # noqa F403
from .hooks import * # noqa F403
from .tasks import * # noqa F403
17 changes: 5 additions & 12 deletions mmtune/mm/tasks/base.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
import argparse
from abc import ABCMeta, abstractmethod
from typing import List, Optional
from typing import List, Optional, Sequence

import ray
from mmcv.utils.config import Config

from mmtune.mm.context import ContextManager
from mmtune.utils import ImmutableContainer
Expand All @@ -19,23 +18,17 @@ def __init__(self):
self.args: Optional[argparse.Namespace] = None
self.rewriters: List[dict] = []

def set_base_cfg(self, base_cfg: Config) -> None:
self.base_cfg = ImmutableContainer(base_cfg, 'base')

def set_args(self, args: argparse.Namespace) -> None:
self.args = args
def set_args(self, args: Sequence[str]) -> None:
self.args = self.parse_args(args)

def set_rewriters(self, rewriters: List[dict] = []) -> None:
self.rewriters = rewriters

@abstractmethod
def add_arguments(
self,
parser: Optional[argparse.ArgumentParser] = None
) -> argparse.ArgumentParser:
def parse_args(self, args: Sequence[str]) -> argparse.Namespace:
pass

def contextaware_run(self, status, *args, **kwargs) -> None:
def context_aware_run(self, status, *args, **kwargs) -> None:
context_manager = ContextManager(**status)
return context_manager(self.run)(*args, **kwargs)

Expand Down
15 changes: 5 additions & 10 deletions mmtune/mm/tasks/blackbox.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import argparse
from abc import ABCMeta
from functools import partial
from typing import Callable, Optional
from typing import Callable, Sequence

from .base import BaseTask
from .builder import TASKS
Expand All @@ -10,18 +10,13 @@
@TASKS.register_module()
class BloackBoxTask(BaseTask, metaclass=ABCMeta):

def add_arguments(
self,
parser: Optional[argparse.ArgumentParser] = None
) -> argparse.ArgumentParser:

if parser is None:
parser = argparse.ArgumentParser(description='black box')
return parser
def parse_args(self, args: Sequence[str]) -> argparse.Namespace:
parser = argparse.ArgumentParser(description='black box')
return parser.parse_args(args)

def create_trainable(self) -> Callable:
return partial(
self.contextaware_run,
self.context_aware_run,
dict(
base_cfg=self.base_cfg,
args=self.args,
Expand Down
4 changes: 2 additions & 2 deletions mmtune/mm/tasks/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,5 @@
TASKS = Registry('tasks')


def build_task_processor(task_name: str):
return TASKS.build(dict(type=task_name))
def build_task_processor(task: dict):
return TASKS.build(task)
19 changes: 7 additions & 12 deletions mmtune/mm/tasks/mmdet.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import copy
import time
from os import path as osp
from typing import Optional
from typing import Optional, Sequence

import mmcv
import torch
Expand All @@ -17,13 +17,9 @@
@TASKS.register_module()
class MMDetection(MMTrainBasedTask):

def add_arguments(
self,
parser: Optional[argparse.ArgumentParser] = None
) -> argparse.ArgumentParser:

if parser is None:
parser = argparse.ArgumentParser(description='Train a detector')
def parse_args(self, args: Sequence[str]) -> argparse.Namespace:
parser = argparse.ArgumentParser(description='Train a detector')
parser.add_argument('config', help='train config file path')
parser.add_argument(
'--work-dir', help='the dir to save logs and models')
parser.add_argument(
Expand Down Expand Up @@ -63,13 +59,12 @@ def add_arguments(
'like key="[a,b]" or key=a,b It also allows nested list/tuple '
'values, e.g. key="[(a,b),(c,d)]" Note that the quotation marks '
'are necessary and that no white space is allowed.')
parser.add_argument('--local_rank', type=int, default=0)
parser.add_argument(
'--auto-scale-lr',
action='store_true',
help='enable automatically scaling LR.')

return parser
args = parser.parse_args(args)
return args

def build_model(self,
cfg: Config,
Expand Down Expand Up @@ -102,7 +97,7 @@ def run(self, *args, **kwargs):
from mmdet.apis import init_random_seed, set_random_seed
from mmdet.utils import (collect_env, get_root_logger,
setup_multi_processes)
args = kwargs['args']
args = self.args

cfg = Config.fromfile(args.config)
if args.cfg_options is not None:
Expand Down
17 changes: 7 additions & 10 deletions mmtune/mm/tasks/mmseg.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import copy
import time
from os import path as osp
from typing import Optional
from typing import Optional, Sequence

import mmcv
import torch
Expand All @@ -17,13 +17,9 @@
@TASKS.register_module()
class MMSegmentation(MMTrainBasedTask):

def add_arguments(
self,
parser: Optional[argparse.ArgumentParser] = None
) -> argparse.ArgumentParser:

if parser is None:
parser = argparse.ArgumentParser(description='Train a segmentor')
def parse_args(self, args: Sequence[str]) -> argparse.Namespace:
parser = argparse.ArgumentParser(description='Train a segmentor')
parser.add_argument('config', help='train config file path')
parser.add_argument(
'--work-dir', help='the dir to save logs and models')
parser.add_argument(
Expand Down Expand Up @@ -58,7 +54,8 @@ def add_arguments(
'--auto-resume',
action='store_true',
help='resume from the latest checkpoint automatically.')
return parser
args = parser.parse_args(args)
return args

def build_model(self,
cfg: Config,
Expand Down Expand Up @@ -92,7 +89,7 @@ def run(self, *args, **kwargs):
from mmseg.apis import init_random_seed, set_random_seed
from mmseg.utils import (collect_env, get_root_logger,
setup_multi_processes)
args = kwargs['args']
args = self.args

cfg = Config.fromfile(args.config)
if args.cfg_options is not None:
Expand Down
23 changes: 15 additions & 8 deletions mmtune/mm/tasks/mmtrainbase.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,11 @@
import mmcv
import ray
import torch
from mmcv import Config
from ray.tune.integration.torch import DistributedTrainableCreator

from mmtune.mm.context import ContextManager
from mmtune.utils import ImmutableContainer
from .base import BaseTask
from .builder import TASKS

Expand All @@ -30,24 +32,29 @@ def train_model(self, model: torch.nn.Module,
**kwargs) -> None:
pass

def contextaware_run(self, status, backend, *args, **kwargs) -> None:
from mmtune.mm import hooks # noqa F401
def context_aware_run(self, status, backend, *args, **kwargs) -> None:
if backend == 'nccl' and os.getenv('NCCL_BLOCKING_WAIT') is None:
os.environ['NCCL_BLOCKING_WAIT'] = '0'
context_manager = ContextManager(**status)
return context_manager(self.run)(*args, **kwargs)

def create_trainable(self, backend: str = 'nccl') -> ray.tune.trainable:
def create_trainable(self,
backend: str = 'nccl',
num_workers: int = 1,
num_gpus_per_worker: int = 1,
num_cpus_per_worker: int = 1) -> ray.tune.trainable:
assert backend in ['gloo', 'nccl']

base_cfg = Config.fromfile(self.args.config)
base_cfg = ImmutableContainer(base_cfg, 'base')
return DistributedTrainableCreator(
partial(
self.contextaware_run,
self.context_aware_run,
dict(
base_cfg=self.base_cfg,
base_cfg=base_cfg,
args=self.args,
rewriters=self.rewriters), backend),
backend=backend,
num_workers=self.args.num_workers,
num_gpus_per_worker=self.args.num_cpus_per_worker,
num_cpus_per_worker=self.args.num_cpus_per_worker)
num_workers=num_workers,
num_gpus_per_worker=num_gpus_per_worker,
num_cpus_per_worker=num_cpus_per_worker)
2 changes: 1 addition & 1 deletion mmtune/mm/tasks/sphere.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
class Sphere(BloackBoxTask):

def run(self, *args, **kwargs):
args = kwargs['args']
args = self.args
cfg = Config.fromfile(args.config)

inputs = []
Expand Down
Loading