-
Notifications
You must be signed in to change notification settings - Fork 1.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Speed up PPO with ZeRO-3 by 10x 🔥 #1483
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
if generate_ref_response: | ||
with self.optional_peft_ctx(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've moved the logic for disabling the adapter to unwrap_model_for_generation()
Also, why did we do the adapter disabling for the reference model but not the active model above?
|
||
|
||
@contextmanager | ||
def unwrap_model_for_generation( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wasn't sure if this belongs in modeling_base.py
or as a utility method here - let me know what you prefer!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for 10x-ing PPO Zero-3 🚀
* Speed up PPO by 10x 🔥 * Revert * Clean up * Use relative import * Clean * Fix typing for docs
Gives the ability to add and remove the forward hooks in ZeRO 3 by using a context manager. These code changes were taken from a Huggingface [PR](huggingface/trl#1617) and integrated for direct support in DeepSpeed. This is useful in the inference case and the speedup can be observed [here](huggingface/trl#1483). --------- Co-authored-by: root <root@deepspeed-c000004.2d1icxc5dsxehnpuwt3ifc34ph.gvxx.internal.cloudapp.net> Co-authored-by: Olatunji Ruwase <[email protected]> Co-authored-by: Heyang Qin <[email protected]> Co-authored-by: Logan Adams <[email protected]>
In PPO, text generation is the main bottleneck and especially so with ZeRO-3 where weights are sharded across N devices and need to be gathered for each forward pass.
This PR introduces a new context manager called
unwrap_model_for_generation()
which does a single gather of the model weights to speed up theppo.py
example script by ~10x relative to naive ZeRO-3 inference. Thank you to @pacman100 for showing me this feature ofdeepspeed
🙏 !Note: this context manager is entirely general and can be used in other trainers. For now I've focused on PPO, but happy to roll it out to the other parts of the codebase in follow-up PRs.
As they say, a picture is worth a 1000 words and here's the comparisons against DDP / ZeRO-2 and naive ZeRO-3:
Code to test with
I've checked the script below works with DDP, DDP + LoRA, ZeRO-2, and ZeRO-3:
Inference script
Addresses the speed issue discussed in #1051