Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Enhance] Support MIM trainable #125

Merged
merged 50 commits into from
Jan 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
50 commits
Select commit Hold shift + click to select a range
4c07274
:sparkles: Add irr trainer
yhna940 Dec 26, 2022
60cb8e9
:art: Apply lint
yhna940 Dec 26, 2022
4144551
:bug: Fix entry point
yhna940 Dec 26, 2022
8f91f47
:art: Typo
yhna940 Dec 26, 2022
f344759
:bug: Fix remote job
yhna940 Dec 27, 2022
02fedae
:art: Fix lone run ref
yhna940 Dec 27, 2022
9cf0ba6
:bug: Fix minor bug
yhna940 Dec 27, 2022
e872674
:recycle: Split component
yhna940 Dec 29, 2022
1ed6c70
:bug: Del red context
yhna940 Dec 29, 2022
a825172
:sparkles: Add dist creator
yhna940 Dec 29, 2022
665c426
:art: Apply lint
yhna940 Dec 29, 2022
f0b4425
:art: Rename dp creator
yhna940 Dec 29, 2022
8e851a6
:wheelchair: Fix rewriters
yhna940 Dec 30, 2022
ae21fb6
:bug: Fix dist
yhna940 Dec 30, 2022
75d6540
:test_tube: Test
yhna940 Dec 30, 2022
04ea33f
:test_tube: Test
yhna940 Dec 30, 2022
491147c
:bug: typo
yhna940 Dec 30, 2022
b3e343d
:rewind: Use ddp trainer
yhna940 Jan 3, 2023
9c538bc
:rotating_light: minor bug fix
yhna940 Jan 3, 2023
865c73e
Fix context
yhna940 Jan 3, 2023
75490e1
:bug: Add revert workspace
yhna940 Jan 3, 2023
7b23592
:rewind: Revert custom torch launcher
yhna940 Jan 4, 2023
d3e1183
:chart_with_upwards_trend: Expand Reporter
yhna940 Jan 4, 2023
31c8982
:art: Apply lint
yhna940 Jan 4, 2023
817010a
Resolve conflict
yhna940 Jan 4, 2023
0c68af4
Fix mim dir
yhna940 Jan 4, 2023
f1ea0d8
Fix filtering bug in reporter
yhna940 Jan 4, 2023
691ff34
:bug: Fix report minor bug
yhna940 Jan 5, 2023
16e376e
Update siatune/codebase/base.py
yhna940 Jan 5, 2023
529cea9
Update siatune/codebase/mm.py
yhna940 Jan 5, 2023
b3ca901
Update siatune/core/launch.py
yhna940 Jan 5, 2023
90f9ece
:art: Reflect the review
yhna940 Jan 5, 2023
c067ca3
:fire: Del ckpt
yhna940 Jan 5, 2023
54a1ef8
Add test code
yhna940 Jan 6, 2023
d8ad0af
:bug: Tmp unit test
yhna940 Jan 6, 2023
ae4e16f
:bug: Fix Test code
yhna940 Jan 6, 2023
a7c6848
:art: Test code fix
yhna940 Jan 6, 2023
6d79717
:bug: Add entrypoint test script
yhna940 Jan 6, 2023
7d32db2
:art: Apply lint
yhna940 Jan 6, 2023
dd45fbf
Update siatune/core/rewriters/resume.py
yhna940 Jan 7, 2023
a1c29fa
Update siatune/core/hooks/reporter.py
yhna940 Jan 7, 2023
81064f5
:fire: Rename
yhna940 Jan 7, 2023
cfae668
:memo: Add docstring
yhna940 Jan 7, 2023
85945a5
:art: Fix test code
yhna940 Jan 7, 2023
78090e5
Update siatune/codebase/mim.py
yhna940 Jan 7, 2023
8d8b254
Update siatune/codebase/base.py
yhna940 Jan 7, 2023
a26f4de
Update siatune/codebase/base.py
yhna940 Jan 7, 2023
546a513
Update configs/mmcls/mmcls_cifar_100_asynchb_nevergrad_pso.py
yhna940 Jan 7, 2023
b22ee6b
Update configs/mmcls/mim_cifar_100_asynchb_nevergrad_pso.py
yhna940 Jan 7, 2023
3220cc5
:bug: minor
yhna940 Jan 7, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions configs/_base_/context/train.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
task = dict(rewriters=[
dict(type='InstantiateCfg', arg_name='config', key='base_cfg'),
dict(type='InstantiateCfg', key='base_cfg'),
dict(type='BatchConfigPatcher', key='searched_cfg'),
dict(type='SequeunceConfigPatcher', key='searched_cfg'),
dict(
Expand All @@ -12,10 +12,11 @@
key='cfg',
post_custom_hooks=[
dict(
type='RayTuneLoggerHook',
type='RayTuneReporterHook',
filtering_key='val',
priority='VERY_LOW'),
dict(type='RayCheckpointHook', by_epoch=True, interval=1)
]),
dict(type='Dump', key='cfg', arg_name='config')
dict(type='ResumeFromCkpt'),
dict(type='Dump', key='cfg'),
dict(type='AttachTrialInfoToPath')
])
20 changes: 20 additions & 0 deletions configs/mmcls/mim_cifar_100_asynchb_nevergrad_pso.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
_base_ = [
'../_base_/context/train.py', '../_base_/searcher/nevergrad_pso.py',
'../_base_/scheduler/asynchb.py', '../_base_/space/mmcls_model.py',
'../_base_/space/optimizer.py', '../_base_/space/batch_size.py'
]

