Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

OOM issues with loading large model checkpoints w/ FSDP after checkpoint refactor #8043

Closed
mleshen opened this issue Jun 20, 2021 · 10 comments · Fixed by #18379
Closed

OOM issues with loading large model checkpoints w/ FSDP after checkpoint refactor #8043

mleshen opened this issue Jun 20, 2021 · 10 comments · Fixed by #18379
Assignees
Labels
3rd party Related to a 3rd-party bug Something isn't working checkpointing Related to checkpointing help wanted Open to be worked on

Comments

@mleshen
Copy link

mleshen commented Jun 20, 2021

🐛 Bug

In #7928 the trainer logic was modified to restore the model state from the checkpoint connector instead of from the training type plugin and restore_model_from_ckpt_path was split into three new modular APIs. For our use case we overrode restore_model_from_ckpt_path in the FSDP plugin to prevent CPU OOMs, and now that the functionality for restoring the model state has been offloaded to the checkpoint, we run into OOMs again.

In #7509 it was proposed to solve this problem on the level of trainer — comment suggests offloading responsibility to training_type_plugin since this is not widely required outside of DDP and its derivatives, but restoring model state functionality no longer belongs to the plugin. Could we add some more memory-friendly logic to the checkpoint connector in case of multiple workers?

Please reproduce using the BoringModel

To Reproduce

Use following BoringModel and post here

Expected behavior

Environment

Note: Bugs with code are solved faster ! Colab Notebook should be made public !

You can get the script and run it with:

wget https://raw.githubusercontent.com/PyTorchLightning/pytorch-lightning/master/tests/collect_env_details.py
# For security purposes, please check the contents of collect_env_details.py before running it.
python collect_env_details.py
  • PyTorch Version (e.g., 1.0):
  • OS (e.g., Linux):
  • How you installed PyTorch (conda, pip, source):
  • Build command you used (if compiling from source):
  • Python version:
  • CUDA/cuDNN version:
  • GPU models and configuration:
  • Any other relevant information:

Additional context

@SeanNaren
Copy link
Contributor

Thanks for the issue @mleshen and the PR :)

Could you pseudocode what you're doing? Is the logic that used to live in restore_model_from_ckpt_path unable to fit into the hook load_checkpoint_file?

@mleshen
Copy link
Author

mleshen commented Jun 21, 2021

hi @SeanNaren! yeah, i'm sorry for the confusion. let me sketch it out

so previously:

def restore_model_state_from_ckpt_path: on TrainingTypePlugin
- load checkpoint
- restore datamodule states
- load state dict
- return tuple bundling optimizer states and checkpoint

this was called in trainer:
image

after the refactor:

def load_checkpoint_file: on TrainingTypePlugin
- load checkpoint
def load_model_state_dict: on TrainingTypePlugin
- load checkpoint["state_dict"]
def load_optimizer_state_dict: on TrainingTypePlugin
- load checkpoint["optimizer_states"]

these TrainingTypePlugin APIs are now called in checkpoint_connector:

image

which is now called in trainer:
image

@mleshen
Copy link
Author

mleshen commented Jun 21, 2021

so previously since load_model_state_from_ckpt was in TrainingTypePlugin, it could be overridden. for our case we overrode this method in FSDP to load the model state in a serialized way to prevent CPU OOMs. now that the functionality is in checkpoint_connector, I don't see an obvious solution for customizing this behavior. perhaps we should revert the method now in checkpoint_connector back to TrainingTypePlugin to keep it cleaner? cc: @tchaton @awaelchli

our overriden method looked something like this:
def restore_model_state_from_ckpt_path() -> Tuple[Dict, bool]:
- for current_worker in range(num_processes):
- if self.local_rank == current_worker:
- checkpoint = super().restore_model_state_from_ckpt_path()
- del ckpt["state_dict"]
- self.barrier()

@awaelchli
Copy link
Contributor

awaelchli commented Jun 22, 2021

Hi, I made the change in #7928 in order to enable loading the weights for a sharded model before it gets wrapped in configure_sharded_model(). The issue with discussion was here: #7535
I hope this provides some context.

perhaps we should revert the method now in checkpoint_connector back to TrainingTypePlugin to keep it cleaner?

I vote against that. IMO the previous was anything but clean, since the responsibility of calling hooks at the right time and order was pushed to the plugin. In PL we want to achieve consistent hook call order.

Can you share the code that you used to override the previous method? How would you restore optimizers, learning rate schedulers, callbacks etc. from that serialized checkpoint?

Could we add some more memory-friendly logic to the checkpoint connector in case of multiple workers?

Do you mean make each worker process load after each other?

@mleshen
Copy link
Author

mleshen commented Jun 22, 2021

Hi @awaelchli! Thanks for responding. Here's the overriden method we used formerly:

class DDPFullyShardedPlugin(OSSDDPFullyShardedPlugin):
    def restore_model_state_from_ckpt_path(
        self,
        ckpt_path: str,
        map_location: Callable = lambda storage, loc: storage,
    ) -> Tuple[Dict, bool]:
        ckpt = {}
        load_optimizer_states = True
        for current_worker in range(self.num_processes):
            if self.local_rank == current_worker:
                (
                    ckpt,
                    load_optimizer_states,
                ) = super().restore_model_state_from_ckpt_path(
                    ckpt_path, map_location=map_location
                )
                log.info(
                    f"Rank {self.global_rank}: done loading model states from {ckpt_path}."
                )
                # model states are restored
                del ckpt["state_dict"]
            self.barrier()
        return ckpt, load_optimizer_states

Do you mean make each worker process load after each other?

Yes, I mean add an option to serialize loading the model state.

@awaelchli
Copy link
Contributor

awaelchli commented Jun 22, 2021

Okay, this is what you meant. I misunderstood originally what you mentioned as "serialization" for something else.

With the latest changes, the loading of checkpoints goes as follows:

Trainer.fit():
    ...
    -> LightningModule.setup()
    -> load checkpoint into memory
    -> restore model weights
    -> LightningModule.configure_sharded_model()
    -> LightningModule.on_load_checkpoint()
    -> accelerator setup etc.
    ...
    -> restore optimizers, schedulers
    -> restore progress
    ...
    -> release memory for cached checkpoint
    ...
    train()
    

As you can see, the checkpoint is kept in memory as long as needed, and restoring Trainer and LightningModule is split across multiple stages. Before, this was all restored in one single function call in a single place, that's why you were able to do the sequential loading and drop the state dict. In the current flow, you could try this in the plugin:

    def load_checkpoint_file(self, checkpoint_path: Union[str, Path]) -> Dict[str, Any]:
        for current_worker in range(self.num_processes):
            if self.local_rank == current_worker:
                ckpt = super().load_checkpoint_file(checkpoint_path)
                log.info(
                    f"Rank {self.global_rank}: done loading model states from {ckpt_path}."
                )
                # model states are restored
                model.load_state_dict(ckpt["state_dict"])  # or however fsdp does it
                del ckpt["state_dict"]
            self.barrier()
            
        # our checkpoint connector will still store this, but without the state_dict of the model, no OOM?
        return ckpt 

    def load_model_state_dict(self, checkpoint: Mapping[str, Any]) -> None:
        # override to do nothing, plugin already loaded the weights in `load_checkpoint_file()`
        pass

Would that be an acceptable solution? I may have a typo, apologies.

@awaelchli awaelchli added the checkpointing Related to checkpointing label Jun 22, 2021
@mleshen
Copy link
Author

mleshen commented Jun 30, 2021

Hi @awaelchli!

Sorry for the late response, and thanks so much for your suggestion. I tried this out and we get an error with model.load_state_dict because the checkpoint file to be loaded has already been wrapped in FSDP and thus we get some KeyErrors:
Missing key(s) in state_dict: "model._fsdp_wrapped_module.flat_param", "model._fsdp_wrapped_module._fpw_module.encoder.transformer._fsdp_wrapped_module.flat_param", "model._fsdp_wrapped_module._fpw_module.encoder.transformer._fsdp_wrapped_module._fpw_module.layers.0._fsdp_wrapped_module.flat_param
I'll check in with the FairScale folks to see if there's anything we can do about this.

@edenlightning edenlightning added the waiting on author Waiting on user action, correction, or update label Jul 1, 2021
@edenlightning edenlightning added the 3rd party Related to a 3rd-party label Jul 1, 2021
@awaelchli
Copy link
Contributor

Thank you @mleshen.

model.load_state_dict because the checkpoint file to be loaded has already been wrapped in FSDP

So are you saying the model has already been wrapped? Can you let me know in which hook of LightningModule you are performing the wrapping? Ideally this should happen in LightningModule.configure_sharded_model

@mleshen
Copy link
Author

mleshen commented Jul 21, 2021

hi @awaelchli, thanks for your patience! the problem was that I forgot to call on_load_checkpoint.

i've updated in here #8515 with a method that serializes loading model state in the FSDP plugin, trying to keep everything as close to the FSDP plugin as possible instead of messing with the logic in checkpoint_connector.

@awaelchli awaelchli removed the waiting on author Waiting on user action, correction, or update label Mar 18, 2023
@carmocca
Copy link
Contributor

@awaelchli Can we close this?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
3rd party Related to a 3rd-party bug Something isn't working checkpointing Related to checkpointing help wanted Open to be worked on
Projects
None yet
Development

Successfully merging a pull request may close this issue.

5 participants