-
Notifications
You must be signed in to change notification settings - Fork 0
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add DeepSpeed Baseline #1
base: main
Are you sure you want to change the base?
Conversation
Co-authored-by: Saeyeol Lee <[email protected]> Co-authored-by: Donggeun Yu <[email protected]> Co-authored-by: Junhwa Song <[email protected]> Co-authored-by: Younghwan Na <[email protected]> Signed-off-by: Hakjin Lee <[email protected]> Signed-off-by: Saeyeol Lee <[email protected]> Signed-off-by: Donggeun Yu <[email protected]> Signed-off-by: Junhwa Song <[email protected]> Signed-off-by: Younghwan Na <[email protected]>
# load DeepSpeed configuration file | ||
self.ds_config = self.cfg.get('ds_config', None) | ||
assert self.ds_config is not None, 'ds_config should be specified.' |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
load deepspeed_config
self.ds_config = self.cfg.get('ds_config', None) | ||
assert self.ds_config is not None, 'ds_config should be specified.' | ||
|
||
self.check_ds_config(self.ds_config) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some ds_config
option should be ignored since MMEngine
already supports it.
# initialize the model weights before wrapping it with deepspeed | ||
self._weights_initialized = False | ||
self._init_model_weights() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Model weights should be initialized before wrapping it with DeepSpeedEngine
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there any documentation on the reason?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I couldn't find any documentation yet. But, we couldn't reproduce the performance before changing it. (Similar performance trained from scratch.)
We could reproduce the performance after fixing it.
Maybe there is the other way to handle this.
if model_wrapper_cfg is None: | ||
# Model will be wrapped in `deepspeed.initialize`. | ||
pass |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wrapping model with DeepSpeedEngine
is done in deepspeed.initialize
.
We may be able to get out its logic by wrapping DeepSpeedEngine
as model_wrapper
.
|
||
# TODO: Model Sequentializing | ||
# sequential_model = convert_to_sequential_model(model) | ||
# model = PipelineModule( | ||
# layers=[model], num_stages=int(os.environ['WORLD_SIZE'])) | ||
raise NotImplementedError( | ||
'Pipeline Parallel is not implemented yet.') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Cannot support PP yet.
def consolidate_state_dict(self, | ||
state_dict: Dict[str, Any], | ||
to: int = 0) -> None: | ||
r""" | ||
Consolidate a list of ``state_dict`` s (one per rank) on the target | ||
rank. | ||
Arguments: | ||
to (int): the rank that receives the optimizer states (default: 0). | ||
Raises: | ||
RuntimeError: if ``overlap_with_ddp=True`` and this method is | ||
called before this :class:`ZeroRedundancyOptimizer` instance | ||
has been fully initialized, which happens once | ||
:class:`DistributedDataParallel` gradient buckets have been | ||
rebuilt. | ||
.. warning:: This needs to be called on all ranks. | ||
""" | ||
from torch.distributed.optim.zero_redundancy_optimizer import ( | ||
_broadcast_object, _recursive_copy_to_device) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does deepspeed
provide APIs to deal with this? Borrowing from other library may cause incompatibility in the future.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll try to check it!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I checked the above link. It provides similar functions but has limitations.
- It only supports ZeRO3
- It only supports model state.
I think this consolidate
logic is general and can be used for other similar purposes.
How about adding it in mmengine/dist/utils
?
# initialize DeepSpeed Engine | ||
self.model, optimizer, _, _ = deepspeed.initialize( | ||
model=self.model, | ||
optimizer=self.optim_wrapper.optimizer, | ||
model_parameters=self.model.parameters(), | ||
config=self.ds_config) | ||
self.optim_wrapper.optimizer = optimizer |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is the problematic line.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since optimizer
is wrapped in DeepSpeedEngine
, update the optimizer in optim_wrapper
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I suspect this update operation might be buggy in some special situations... Maybe it's better to build optimizer first, add it to dict
and then build_optim_wrapper
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think your idea is desirable. I'll check it.
if self.model.zero_optimization_partition_weights(): | ||
device = get_device() | ||
checkpoint = _load_checkpoint(filename, map_location=device) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When using ZeRO
, first load state_dict
without changing it.
e.g. module
must not be deleted to resume ZeRO.
if self.model.zero_optimization_partition_weights(): | ||
# Prepare for checkpoint save by | ||
# ensuring all parameters are partitioned | ||
self.model.optimizer.checkpoint_event_prologue() | ||
|
||
checkpoint = { | ||
'meta': meta, | ||
'message_hub': self.message_hub.state_dict(), | ||
} | ||
# save optimizer state dict to checkpoint | ||
if save_optimizer: | ||
if not self.model.zero_optimization(): | ||
checkpoint['optimizer'] = self.optim_wrapper.state_dict() | ||
else: | ||
self.consolidate_state_dict(self.optim_wrapper.state_dict()) | ||
# Only the main process needs to load the optimizer's state. | ||
optim_state = self.get_zero_state_dict() | ||
checkpoint['optimizer'] = optim_state | ||
|
||
# model state is stored after pulling optimizer state to handle ZeRO 3. | ||
checkpoint['state_dict'] = weights_to_cpu(self.get_state_dict(model)) | ||
|
||
if self.model.zero_optimization_partition_weights(): | ||
self.model.optimizer.checkpoint_event_epilogue() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if self.model.zero_optimization(): | ||
self.optim_wrapper.load_state_dict( # type: ignore | ||
checkpoint['optimizer'], | ||
load_from_fp32_weights=self.model. | ||
zero_load_from_fp32_weights()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This tries to follow the compatibility with https://github.com/microsoft/DeepSpeed/blob/211055216792cbb52ab6d355f698c194f9c55efb/deepspeed/runtime/zero/stage_1_and_2.py#L2228
if self.model.zero_optimization_partition_weights(): | ||
optim_state = self.get_zero_state_dict() | ||
fp32_flat_groups = [ | ||
torch.cat(optim_state[i][FP32_FLAT_GROUPS]) | ||
for i in range(len(optim_state)) | ||
] | ||
param_shapes = self.model._get_zero_param_shapes()[0] | ||
param_shapes = OrderedDict( | ||
{'module.' + k: v | ||
for k, v in param_shapes.items()}) | ||
|
||
model_state = _get_fp32_state_dict_from_zero3_checkpoint( | ||
world_size=self._world_size, | ||
param_shapes=[param_shapes], | ||
fp32_flat_groups=fp32_flat_groups, | ||
buffers={}) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There should be special logic for ZeRO 3.
https://github.com/microsoft/DeepSpeed/blob/211055216792cbb52ab6d355f698c194f9c55efb/deepspeed/utils/zero_to_fp32.py#L100
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In the case of ZeRO3, model_state
are saved in the optim_state
.
Check deepspeedai/DeepSpeed#2413
# model state is stored after pulling optimizer state to handle ZeRO 3. | ||
checkpoint['state_dict'] = weights_to_cpu(self.get_state_dict(model)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is important.
if not self.model.zero_optimization(): | ||
checkpoint['optimizer'] = self.optim_wrapper.state_dict() | ||
else: | ||
self.consolidate_state_dict(self.optim_wrapper.state_dict()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To save optimizer_state
, consolidate optim_state
from all ranks.
def load_state_dict(self, state_dict: dict, **kwargs) -> None: | ||
"""A wrapper of ``Optimizer.load_state_dict``. load the state dict of |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
**kwargs
are newly added.
# Set logging level to remove duplicate training log from DeepSpeed | ||
deepspeed_logger = logging.getLogger('DeepSpeed') | ||
deepspeed_logger.setLevel(logging.WARNING) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove deepspeed
's log.
def inject_basemodel_methods(self): | ||
"""inject methods from ``BaseModel`` into ``DeepSpeedEngine`` to make | ||
``DeepSpeedEngine`` support the ``train_step`` method appropriately. | ||
|
||
Without injecting, ``DeepSpeedOptimWrapper`` tries ``backward`` from | ||
``BaseModel``, which should be in ``DeepSpeedEngine``. | ||
""" | ||
|
||
def _train_step(self, data: Union[dict, tuple, list], | ||
optim_wrapper) -> Dict[str, torch.Tensor]: | ||
with optim_wrapper.optim_context(self): | ||
data = self.data_preprocessor(data, True) | ||
losses = self._run_forward(data, mode='loss') # type: ignore | ||
parsed_losses, log_vars = self.parse_losses(losses) # type: ignore | ||
optim_wrapper.update_params(parsed_losses) | ||
return log_vars | ||
|
||
self.model.train_step = types.MethodType(_train_step, self.model) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we delete this by wrapping DeepSpeedEngine
with BaseModel
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Really awesome work! Actually I've not finished my review, especially those "messy" logic related to checkpoint saving/loading. I'll refer to DeepSpeed docs & example codes and your comments later. Stay connected.
deepspeed_logger = logging.getLogger('DeepSpeed') | ||
deepspeed_logger.setLevel(logging.WARNING) | ||
|
||
self.inject_basemodel_methods() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are some models who overload the train_step
, val_step
, test_step
methods. Can we support them in deepspeed_runner?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can I ask you to provide some links to those examples?
One of the difficulties to handle this is that DeepSpeedEngine
fails to get some attributes from self.module
. https://github.com/microsoft/DeepSpeed/blob/90ae6884424232870154b49967c3e61f0db550d6/deepspeed/runtime/engine.py#L461
I think it will be difficult to support them in the current implementation. Maybe we have to find the other way to handle this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some low-level tasks need to rewrite train_step
, such as GAN in mmediting. It is indeed very difficult to support them, so I think current implementation is acceptable.
# initialize the model weights before wrapping it with deepspeed | ||
self._weights_initialized = False | ||
self._init_model_weights() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there any documentation on the reason?
def consolidate_state_dict(self, | ||
state_dict: Dict[str, Any], | ||
to: int = 0) -> None: | ||
r""" | ||
Consolidate a list of ``state_dict`` s (one per rank) on the target | ||
rank. | ||
Arguments: | ||
to (int): the rank that receives the optimizer states (default: 0). | ||
Raises: | ||
RuntimeError: if ``overlap_with_ddp=True`` and this method is | ||
called before this :class:`ZeroRedundancyOptimizer` instance | ||
has been fully initialized, which happens once | ||
:class:`DistributedDataParallel` gradient buckets have been | ||
rebuilt. | ||
.. warning:: This needs to be called on all ranks. | ||
""" | ||
from torch.distributed.optim.zero_redundancy_optimizer import ( | ||
_broadcast_object, _recursive_copy_to_device) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does deepspeed
provide APIs to deal with this? Borrowing from other library may cause incompatibility in the future.
@contextmanager | ||
def optim_context(self, model: nn.Module): | ||
"""A Context for gradient accumulation and automatic mix precision | ||
training. | ||
|
||
Compared to the original method, this saves model information as | ||
a member variable in order to use in the training step. | ||
|
||
Args: | ||
model (nn.Module): The training model. | ||
""" | ||
# During gradient accumulation process, the gradient synchronize | ||
# should only happen before updating parameters. | ||
self.model = model | ||
yield super().optim_context(model) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems a little tricky. As you've mentioned, maybe we should provide wrap_model_and_optimizer
# initialize DeepSpeed Engine | ||
self.model, optimizer, _, _ = deepspeed.initialize( | ||
model=self.model, | ||
optimizer=self.optim_wrapper.optimizer, | ||
model_parameters=self.model.parameters(), | ||
config=self.ds_config) | ||
self.optim_wrapper.optimizer = optimizer |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I suspect this update operation might be buggy in some special situations... Maybe it's better to build optimizer first, add it to dict
and then build_optim_wrapper
?
This PR is for the discussion open-mmlab#749.
Co-authored-by: Saeyeol Lee [email protected]
Co-authored-by: Donggeun Yu [email protected]
Co-authored-by: Junhwa Song [email protected]
Co-authored-by: Younghwan Na [email protected]
Signed-off-by: Hakjin Lee [email protected]
Signed-off-by: Saeyeol Lee [email protected]
Signed-off-by: Donggeun Yu [email protected]
Signed-off-by: Junhwa Song [email protected]
Signed-off-by: Younghwan Na [email protected]