space = {
'model': {{_base_.model}},
'model.head.num_classes': 100,
'optimizer': {{_base_.optimizer}},
'data.samples_per_gpu': {{_base_.batch_size}},
}

task = dict(type='MIM', pkg_name='mmcls')
tune_cfg = dict(
num_samples=8,
metric='val/accuracy_top-1',
mode='max',
reuse_actors=False,
chdir_to_trial_dir=False)
15 changes: 4 additions & 11 deletions siatune/codebase/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,22 +4,15 @@
from .builder import TASKS, build_task
from .cont_test_func import ContinuousTestFunction
from .disc_test_func import DiscreteTestFunction
from .mim import MIM
from .mm import MMBaseTask
from .mmcls import MMClassification
from .mmdet import MMDetection
from .mmedit import MMEditing
from .mmseg import MMSegmentation

__all__ = [
'TASKS',
'build_task',
'BaseTask',
'BlackBoxTask',
'ContinuousTestFunction',
'DiscreteTestFunction',
'MMBaseTask',
'MMClassification',
'MMDetection',
'MMEditing',
'MMSegmentation',
'TASKS', 'build_task', 'BaseTask', 'BlackBoxTask',
'ContinuousTestFunction', 'DiscreteTestFunction', 'MMBaseTask',
'MMClassification', 'MMDetection', 'MMEditing', 'MMSegmentation', 'MIM'
]
23 changes: 13 additions & 10 deletions siatune/codebase/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import argparse
from abc import ABCMeta, abstractmethod
from copy import deepcopy
from typing import Callable, Optional, Sequence, Union
from typing import Optional, Sequence, Union

from ray.tune import Trainable

Expand All @@ -26,7 +26,8 @@ class BaseTask(metaclass=ABCMeta):
Aggregate the information we define as context,
convert it into a refined argparse namespace, and input it to run.
The context consists of:
1. args (argparse.Namespace): The low level CLI arguments.
1. args (argparse.Namespace | Sequence[str]):
The low level CLI arguments.
2. searched_cfg (Dict):
The configuration searched by the algorithm.
Inputs: searched_cfg (Dict)
Expand All @@ -46,11 +47,15 @@ class BaseTask(metaclass=ABCMeta):

def __init__(self,
args: Sequence[str],
num_workers: int,
num_workers: int = 1,
num_cpus_per_worker: int = 1,
num_gpus_per_worker: int = 1,
rewriters: Optional[Union[list, dict]] = None):
self.args = self.parse_args(args)
rewriters: Optional[Union[list, dict]] = None,
should_parse: bool = True):

if should_parse:
args = self.parse_args(args)
self.args = args

self.num_workers = num_workers
self.num_cpus_per_worker = num_cpus_per_worker
Expand All @@ -61,7 +66,8 @@ def __init__(self,
self.rewriters = rewriters

@abstractmethod
def parse_args(self, args: Sequence[str]) -> argparse.Namespace:
def parse_args(self,
args: Sequence[str]) -> Union[argparse.Namespace, None]:
"""Define and parse the necessary arguments for the task.

Args:
Expand All @@ -72,15 +78,12 @@ def parse_args(self, args: Sequence[str]) -> argparse.Namespace:
"""
pass

def context_aware_run(self, searched_cfg: dict) -> Callable:
def context_aware_run(self, searched_cfg: dict):
"""Gather and refine the information received by users and Ray.tune to
execute the objective task.

Args:
searched_cfg (Dict): The searched configuration.

Returns:
Callable: The result of the objective task.
"""

context_manager = ContextManager(self.rewriters)
Expand Down
68 changes: 68 additions & 0 deletions siatune/codebase/mim.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
# Copyright (c) SI-Analytics. All rights reserved.
from importlib.machinery import SourceFileLoader
from typing import Sequence

from siatune.utils import get_train_script
from .builder import TASKS
from .mm import MMBaseTask


class _EntrypointExecutor:
"""Execute the entrypoint of open mm train-based projects.

Args:
pkg_name (str): The abbreviation of the package.
argv (Sequence[str]): The arguments for `tools/train.py`
module_name (str):
The name of the module to execute. Defaults to 'main'.
"""

def __init__(self,
pkg_name: str,
argv: Sequence[str],
module_name: str = 'main'):
self._train_script = get_train_script(pkg_name)
self._module_name = module_name
self._argv = argv
self._entrypoint = SourceFileLoader(self._module_name,
self._train_script).load_module()

def _hijack_argv(self, argv: Sequence[str]):
"""Hijack the command line arguments.

Args:
argv (Sequence[str]): The arguments for `tools/train.py`
"""
import sys
sys.argv[1:] = argv
return

def execute(self):
"""Run the task."""
self._hijack_argv(self._argv)
getattr(self._entrypoint, self._module_name)()


@TASKS.register_module()
class MIM(MMBaseTask):
"""Wrapper class execute any script provided by all OpenMMLab codebases.

Args:
pkg_name (str): The abbreviation of the package.
"""

def __init__(self, pkg_name: str, **kwargs):
self._pkg_name = pkg_name
super().__init__(should_parse=False, **kwargs)

def parse_args(self, *args, **kwargs) -> None:
pass

def run(self, args: Sequence[str]):
"""This method runs a task in the MIM framework.

Args:
args (Sequence[str]): A list of command-line arguments.
"""
executor = _EntrypointExecutor(self._pkg_name, args)
executor.execute()
48 changes: 34 additions & 14 deletions siatune/codebase/mm.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
# Copyright (c) SI-Analytics. All rights reserved.
from abc import ABCMeta
from copy import deepcopy
from typing import Callable

import torch
from ray.air.config import ScalingConfig
from ray.train.data_parallel_trainer import DataParallelTrainer
from ray.tune import with_resources

from siatune.tune import MMBackendConfig
from siatune.core import ContextManager, DistributedTorchLauncher
from siatune.utils import ImmutableContainer
from .base import BaseTask
from .builder import TASKS

Expand All @@ -14,17 +15,36 @@
class MMBaseTask(BaseTask, metaclass=ABCMeta):
"""Wrap the apis of open mm train-based projects."""

def create_trainable(self) -> DataParallelTrainer:
"""Get a :class:`DataParallelTrainer` instance.
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
assert self.num_gpus_per_worker == 1
self.launcher = DistributedTorchLauncher(
self.num_cpus_per_worker,
self.num_workers,
)

def create_trainable(self) -> Callable:
"""Get a trainable task.

Returns:
DataParallelTrainer: Trainer to optimize hyperparameter.
Callable: Callable object to optimize hyperparameter.
"""
return with_resources(self.context_aware_run, self.launcher.resources)

def context_aware_run(self, searched_cfg: dict):
"""Gather and refine the information received by users and Ray.tune to
execute the objective task.

Args:
searched_cfg (Dict): The searched configuration.
"""

context_manager = ContextManager(self.rewriters)
context = dict(
args=deepcopy(self.args),
searched_cfg=deepcopy(ImmutableContainer.decouple(searched_cfg)),
)
return context_manager(self.dist_run)(**context)

return DataParallelTrainer(
self.context_aware_run,
backend_config=MMBackendConfig(),
scaling_config=ScalingConfig(
trainer_resources=dict(CPU=self.num_cpus_per_worker),
num_workers=self.num_workers,
use_gpu=torch.cuda.is_available()))
def dist_run(self, *args, **kwargs):
self.launcher.launch(self.run, *args, **kwargs)
8 changes: 7 additions & 1 deletion siatune/core/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
# Copyright (c) SI-Analytics. All rights reserved.
from .context import ContextManager
from .hooks import * # noqa F403
from .launch import DistributedTorchLauncher
from .rewriters import REWRITERS, build_rewriter

__all__ = ['ContextManager', 'REWRITERS', 'build_rewriter']
__all__ = [
'ContextManager',
'REWRITERS',
'build_rewriter',
'DistributedTorchLauncher',
]
5 changes: 2 additions & 3 deletions siatune/core/hooks/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
# Copyright (c) SI-Analytics. All rights reserved.
from .checkpoint import RayCheckpointHook
from .reporter import RayTuneLoggerHook
from .reporter import RayTuneReporterHook

__all__ = ['RayCheckpointHook', 'RayTuneLoggerHook']
__all__ = ['RayTuneReporterHook']
Loading