Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Support using gradient checkpointing in FSDP #1382

Merged
merged 8 commits into from
Oct 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 44 additions & 1 deletion mmengine/_strategy/fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import os.path as osp
import time
from collections import OrderedDict
from functools import partial
from typing import Callable, Dict, List, Optional, Sequence, Union

import torch.nn as nn
Expand All @@ -25,7 +26,7 @@
from mmengine.optim import (AmpOptimWrapper, BaseOptimWrapper, OptimWrapper,
OptimWrapperDict, _ParamScheduler,
build_optim_wrapper)
from mmengine.registry import (MODEL_WRAPPERS, OPTIM_WRAPPERS,
from mmengine.registry import (FUNCTIONS, MODEL_WRAPPERS, OPTIM_WRAPPERS,
PARAM_SCHEDULERS, STRATEGIES, Registry)
from mmengine.utils import get_git_hash, mkdir_or_exist
from .distributed import DDPStrategy
Expand Down Expand Up @@ -91,6 +92,19 @@ class FSDPStrategy(DDPStrategy):
:meth:`setup_env`. Defaults to None.
- log_kwargs (dict, optional): Logger config passed in
:meth:`build_logger`. Defaults to None.
activation_checkpointing (dict, optional): Config dict for gradient
checkpoint.

Examples:
>>> activation_checkpointing = dict(check_fn='CustomCheckFn')
>>> activation_checkpointing = dict(check_fn=dict(type='CustomCheckFn', arg1=arg1))


``check_fn`` field should behave consistently with
``auto_wrap_policy`` defined in `model_wrapper`, and other
fields will be passed to ``apply_activation_checkpointing``

`New in version 0.9.0.`

.. _FSDP official api documents: https://pytorch.org/docs/stable/fsdp.html#torch.distributed.fsdp.FullyShardedDataParallel.set_state_dict_type
""" # noqa: E501
Expand All @@ -100,13 +114,15 @@ def __init__(self,
model_wrapper: Optional[dict] = None,
skip_init_weights=False,
state_dict_cfg: Union[str, dict] = 'local',
activation_checkpointing: Optional[dict] = None,
**kwargs):
super().__init__(model_wrapper=model_wrapper, **kwargs)
self._init_state_dict_cfg(state_dict_cfg)
if not isinstance(skip_init_weights, bool):
raise TypeError('skip_init_weights must be a boolean, but got '
f'{type(skip_init_weights)}')
self.skip_init_weights = skip_init_weights
self.activation_checkpointing = activation_checkpointing

def _wrap_model(self, model: nn.Module) -> None:
"""Wrap the model to :obj:``MMFullyShardedDataParallel`` or other
Expand All @@ -119,6 +135,12 @@ def _wrap_model(self, model: nn.Module) -> None:
FullyShardedDataParallel: ``MMFullyShardedDataParallel``
or subclass of ``FullyShardedDataParallel``.
"""
try:
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import \
apply_activation_checkpointing # noqa: E501
except ImportError:
apply_activation_checkpointing = None

for module in model.modules():
if isinstance(module, BaseDataPreprocessor):
module.to(get_device())
Expand All @@ -138,6 +160,27 @@ def _wrap_model(self, model: nn.Module) -> None:
model.set_state_dict_type(model, self.state_dict_type,
self.state_dict_config,
self.optim_state_dict_config)

if self.activation_checkpointing is not None:
if apply_activation_checkpointing is None:
raise RuntimeError(
'activation_checkpointing maybe deprecated by current '
'PyTorch version, maybe you could switch to PyTorch 2.0 '
'or 2.1 to use `activation_checkpointing`.')
cfg = copy.deepcopy(self.activation_checkpointing)
with FUNCTIONS.switch_scope_and_registry(None):
check_fn = cfg.pop('check_fn')
if isinstance(check_fn, str):
check_fn = FUNCTIONS.get(check_fn)
elif isinstance(check_fn, dict):
fn_type = check_fn.pop('type')
if isinstance(fn_type, str):
fn_type = FUNCTIONS.get(fn_type)
check_fn = partial(fn_type, **cfg)

if not callable(check_fn):
raise TypeError('`check_fn` must be a callable function')
apply_activation_checkpointing(model, check_fn=check_fn, **cfg)
return model

def _is_full_state_dict(self):
Expand Down
90 changes: 46 additions & 44 deletions mmengine/model/wrappers/fully_sharded_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,51 +146,53 @@ def __init__(
'`cpu_offload` should be `None`, `bool`'
f'or `CPUOffload`, but has type {type(cpu_offload)}')

if isinstance(auto_wrap_policy, str):
auto_wrap_policy = FUNCTIONS.get( # type: ignore
auto_wrap_policy)
if auto_wrap_policy is None:
raise ValueError('`auto_wrap_policy` is not registered!')
elif isinstance(auto_wrap_policy, dict):
policy = auto_wrap_policy.pop('type')
if isinstance(policy, str):
policy = FUNCTIONS.get(policy) # type: ignore
if policy is None:
raise ValueError('`auto_wrap_policy` is not registered!')
auto_wrap_policy = partial(policy, **auto_wrap_policy)

if not (auto_wrap_policy is None
or callable(auto_wrap_policy)): # type: ignore
raise TypeError('`auto_wrap_policy` should be a str, a '
'callable, a dict or None, but has type '
f'{type(auto_wrap_policy)}')

if isinstance(backward_prefetch, str):
backward_prefetch = BackwardPrefetch[backward_prefetch]
if not (isinstance(backward_prefetch, BackwardPrefetch)
or backward_prefetch is None):
raise TypeError(
'`backward_prefetch` should be `None`, string of '
'"BACKWARD_PRE" and "BACKWARD_POST", or '
f'`BackwardPrefetch`, but has type {type(backward_prefetch)}')

if isinstance(param_init_fn, str):
param_init_fn = FUNCTIONS.get( # type: ignore
param_init_fn)
if param_init_fn is None:
raise ValueError('`param_init_fn` is not registered!')
elif isinstance(param_init_fn, dict):
init_fn = param_init_fn.pop('type')
with FUNCTIONS.switch_scope_and_registry(None):
if isinstance(auto_wrap_policy, str):
auto_wrap_policy = FUNCTIONS.get( # type: ignore
auto_wrap_policy)
if auto_wrap_policy is None:
raise ValueError('`auto_wrap_policy` is not registered!')
elif isinstance(auto_wrap_policy, dict):
policy = auto_wrap_policy.pop('type')
if isinstance(policy, str):
policy = FUNCTIONS.get(policy) # type: ignore
if policy is None:
raise ValueError('`auto_wrap_policy` is not registered!')
auto_wrap_policy = partial(policy, **auto_wrap_policy)

if not (auto_wrap_policy is None
or callable(auto_wrap_policy)): # type: ignore
raise TypeError('`auto_wrap_policy` should be a str, a '
'callable, a dict or None, but has type '
f'{type(auto_wrap_policy)}')

if isinstance(backward_prefetch, str):
backward_prefetch = BackwardPrefetch[backward_prefetch]
if not (isinstance(backward_prefetch, BackwardPrefetch)
or backward_prefetch is None):
raise TypeError(
'`backward_prefetch` should be `None`, string of '
'"BACKWARD_PRE" and "BACKWARD_POST", or '
f'`BackwardPrefetch`, but has type {type(backward_prefetch)}' # noqa: E501
)

if isinstance(param_init_fn, str):
init_fn = FUNCTIONS.get(init_fn) # type: ignore
if init_fn is None:
raise ValueError('`param_init_fn` is not registered!')
param_init_fn = partial(init_fn, **param_init_fn)

if not (callable(param_init_fn) or param_init_fn is None):
raise TypeError('`param_init_fn` should be a str, a '
'callable, a dict or None, but has type '
f'{type(param_init_fn)}')
param_init_fn = FUNCTIONS.get( # type: ignore
param_init_fn)
if param_init_fn is None:
raise ValueError('`param_init_fn` is not registered!')
elif isinstance(param_init_fn, dict):
init_fn = param_init_fn.pop('type')
if isinstance(param_init_fn, str):
init_fn = FUNCTIONS.get(init_fn) # type: ignore
if init_fn is None:
raise ValueError('`param_init_fn` is not registered!')
param_init_fn = partial(init_fn, **param_init_fn)

if not (callable(param_init_fn) or param_init_fn is None):
raise TypeError('`param_init_fn` should be a str, a '
'callable, a dict or None, but has type '
f'{type(param_init_fn)}')

def parse_dtype(dtype):
if dtype is None:
Expand Down