Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support user-defined parallelization in the LightningModule #11922

Closed
ananthsub opened this issue Feb 15, 2022 · 3 comments
Closed

Support user-defined parallelization in the LightningModule #11922

ananthsub opened this issue Feb 15, 2022 · 3 comments
Labels
distributed Generic distributed-related topic feature Is an improvement or enhancement strategy
Milestone

Comments

@ananthsub
Copy link
Contributor

ananthsub commented Feb 15, 2022

🚀 Feature

Support a manual parallelization option

Motivation

Now that the Strategy refactor is complete, this unlocks a step change for research flexibility. Users no longer have to override 2 different classes (TrainingTypePlugin & Accelerator) to be able to implement custom parallelism handling, which widens the set of use cases Lightning can support as a training loop framework.

There are users who have highly customized parallelization requirements.

For instance:

  • Model parallelism variants: Users of plain PyTorch can partition parameters across devices. However, Lightning so far has not allowed this: distributed training forces some sort of data-parallel variant. Lightning natively doesn't support open-ended model parallelism as the nn.Modules inside of the LightningModule are a black-box to the Trainer (this is necessary for generality of the Trainer). Example use case: recommendation models often have large embedding tables that cannot fit on a single device. The sharding of these tables is highly customized. There's no single-device alternative for training these models, so the modeling logic is written in a way that assumes distributed training.

  • Some users may want to combine different module wrappers, such as wrap parts of their models in DDP and partition other parts with techniques like FSDP.

  • Variable batch sizes that require normalizing the reduced gradient by the number of samples across all batches used in that step, rather than the DDP default of normalizing by world size.

Rather than require each of these users to learn all about the Strategy codebase to be able to customize this, I propose a "manual" parallel strategy which delegates this logic back to the LightningModule.

This way, all of the modeling logic sits in one place. This is easier for researchers to get started without needing to learn another abstraction. If these techniques pan out to be more general, they can be abstracted out to fit into the Strategy interface, which makes them shareable across projects.

In this setting, the user assumes responsibility for the following:

  • Moving modules to the corresponding devices
  • Process group initialization (if any)
  • Handling any distributed module wrappers themselves

The Trainer/Strategy will still handle:

  • The rank information (received from the cluster environment)
  • Collectives: these are common torch.distributed collective calls required elsewhere within the Trainer

This is intended for power users who know exactly what they're doing.
The terminology manual parallel follows the precedent of manual optimization: https://pytorch-lightning.readthedocs.io/en/latest/common/optimizers.html#manual-optimization

This is also the motivation for the PRs removing dependencies on LightningModule.device within the Trainer:

LightningModule.device is not properly defined for use cases where the LightningModule's parameters sit on multiple devices. This proposal aims to remove the requirement for users of these LightningModules to call LightningModule.to(...) before executing a Trainer function.

Pitch

Define a new strategy class like this:

class ManualParallelStrategy(ParallelStrategy):
    def __init__(
        self,
        accelerator: Optional["pl.accelerators.accelerator.Accelerator"] = None,
        parallel_devices: Optional[List[torch.device]] = None,
        cluster_environment: Optional[ClusterEnvironment] = None,
        checkpoint_io: Optional[CheckpointIO] = None,
        precision_plugin: Optional[PrecisionPlugin] = None,
    ):
        super().__init__(accelerator=accelerator, checkpoint_io=checkpoint_io, precision_plugin=precision_plugin)
        self.parallel_devices = parallel_devices
        self.cluster_environment = cluster_environment

    def setup_environment(self) -> None:
        # start the other scripts
        if not self.cluster_environment.creates_processes_externally:
            self._call_children_scripts()

        self.setup_distributed()
        super().setup_environment()

    def setup_distributed(self):
        # initialize process group if not already available

    @property
    def root_device(self) -> torch.device:
        """ The device where data is loaded to """
        return self.parallel_devices[self.local_rank]

    def save_checkpoint(self, checkpoint: Dict[str, Any], filepath: _PATH) -> None:
        """Save model/training states as a checkpoint file through state-dump and file-write.
        Args:
            checkpoint: dict containing model and trainer state
            filepath: write-target file's path
        """
        # By default, enable saving on all ranks for distributed checkpointing
        self.checkpoint_io.save_checkpoint(checkpoint, filepath)

    def barrier(self, *args, **kwargs) -> None:
        if _TORCH_GREATER_EQUAL_1_8 and torch.distributed.get_backend() == "nccl":
            torch.distributed.barrier(device_ids=self.determine_device_ids())
        else:
            torch.distributed.barrier()

    def broadcast(self, obj: TBroadcast, src: int = 0) -> TBroadcast:
        obj = [obj]
        if self.global_rank != src:
            obj = [None]
        torch.distributed.broadcast_object_list(obj, src, group=_group.WORLD)
        return obj[0]

    def reduce(self, tensor, group: Optional[Any] = None, reduce_op: Union[ReduceOp, str] = "mean") -> torch.Tensor:
        """Reduces a tensor from several distributed processes to one aggregated tensor.
        Args:
            tensor: the tensor to sync and reduce
            group: the process group to gather results from. Defaults to all processes (world)
            reduce_op: the reduction operation. Defaults to 'mean'/'avg'.
                Can also be a string 'sum' to calculate the sum during reduction.
        Return:
            reduced value, except when the input was not a tensor the output remains is unchanged
        """
        if isinstance(tensor, torch.Tensor):
            tensor = sync_ddp_if_available(tensor, group, reduce_op=reduce_op)
        return tensor

    def teardown(self) -> None:
        """This method is called to teardown the training process.
        It is the right place to release memory and free other resources.
        """
        self.precision_plugin.teardown()
        self.cluster_environment.teardown()

Example of a LightningModule which is inherently distributed aware

class MyLightningModule(LightningModule):
    def __init__(self):
        rank = int(os.environ["LOCAL_RANK"])
        if torch.cuda.is_available():
            device = torch.device(f"cuda:{rank}")
            backend = "nccl"
            torch.cuda.set_device(device)
        else:
            device = torch.device("cpu")
            backend = "gloo"

        if not torch.distributed.is_initialized():
            dist.init_process_group(backend=backend)
        model = MyHugeModel(....)  # might require process group to already be available
        self.model = shard_huge_model(model, device) # might require process group to already be available
        self.optimizer = MyOptimizer(self.model.parameters()) # model is already on the correct device, so this is safe to initalize now


    def configure_optimizers(self):
        return self.optimizer


trainer = Trainer(strategy="manual", accelerator="gpu", devices=8)
lit_model = MyLightningModule()
trainer.fit(lit_model)

Alternatives

Additional context

Idea for manual parallelization was also raised here: #8722 (comment)


If you enjoy Lightning, check out our other projects! ⚡

  • Metrics: Machine learning metrics for distributed, scalable PyTorch applications.

  • Lite: enables pure PyTorch users to scale their existing code on any kind of device while retaining full control over their own loops and optimization logic.

  • Flash: The fastest way to get a Lightning baseline! A collection of tasks for fast prototyping, baselining, fine-tuning, and solving problems with deep learning.

  • Bolts: Pretrained SOTA Deep Learning models, callbacks, and more for research and production with PyTorch Lightning and PyTorch.

  • Lightning Transformers: Flexible interface for high-performance research using SOTA Transformers leveraging Pytorch Lightning, Transformers, and Hydra.

cc @Borda @awaelchli @rohitgr7 @akihironitta

@ananthsub ananthsub added feature Is an improvement or enhancement strategy distributed Generic distributed-related topic labels Feb 15, 2022
@ananthsub ananthsub changed the title Support user-provided parallelization in the LightningModule Support user-defined parallelization in the LightningModule Feb 15, 2022
@carmocca carmocca moved this to Todo in Frameworks Planning Feb 16, 2022
@carmocca carmocca added this to the 1.6 milestone Feb 16, 2022
@carmocca carmocca moved this from Todo to In Progress in Frameworks Planning Feb 16, 2022
@carmocca carmocca moved this from In Progress to Todo in Frameworks Planning Feb 16, 2022
@carmocca
Copy link
Contributor

Should the model be wrapped instead in setup? It would avoid the following, right?

        rank = int(os.environ["LOCAL_RANK"])
        if torch.cuda.is_available():
            device = torch.device(f"cuda:{rank}")
            backend = "nccl"
            torch.cuda.set_device(device)
        else:
            device = torch.device("cpu")
            backend = "gloo"

        if not torch.distributed.is_initialized():
            dist.init_process_group(backend=backend)

The code above (wrapping in __init__) won't work in DDP spawn, it would also be cleaner to let Lightning create the process group etc so the users just need to wrap and create the optimizers.

Should the creation of the process group be completely customizable for support with strategies like DeepSpeed or Bagua?

@ananthsub
Copy link
Contributor Author

Initializing the model in setup has a few downsides:

  1. The model initialization & training logic are coupled together. Ideally we would have the model initialization external to the lightning module. This way, we use the lightning module as a system, as recommended by the docs. Otherwise, the lightning module needs to know how to initialize & shard the provided models. Different models may have different APIs/behaviors, all of which end up inside of setup (e.g. the MyHugeModel and shard_huge_model methods above).
  2. initializing the model inside of setup also runs into complications with loading checkpoints through load_from_checkpoint. This is a similar problem faced with FSDP when using configure_sharded_model to do the sharding.

The code above (wrapping in init) won't work in DDP spawn.

The LightningModule code is determined by the user. So they would have to determine if they require their code to work with DDP spawn vs not. I think it's going to be hard to support all of custom parallelization + checkpoint loading + spawning simultaneously.

it would also be cleaner to let Lightning create the process group etc so the users just need to wrap and create the optimizers. Should the creation of the process group be completely customizable for support with strategies like DeepSpeed or Bagua?

Users can already initialize the process group themselves if they create the processes externally. The only instances where Lightning has to create the process group is for spawn and subprocess script launch. A lighter form of customization is being worked on in #11745 .

Note: I don't want to make this an issue about supporting spawning vs not. This is only to state that relying on the lightning trainer to do the process creation imposes restrictions on how users author their training programs. From use cases I've seen, especially ones that would benefit from this strategy, we have been using torchx to great effect.

@carmocca
Copy link
Contributor

Ideally we would have the model initialization external to the lightning module. This way, we use the lightning module as a system, as recommended by the docs.

The LightningModule.setup could call a nn.Module.setup defined by the user to avoid this

also runs into complications with loading checkpoints

This is a fair point. However, at least the docs for this feature should display both options and mention their differences.

@carmocca carmocca moved this from Todo to In Progress in Frameworks Planning Feb 28, 2022
@Borda Borda modified the milestones: 1.6, 1.7 Mar 21, 2022
@carmocca carmocca modified the milestones: pl:1.7, pl:future Jul 19, 2022
@carmocca carmocca moved this from In Progress to Todo in Frameworks Planning Sep 1, 2022
@carmocca carmocca modified the milestones: future, 2.3 May 23, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
distributed Generic distributed-related topic feature Is an improvement or enhancement strategy
Projects
No open projects
Status: Todo
Development

Successfully merging a pull request may close this issue.

3 participants