Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cleanup FSDP integration to not require boilerplate logic #8722

Closed
SeanNaren opened this issue Aug 4, 2021 · 9 comments
Closed

Cleanup FSDP integration to not require boilerplate logic #8722

SeanNaren opened this issue Aug 4, 2021 · 9 comments
Assignees
Labels
3rd party Related to a 3rd-party distributed Generic distributed-related topic feature Is an improvement or enhancement help wanted Open to be worked on

Comments

@SeanNaren
Copy link
Contributor

🚀 Feature

Motivated by debugging FSDP in a recent PR made by @carmocca, I think we should try clean out the interface for FSDP.

Currently FSDP supports a case where we wrap layers inside the configure_sharded_model hook, with an assumption that these layers are defined outside the hook. This is probably because in most cases the model has been defined in setup, or __init__.

This can be seen here: https://github.com/PyTorchLightning/pytorch-lightning/blob/master/tests/plugins/test_ddp_fully_sharded_with_full_state_dict.py#L58-L74

class TestFSDPModel(BoringModel):
    def setup(self, stage: str) -> None:
        if stage != "fit":
            # when running stages like test, validate, and predict, we will skip setting up,
            # will directly use the module itself unless we load from checkpoint
            return
        # resetting call_configure_sharded_model_hook attribute so that we could call
        # configure sharded model
        self.call_configure_sharded_model_hook = False
        # for loading full state dict, we first need to create a new unwrapped model
        # to load state dict and then wrapping
        self.layer = torch.nn.Sequential(torch.nn.Linear(32, 32), torch.nn.ReLU(), torch.nn.Linear(32, 2))

    def configure_sharded_model(self) -> None:
        for i, layer in enumerate(self.layer):
            if i % 2 == 0:
                self.layer[i] = wrap(layer)
        self.layer = wrap(self.layer)

    def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
        # when loading full state dict, we first need to create a new unwrapped model
        self.setup("fit")

Also included is a lot of boilerplate logic to handle the case where a user wants to load weights back into model, and we re-create the model -> load weights -> call configure sharded model again.

This is a bit unclean as we see an internal variable needing to be reset (call_configure_sharded_model_hook) and more importantly, assume that the model state hasn't been altered (which it has by FSDP which permanently flattens the parameters).

Imo we should move towards this API:

class TestFSDPModel(BoringModel):
    def __init__(self):
        self.layer = torch.nn.Sequential(torch.nn.Linear(32, 32), torch.nn.ReLU(), torch.nn.Linear(32, 2))

    def configure_sharded_model(self) -> None:
        for i, layer in enumerate(self.layer):
            if i % 2 == 0:
                self.layer[i] = wrap(layer)
        self.layer = wrap(self.layer)

and allow this to happen for large models that take time to load into memory (are quicker one module at a time):

class TestFSDPModel(BoringModel):
    def __init__(self):
        pass

    def configure_sharded_model(self) -> None:
        self.layer = torch.nn.Sequential(
            wrap(torch.nn.Linear(32, 32)), 
            torch.nn.ReLU(), 
           wrap(torch.nn.Linear(32, 2))
        )

How to actually implement this?

Once the model has been setup, ideally we should never need to set this model up again unless the model has changed (covered in the RFC #8593). This would allow the model to remain the same across stages.

Given the above, I think we'll then be able to rely on primitive state dict functions of the wrapped model via FSDP: https://fairscale.readthedocs.io/en/stable/_modules/fairscale/nn/data_parallel/fully_sharded_data_parallel.html#FullyShardedDataParallel.state_dict

cc @ananthsub

@SeanNaren SeanNaren added feature Is an improvement or enhancement help wanted Open to be worked on distributed Generic distributed-related topic 3rd party Related to a 3rd-party labels Aug 4, 2021
@SeanNaren SeanNaren self-assigned this Aug 4, 2021
@ananthsub
Copy link
Contributor

ananthsub commented Aug 4, 2021

I agree. call_configure_sharded_model as a property is fragile. it reminds me of #7301

where either the user can check inside of configure_sharded_model if the model is already sharded, or the framework avoids rewrapping.
@SeanNaren both the new TestFSDPModel examples look far cleaner

regarding the state dict, would the plugin now wrap the LightningModule as whole with FSDP?

@SeanNaren
Copy link
Contributor Author

@ananthsub this is a really good point, I realised that after we support #8593 then there is no reason that FSDP cannot wrap the entire module!

I am a bit unsure exactly how the logic would proceed currently, will need some investigation!

@ananthsub
Copy link
Contributor

@SeanNaren regarding the current test example, I think this is a specific choice by the use case. If the guiding principle is to deprecate call_configure_sharded_model_hook then I think the initialization structure you've provided is the natural follow up. regardless of the property, it's a much clearer approach

@fcampagnexandr
Copy link

fcampagnexandr commented Aug 26, 2021

Seems to me there are some issues with the code snippets as written. I stumbled on this issue looking for information about whether I should still init the model in the constructor, or only in the hook.
I think one model could be rewritten as:

class TestFSDPModel(BoringModel):
    def __init__(self, lazy_init:bool=False):
        if lazy_init:
          # need to define fields in constructor, or hook will fail.
           self.layer=None
       else:
           # Create layers right away, not super efficient with large models, but convenient for testing in isolation from trainer.
           configure_sharded_model()

    def configure_sharded_model(self) -> None:
        self.layer = torch.nn.Sequential(
            wrap(torch.nn.Linear(32, 32)), 
            torch.nn.ReLU(), 
           wrap(torch.nn.Linear(32, 2))
        )

@tchaton
Copy link
Contributor

tchaton commented Aug 26, 2021

@SeanNaren Do you recommend not calling setup with FSDP ?

@SeanNaren
Copy link
Contributor Author

SeanNaren commented Aug 27, 2021

@fcampagnexandr TLDR: this works:

from typing import Dict, Any

import torch
from pytorch_lightning import Trainer
from tests.helpers.boring_model import BoringModel


class TestFSDPModel(BoringModel):
    def __init__(self):
        super().__init__()
        self._setup_model()

    def _setup_model(self):
        self.model = torch.nn.Sequential(
            wrap(torch.nn.Linear(32, 32)),
            torch.nn.ReLU(),
            wrap(torch.nn.Linear(32, 2))
        )

    def configure_sharded_model(self) -> None:
        self.model[0] = wrap(self.model[0])
        self.model[1] = wrap(self.model[1])

    def configure_optimizers(self):
        return torch.optim.AdamW(self.model.parameters(), lr=1e-5)

    def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
        # restores the model before FSDP wraps to
        # load the state dict, which doesn't have FSDP references.
        self._setup_model()


model = TestFSDPModel()
trainer = Trainer(plugins='fsdp', gpus=1, fast_dev_run=True)
trainer.fit(model)
trainer.save_checkpoint('model.pt')
trainer.test(model, ckpt_path='model.pt')

More details and why this is wrong (especially important to @ananthsub):

The reason we have to restore the model in the on_load_checkpoint hook can be described as below:

import os

import torch
from fairscale.nn import FullyShardedDataParallel
from pytorch_lightning.plugins.environments.lightning_environment import find_free_network_port

os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = str(find_free_network_port())
os.environ["RANK"] = "0"
os.environ["WORLD_SIZE"] = "1"
torch.distributed.init_process_group("nccl")


class MyModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.model = FullyShardedDataParallel(
            torch.nn.Sequential(
                FullyShardedDataParallel(torch.nn.Linear(32, 32)),
                torch.nn.ReLU(),
                FullyShardedDataParallel(torch.nn.Linear(32, 2))
            )
        )


model = MyModel()
state_dict = model.state_dict()
# crashes because `load_state_dict` hasn't been called on FSDP model!
model.load_state_dict(state_dict)

When this issue has been resolved, we will always wrap the entire module in FSDP, and the plugin keeps the same reference. This is closer to intended behaviour and solves a plethora of issues as described.

@ananthsub
Copy link
Contributor

ananthsub commented Sep 20, 2021

@SeanNaren @tchaton this is on @jjenniferdai and my mind as some of our large text model cases are having issues with CPU OOMs, which relates to model initialization and checkpoint loading (#9406)

  1. I am totally onboard with deprecating call_configure_sharded_model and all the lifecycle checks it comes with. There are 2 issues I see with it right now:
  • The Trainer has an inconsistent call order across successive functions. Sometimes we call the hook, sometimes we don't. This makes debugging really challenging. The example integration assumes fit is called before test but that's not always true.
  • It's the LightningModule which is actually wrapping the layers inside of the LightningModule, not the plugin. Therefore, the user is already the one responsible for determining whether to apply the wrap, not the plugin. This is a departure from existing plugins in Lightning and starts us down supporting generic manual parallelization. Instead of relying on call_configure_sharded_model in the model and the training type plugin, users could just as well check for isinstance(self.model, FullyShardedDataParallel) and return early if it's already wrapped. It's the same spirit of Avoid rewrapping LightningModules in plugins #8593 - but controlled by the user. TLDR: Users should implement configure_sharded_model as idempotent.

Proposal: Given that LightningModule.call_configure_sharded_model_hook is not documented anywhere on the public of the LightningModule, can we remove these all properties associated with this check without a deprecation process?

  1. configure_sharded_model is not implementation agnostic: the user needs to know whether they're using FSDP/DeepSpeed/some other technique to apply the appropriate wrapping. this makes sense since the libraries involved here can be so different.

  2. Expectations around delayed initialization & checkpoint loading.
    Right now we restore model states from checkpoint into the model after setup : https://github.com/PyTorchLightning/pytorch-lightning/blob/381343a79c703f2ccf1ab7c1d87400ad6e31fdf4/pytorch_lightning/trainer/trainer.py#L987-L993

For the checkpoint state to be loaded, all layers must be initialized by the time setup completes. However, configure_sharded_model runs after setup.

This means if the model state dict contains FSDP weights, the LightningModule needs to initialize FSDP before loading the checkpoint. And if the LightningModule wants to load a model state dict without FSDP weights and then configure FSDP, it needs to apply the wrapper only in configure_sharded_model.

This is confusing since:

  • The Trainer applies the model_sharded_context only around configure_sharded_model and not around setup. I'm not sure if applying the context around both hooks is viable because we might want to do one without the other.
  • The LightningModule needs to know about the contents of the checkpoint state dict and branch accordingly. The two branches aren't equal because of the trainer differences listed above
  • Down the line, we could expect initialization with meta-devices to be the only initialization option

One potential mitigation is wrapping the entire LightningModule with FSDP and then avoiding rewrapping it later. However, I'm not sure how that will play with:

  • Other modules inside the LightningModule, like metrics or losses with parameters. Do we lose perf because those modules now get sharded too? Are there correctness issues around the state dicts o those other submodules afterward?
  • Does this break calling MyLightningModule.load_from_checkpoint because now we'd need to call the FSDP(LightningModule).load_state_dict instead? But this FSDP-wrapped LightningModule isn't visible to the user because it's an implementation detail of the trainer. DP/DDP don't face this issue since they're not destructive.

Do you think a formalization of manual parallelization is an option we could pursue here? In this case:

  • the user provides the wrapped FSDP module at either initialization (if they launched the processes themselves the torch distributed init could happen prior to the trainer invocation) or in setup
  • if resuming from checkpoint:
  • -- if it contains FSDP weights: init model in setup and then in the lightning module load checkpoint path, we delegate to FSDP's load_state_dict automatically
  • -- else: init regular model in setup, load checkpoint, and then shard in configure_sharded_model.

this latter option might be pretty niche since we will it'll be more likely that all params cannot fit on a single device. otherwise users could opt for DDP Sharded + Zero redundancy optimizer.

Looking at the FSDP plugin, it's pretty minimal (some of that is due to it currently extending from DDP): https://github.com/PyTorchLightning/pytorch-lightning/blob/master/pytorch_lightning/plugins/training_type/fully_sharded.py

but as long as we call configure_optimizers in the right spot, and save the checkpoints in the right way (either full state dict from rank 0, or shards of all ranks), then we open up a lot of flexibility for the users.

@carmocca
Copy link
Contributor

@awaelchli Can we close this?

@awaelchli
Copy link
Contributor

Yes, I believe all the main concerns from the issue description are resolved today.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
3rd party Related to a 3rd-party distributed Generic distributed-related topic feature Is an improvement or enhancement help wanted Open to be worked on
Projects
None yet
Development

No branches or pull requests

6 participants