-
Notifications
You must be signed in to change notification settings - Fork 3.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support user-defined parallelization in the LightningModule #11922
Comments
Should the model be wrapped instead in rank = int(os.environ["LOCAL_RANK"])
if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}")
backend = "nccl"
torch.cuda.set_device(device)
else:
device = torch.device("cpu")
backend = "gloo"
if not torch.distributed.is_initialized():
dist.init_process_group(backend=backend) The code above (wrapping in Should the creation of the process group be completely customizable for support with strategies like DeepSpeed or Bagua? |
Initializing the model in
The LightningModule code is determined by the user. So they would have to determine if they require their code to work with DDP spawn vs not. I think it's going to be hard to support all of custom parallelization + checkpoint loading + spawning simultaneously.
Users can already initialize the process group themselves if they create the processes externally. The only instances where Lightning has to create the process group is for spawn and subprocess script launch. A lighter form of customization is being worked on in #11745 . Note: I don't want to make this an issue about supporting spawning vs not. This is only to state that relying on the lightning trainer to do the process creation imposes restrictions on how users author their training programs. From use cases I've seen, especially ones that would benefit from this strategy, we have been using torchx to great effect. |
The
This is a fair point. However, at least the docs for this feature should display both options and mention their differences. |
🚀 Feature
Support a manual parallelization option
Motivation
Now that the Strategy refactor is complete, this unlocks a step change for research flexibility. Users no longer have to override 2 different classes (TrainingTypePlugin & Accelerator) to be able to implement custom parallelism handling, which widens the set of use cases Lightning can support as a training loop framework.
There are users who have highly customized parallelization requirements.
For instance:
Model parallelism variants: Users of plain PyTorch can partition parameters across devices. However, Lightning so far has not allowed this: distributed training forces some sort of data-parallel variant. Lightning natively doesn't support open-ended model parallelism as the nn.Modules inside of the LightningModule are a black-box to the Trainer (this is necessary for generality of the Trainer). Example use case: recommendation models often have large embedding tables that cannot fit on a single device. The sharding of these tables is highly customized. There's no single-device alternative for training these models, so the modeling logic is written in a way that assumes distributed training.
Some users may want to combine different module wrappers, such as wrap parts of their models in DDP and partition other parts with techniques like FSDP.
Variable batch sizes that require normalizing the reduced gradient by the number of samples across all batches used in that step, rather than the DDP default of normalizing by world size.
Rather than require each of these users to learn all about the Strategy codebase to be able to customize this, I propose a "manual" parallel strategy which delegates this logic back to the LightningModule.
This way, all of the modeling logic sits in one place. This is easier for researchers to get started without needing to learn another abstraction. If these techniques pan out to be more general, they can be abstracted out to fit into the Strategy interface, which makes them shareable across projects.
In this setting, the user assumes responsibility for the following:
The Trainer/Strategy will still handle:
This is intended for power users who know exactly what they're doing.
The terminology manual parallel follows the precedent of manual optimization: https://pytorch-lightning.readthedocs.io/en/latest/common/optimizers.html#manual-optimization
This is also the motivation for the PRs removing dependencies on
LightningModule.device
within the Trainer:root_device
in XLAStatsMonitor callback #11749root_device
in DeviceStatsMonitor callback #11748root_device
instead oflightning_module.device
#11734LightningModule.device
is not properly defined for use cases where the LightningModule's parameters sit on multiple devices. This proposal aims to remove the requirement for users of these LightningModules to callLightningModule.to(...)
before executing a Trainer function.Pitch
Define a new strategy class like this:
Example of a LightningModule which is inherently distributed aware
Alternatives
Additional context
Idea for manual parallelization was also raised here: #8722 (comment)
If you enjoy Lightning, check out our other projects! ⚡
Metrics: Machine learning metrics for distributed, scalable PyTorch applications.
Lite: enables pure PyTorch users to scale their existing code on any kind of device while retaining full control over their own loops and optimization logic.
Flash: The fastest way to get a Lightning baseline! A collection of tasks for fast prototyping, baselining, fine-tuning, and solving problems with deep learning.
Bolts: Pretrained SOTA Deep Learning models, callbacks, and more for research and production with PyTorch Lightning and PyTorch.
Lightning Transformers: Flexible interface for high-performance research using SOTA Transformers leveraging Pytorch Lightning, Transformers, and Hydra.
cc @Borda @awaelchli @rohitgr7 @akihironitta
The text was updated successfully, but these errors were encountered: