-
Notifications
You must be signed in to change notification settings - Fork 1.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Speed up ZeRO-3 generation with DPO #1543
Comments
Passing self.model_wrapped instead in unwrap_model_for_generation in gives:
Is it related to the way the model removes/adds hooks ? |
Hey @sngdng we've just opened a PR to fix the issue - please let us know if it still gives you an error! |
I just install trl from source, so I think I have applied the latest fix, but I still get the same error when running example/scripts/ppo.py with deepspeed_zero3. The first two batches ran fine, but the third batch crashed. Maybe the only difference is that I use llama-2-7b-chat. Do you have any suggstions? |
Can you please share the exact command you're running to trigger the error? |
only |
@lewtun I can confirm that the issue still persist even with the fix without the context manager it works but it is super slow.. with the context manager it still gives:
|
Hi, a recent PR brought large improvements (x10) to PPO generation with ZeRO-3.
@lewtun, you mention on the PR that it can be adapted for other trainers. I gave it a quick shot and it seems that naive applying the context manager to trainers like DPO does not work:
There seems to be an inconsistency between the base classes. Is there a reason why DPO is based on Trainer from transformers and PPO on BaseTrainer ? What would be the easy way to add this feature to other trainers ? Thanks !
The text was updated successfully, but these errors were encountered: