-
Notifications
You must be signed in to change notification settings - Fork 3.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Cleanup FSDP integration to not require boilerplate logic #8722
Comments
I agree. where either the user can check inside of regarding the state dict, would the plugin now wrap the LightningModule as whole with FSDP? |
@ananthsub this is a really good point, I realised that after we support #8593 then there is no reason that FSDP cannot wrap the entire module! I am a bit unsure exactly how the logic would proceed currently, will need some investigation! |
@SeanNaren regarding the current test example, I think this is a specific choice by the use case. If the guiding principle is to deprecate |
Seems to me there are some issues with the code snippets as written. I stumbled on this issue looking for information about whether I should still init the model in the constructor, or only in the hook.
|
@SeanNaren Do you recommend not calling setup with FSDP ? |
@fcampagnexandr TLDR: this works: from typing import Dict, Any
import torch
from pytorch_lightning import Trainer
from tests.helpers.boring_model import BoringModel
class TestFSDPModel(BoringModel):
def __init__(self):
super().__init__()
self._setup_model()
def _setup_model(self):
self.model = torch.nn.Sequential(
wrap(torch.nn.Linear(32, 32)),
torch.nn.ReLU(),
wrap(torch.nn.Linear(32, 2))
)
def configure_sharded_model(self) -> None:
self.model[0] = wrap(self.model[0])
self.model[1] = wrap(self.model[1])
def configure_optimizers(self):
return torch.optim.AdamW(self.model.parameters(), lr=1e-5)
def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
# restores the model before FSDP wraps to
# load the state dict, which doesn't have FSDP references.
self._setup_model()
model = TestFSDPModel()
trainer = Trainer(plugins='fsdp', gpus=1, fast_dev_run=True)
trainer.fit(model)
trainer.save_checkpoint('model.pt')
trainer.test(model, ckpt_path='model.pt') More details and why this is wrong (especially important to @ananthsub): The reason we have to restore the model in the import os
import torch
from fairscale.nn import FullyShardedDataParallel
from pytorch_lightning.plugins.environments.lightning_environment import find_free_network_port
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = str(find_free_network_port())
os.environ["RANK"] = "0"
os.environ["WORLD_SIZE"] = "1"
torch.distributed.init_process_group("nccl")
class MyModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.model = FullyShardedDataParallel(
torch.nn.Sequential(
FullyShardedDataParallel(torch.nn.Linear(32, 32)),
torch.nn.ReLU(),
FullyShardedDataParallel(torch.nn.Linear(32, 2))
)
)
model = MyModel()
state_dict = model.state_dict()
# crashes because `load_state_dict` hasn't been called on FSDP model!
model.load_state_dict(state_dict) When this issue has been resolved, we will always wrap the entire module in FSDP, and the plugin keeps the same reference. This is closer to intended behaviour and solves a plethora of issues as described. |
@SeanNaren @tchaton this is on @jjenniferdai and my mind as some of our large text model cases are having issues with CPU OOMs, which relates to model initialization and checkpoint loading (#9406)
Proposal: Given that
For the checkpoint state to be loaded, all layers must be initialized by the time This means if the model state dict contains FSDP weights, the LightningModule needs to initialize FSDP before loading the checkpoint. And if the LightningModule wants to load a model state dict without FSDP weights and then configure FSDP, it needs to apply the wrapper only in This is confusing since:
One potential mitigation is wrapping the entire LightningModule with FSDP and then avoiding rewrapping it later. However, I'm not sure how that will play with:
Do you think a formalization of manual parallelization is an option we could pursue here? In this case:
this latter option might be pretty niche since we will it'll be more likely that all params cannot fit on a single device. otherwise users could opt for DDP Sharded + Zero redundancy optimizer. Looking at the FSDP plugin, it's pretty minimal (some of that is due to it currently extending from DDP): https://github.com/PyTorchLightning/pytorch-lightning/blob/master/pytorch_lightning/plugins/training_type/fully_sharded.py but as long as we call |
@awaelchli Can we close this? |
Yes, I believe all the main concerns from the issue description are resolved today. |
🚀 Feature
Motivated by debugging FSDP in a recent PR made by @carmocca, I think we should try clean out the interface for FSDP.
Currently FSDP supports a case where we wrap layers inside the
configure_sharded_model
hook, with an assumption that these layers are defined outside the hook. This is probably because in most cases the model has been defined in setup, or__init__
.This can be seen here: https://github.com/PyTorchLightning/pytorch-lightning/blob/master/tests/plugins/test_ddp_fully_sharded_with_full_state_dict.py#L58-L74
Also included is a lot of boilerplate logic to handle the case where a user wants to load weights back into model, and we re-create the model -> load weights -> call configure sharded model again.
This is a bit unclean as we see an internal variable needing to be reset (
call_configure_sharded_model_hook
) and more importantly, assume that the model state hasn't been altered (which it has by FSDP which permanently flattens the parameters).Imo we should move towards this API:
and allow this to happen for large models that take time to load into memory (are quicker one module at a time):
How to actually implement this?
Once the model has been setup, ideally we should never need to set this model up again unless the model has changed (covered in the RFC #8593). This would allow the model to remain the same across stages.
Given the above, I think we'll then be able to rely on primitive state dict functions of the wrapped model via FSDP: https://fairscale.readthedocs.io/en/stable/_modules/fairscale/nn/data_parallel/fully_sharded_data_parallel.html#FullyShardedDataParallel.state_dict
cc @ananthsub
The text was updated successfully, but these errors were encountered: