-
Notifications
You must be signed in to change notification settings - Fork 3.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
OOM issues with loading large model checkpoints w/ FSDP after checkpoint refactor #8043
Comments
Thanks for the issue @mleshen and the PR :) Could you pseudocode what you're doing? Is the logic that used to live in |
hi @SeanNaren! yeah, i'm sorry for the confusion. let me sketch it out so previously:
after the refactor:
these TrainingTypePlugin APIs are now called in |
so previously since our overriden method looked something like this: |
Hi, I made the change in #7928 in order to enable loading the weights for a sharded model before it gets wrapped in
I vote against that. IMO the previous was anything but clean, since the responsibility of calling hooks at the right time and order was pushed to the plugin. In PL we want to achieve consistent hook call order. Can you share the code that you used to override the previous method? How would you restore optimizers, learning rate schedulers, callbacks etc. from that serialized checkpoint?
Do you mean make each worker process load after each other? |
Hi @awaelchli! Thanks for responding. Here's the overriden method we used formerly:
Yes, I mean add an option to serialize loading the model state. |
Okay, this is what you meant. I misunderstood originally what you mentioned as "serialization" for something else. With the latest changes, the loading of checkpoints goes as follows:
As you can see, the checkpoint is kept in memory as long as needed, and restoring Trainer and LightningModule is split across multiple stages. Before, this was all restored in one single function call in a single place, that's why you were able to do the sequential loading and drop the state dict. In the current flow, you could try this in the plugin: def load_checkpoint_file(self, checkpoint_path: Union[str, Path]) -> Dict[str, Any]:
for current_worker in range(self.num_processes):
if self.local_rank == current_worker:
ckpt = super().load_checkpoint_file(checkpoint_path)
log.info(
f"Rank {self.global_rank}: done loading model states from {ckpt_path}."
)
# model states are restored
model.load_state_dict(ckpt["state_dict"]) # or however fsdp does it
del ckpt["state_dict"]
self.barrier()
# our checkpoint connector will still store this, but without the state_dict of the model, no OOM?
return ckpt
def load_model_state_dict(self, checkpoint: Mapping[str, Any]) -> None:
# override to do nothing, plugin already loaded the weights in `load_checkpoint_file()`
pass Would that be an acceptable solution? I may have a typo, apologies. |
Hi @awaelchli! Sorry for the late response, and thanks so much for your suggestion. I tried this out and we get an error with model.load_state_dict because the checkpoint file to be loaded has already been wrapped in FSDP and thus we get some KeyErrors: |
Thank you @mleshen.
So are you saying the model has already been wrapped? Can you let me know in which hook of LightningModule you are performing the wrapping? Ideally this should happen in LightningModule.configure_sharded_model |
hi @awaelchli, thanks for your patience! the problem was that I forgot to call on_load_checkpoint. i've updated in here #8515 with a method that serializes loading model state in the FSDP plugin, trying to keep everything as close to the FSDP plugin as possible instead of messing with the logic in checkpoint_connector. |
@awaelchli Can we close this? |
🐛 Bug
In #7928 the trainer logic was modified to restore the model state from the checkpoint connector instead of from the training type plugin and
restore_model_from_ckpt_path
was split into three new modular APIs. For our use case we overroderestore_model_from_ckpt_path
in the FSDP plugin to prevent CPU OOMs, and now that the functionality for restoring the model state has been offloaded to the checkpoint, we run into OOMs again.In #7509 it was proposed to solve this problem on the level of
trainer
— comment suggests offloading responsibility totraining_type_plugin
since this is not widely required outside of DDP and its derivatives, but restoring model state functionality no longer belongs to the plugin. Could we add some more memory-friendly logic to the checkpoint connector in case of multiple workers?Please reproduce using the BoringModel
To Reproduce
Use following BoringModel and post here
Expected behavior
Environment
Note:
Bugs with code
are solved faster !Colab Notebook
should be madepublic
!IDE
: Please, use our python bug_report_model.py template.Colab Notebook
: Please copy and paste the output from our environment collection script (or fill out the checklist below manually).You can get the script and run it with:
conda
,pip
, source):Additional context
The text was updated successfully, but these errors were encountered: