Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add DeepSpeed Baseline #1

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open

Add DeepSpeed Baseline #1

wants to merge 1 commit into from

Conversation

nijkah
Copy link
Owner

@nijkah nijkah commented Nov 22, 2022

This PR is for the discussion open-mmlab#749.

Co-authored-by: Saeyeol Lee [email protected]
Co-authored-by: Donggeun Yu [email protected]
Co-authored-by: Junhwa Song [email protected]
Co-authored-by: Younghwan Na [email protected]

Signed-off-by: Hakjin Lee [email protected]
Signed-off-by: Saeyeol Lee [email protected]
Signed-off-by: Donggeun Yu [email protected]
Signed-off-by: Junhwa Song [email protected]
Signed-off-by: Younghwan Na [email protected]

  • Reproduce the performance
  • Reproduce the performance by resumming.
  • FP16

Co-authored-by: Saeyeol Lee <[email protected]>
Co-authored-by: Donggeun Yu <[email protected]>
Co-authored-by: Junhwa Song <[email protected]>
Co-authored-by: Younghwan Na <[email protected]>

Signed-off-by: Hakjin Lee <[email protected]>
Signed-off-by: Saeyeol Lee <[email protected]>
Signed-off-by: Donggeun Yu <[email protected]>
Signed-off-by: Junhwa Song <[email protected]>
Signed-off-by: Younghwan Na <[email protected]>
Comment on lines +231 to +233
# load DeepSpeed configuration file
self.ds_config = self.cfg.get('ds_config', None)
assert self.ds_config is not None, 'ds_config should be specified.'
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

load deepspeed_config

self.ds_config = self.cfg.get('ds_config', None)
assert self.ds_config is not None, 'ds_config should be specified.'

self.check_ds_config(self.ds_config)
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some ds_config option should be ignored since MMEngine already supports it.

Comment on lines +243 to +245
# initialize the model weights before wrapping it with deepspeed
self._weights_initialized = False
self._init_model_weights()
Copy link
Owner Author

@nijkah nijkah Nov 22, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Model weights should be initialized before wrapping it with DeepSpeedEngine.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there any documentation on the reason?

Copy link
Owner Author

@nijkah nijkah Nov 24, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I couldn't find any documentation yet. But, we couldn't reproduce the performance before changing it. (Similar performance trained from scratch.)
We could reproduce the performance after fixing it.

Maybe there is the other way to handle this.

Comment on lines +322 to +324
if model_wrapper_cfg is None:
# Model will be wrapped in `deepspeed.initialize`.
pass
Copy link
Owner Author

@nijkah nijkah Nov 22, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wrapping model with DeepSpeedEngine is done in deepspeed.initialize.
We may be able to get out its logic by wrapping DeepSpeedEngine as model_wrapper.

Comment on lines +330 to +336

# TODO: Model Sequentializing
# sequential_model = convert_to_sequential_model(model)
# model = PipelineModule(
# layers=[model], num_stages=int(os.environ['WORLD_SIZE']))
raise NotImplementedError(
'Pipeline Parallel is not implemented yet.')
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cannot support PP yet.

Comment on lines +641 to +658
def consolidate_state_dict(self,
state_dict: Dict[str, Any],
to: int = 0) -> None:
r"""
Consolidate a list of ``state_dict`` s (one per rank) on the target
rank.
Arguments:
to (int): the rank that receives the optimizer states (default: 0).
Raises:
RuntimeError: if ``overlap_with_ddp=True`` and this method is
called before this :class:`ZeroRedundancyOptimizer` instance
has been fully initialized, which happens once
:class:`DistributedDataParallel` gradient buckets have been
rebuilt.
.. warning:: This needs to be called on all ranks.
"""
from torch.distributed.optim.zero_redundancy_optimizer import (
_broadcast_object, _recursive_copy_to_device)
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does deepspeed provide APIs to deal with this? Borrowing from other library may cause incompatibility in the future.

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll try to check it!

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I checked the above link. It provides similar functions but has limitations.

  1. It only supports ZeRO3
  2. It only supports model state.

I think this consolidate logic is general and can be used for other similar purposes.
How about adding it in mmengine/dist/utils?

Comment on lines +257 to +263
# initialize DeepSpeed Engine
self.model, optimizer, _, _ = deepspeed.initialize(
model=self.model,
optimizer=self.optim_wrapper.optimizer,
model_parameters=self.model.parameters(),
config=self.ds_config)
self.optim_wrapper.optimizer = optimizer
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the problematic line.

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since optimizer is wrapped in DeepSpeedEngine, update the optimizer in optim_wrapper.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suspect this update operation might be buggy in some special situations... Maybe it's better to build optimizer first, add it to dict and then build_optim_wrapper?

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think your idea is desirable. I'll check it.

Comment on lines +410 to +412
if self.model.zero_optimization_partition_weights():
device = get_device()
checkpoint = _load_checkpoint(filename, map_location=device)
Copy link
Owner Author

@nijkah nijkah Nov 22, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When using ZeRO, first load state_dict without changing it.
e.g. module must not be deleted to resume ZeRO.

Comment on lines +591 to +614
if self.model.zero_optimization_partition_weights():
# Prepare for checkpoint save by
# ensuring all parameters are partitioned
self.model.optimizer.checkpoint_event_prologue()

checkpoint = {
'meta': meta,
'message_hub': self.message_hub.state_dict(),
}
# save optimizer state dict to checkpoint
if save_optimizer:
if not self.model.zero_optimization():
checkpoint['optimizer'] = self.optim_wrapper.state_dict()
else:
self.consolidate_state_dict(self.optim_wrapper.state_dict())
# Only the main process needs to load the optimizer's state.
optim_state = self.get_zero_state_dict()
checkpoint['optimizer'] = optim_state

# model state is stored after pulling optimizer state to handle ZeRO 3.
checkpoint['state_dict'] = weights_to_cpu(self.get_state_dict(model))

if self.model.zero_optimization_partition_weights():
self.model.optimizer.checkpoint_event_epilogue()
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment on lines +481 to +485
if self.model.zero_optimization():
self.optim_wrapper.load_state_dict( # type: ignore
checkpoint['optimizer'],
load_from_fp32_weights=self.model.
zero_load_from_fp32_weights())
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment on lines +752 to +767
if self.model.zero_optimization_partition_weights():
optim_state = self.get_zero_state_dict()
fp32_flat_groups = [
torch.cat(optim_state[i][FP32_FLAT_GROUPS])
for i in range(len(optim_state))
]
param_shapes = self.model._get_zero_param_shapes()[0]
param_shapes = OrderedDict(
{'module.' + k: v
for k, v in param_shapes.items()})

model_state = _get_fp32_state_dict_from_zero3_checkpoint(
world_size=self._world_size,
param_shapes=[param_shapes],
fp32_flat_groups=fp32_flat_groups,
buffers={})
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Owner Author

@nijkah nijkah Nov 22, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the case of ZeRO3, model_state are saved in the optim_state.
Check deepspeedai/DeepSpeed#2413

Comment on lines +610 to +611
# model state is stored after pulling optimizer state to handle ZeRO 3.
checkpoint['state_dict'] = weights_to_cpu(self.get_state_dict(model))
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is important.

if not self.model.zero_optimization():
checkpoint['optimizer'] = self.optim_wrapper.state_dict()
else:
self.consolidate_state_dict(self.optim_wrapper.state_dict())
Copy link
Owner Author

@nijkah nijkah Nov 22, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To save optimizer_state, consolidate optim_state from all ranks.

Comment on lines +88 to +89
def load_state_dict(self, state_dict: dict, **kwargs) -> None:
"""A wrapper of ``Optimizer.load_state_dict``. load the state dict of
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

**kwargs are newly added.

Comment on lines +265 to +267
# Set logging level to remove duplicate training log from DeepSpeed
deepspeed_logger = logging.getLogger('DeepSpeed')
deepspeed_logger.setLevel(logging.WARNING)
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove deepspeed's log.

Comment on lines +370 to +387
def inject_basemodel_methods(self):
"""inject methods from ``BaseModel`` into ``DeepSpeedEngine`` to make
``DeepSpeedEngine`` support the ``train_step`` method appropriately.

Without injecting, ``DeepSpeedOptimWrapper`` tries ``backward`` from
``BaseModel``, which should be in ``DeepSpeedEngine``.
"""

def _train_step(self, data: Union[dict, tuple, list],
optim_wrapper) -> Dict[str, torch.Tensor]:
with optim_wrapper.optim_context(self):
data = self.data_preprocessor(data, True)
losses = self._run_forward(data, mode='loss') # type: ignore
parsed_losses, log_vars = self.parse_losses(losses) # type: ignore
optim_wrapper.update_params(parsed_losses)
return log_vars

self.model.train_step = types.MethodType(_train_step, self.model)
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we delete this by wrapping DeepSpeedEngine with BaseModel?

Copy link
Collaborator

@C1rN09 C1rN09 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Really awesome work! Actually I've not finished my review, especially those "messy" logic related to checkpoint saving/loading. I'll refer to DeepSpeed docs & example codes and your comments later. Stay connected.

deepspeed_logger = logging.getLogger('DeepSpeed')
deepspeed_logger.setLevel(logging.WARNING)

self.inject_basemodel_methods()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some models who overload the train_step, val_step, test_step methods. Can we support them in deepspeed_runner?

Copy link
Owner Author

@nijkah nijkah Nov 24, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can I ask you to provide some links to those examples?

One of the difficulties to handle this is that DeepSpeedEngine fails to get some attributes from self.module. https://github.com/microsoft/DeepSpeed/blob/90ae6884424232870154b49967c3e61f0db550d6/deepspeed/runtime/engine.py#L461

I think it will be difficult to support them in the current implementation. Maybe we have to find the other way to handle this.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some low-level tasks need to rewrite train_step, such as GAN in mmediting. It is indeed very difficult to support them, so I think current implementation is acceptable.

Comment on lines +243 to +245
# initialize the model weights before wrapping it with deepspeed
self._weights_initialized = False
self._init_model_weights()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there any documentation on the reason?

Comment on lines +641 to +658
def consolidate_state_dict(self,
state_dict: Dict[str, Any],
to: int = 0) -> None:
r"""
Consolidate a list of ``state_dict`` s (one per rank) on the target
rank.
Arguments:
to (int): the rank that receives the optimizer states (default: 0).
Raises:
RuntimeError: if ``overlap_with_ddp=True`` and this method is
called before this :class:`ZeroRedundancyOptimizer` instance
has been fully initialized, which happens once
:class:`DistributedDataParallel` gradient buckets have been
rebuilt.
.. warning:: This needs to be called on all ranks.
"""
from torch.distributed.optim.zero_redundancy_optimizer import (
_broadcast_object, _recursive_copy_to_device)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does deepspeed provide APIs to deal with this? Borrowing from other library may cause incompatibility in the future.

Comment on lines +72 to +86
@contextmanager
def optim_context(self, model: nn.Module):
"""A Context for gradient accumulation and automatic mix precision
training.

Compared to the original method, this saves model information as
a member variable in order to use in the training step.

Args:
model (nn.Module): The training model.
"""
# During gradient accumulation process, the gradient synchronize
# should only happen before updating parameters.
self.model = model
yield super().optim_context(model)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems a little tricky. As you've mentioned, maybe we should provide wrap_model_and_optimizer

Comment on lines +257 to +263
# initialize DeepSpeed Engine
self.model, optimizer, _, _ = deepspeed.initialize(
model=self.model,
optimizer=self.optim_wrapper.optimizer,
model_parameters=self.model.parameters(),
config=self.ds_config)
self.optim_wrapper.optimizer = optimizer
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suspect this update operation might be buggy in some special situations... Maybe it's better to build optimizer first, add it to dict and then build_optim_wrapper?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants