-
Notifications
You must be signed in to change notification settings - Fork 1.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[GRPO] use_peft: true
together with use_vllm: true
not working as intended
#2856
Comments
use_peft: true
together with use_vllm: true
not working as intendeduse_peft: true
together with use_vllm: true
not working as intended
Thanks, can you please share MRE and system info (check the issue template for guidance) |
I am having the same issue, reward is flat with vllm + peft, same settings had reward increasing on unsloth. This is with paged_adamw_8bit, gradient check pointing and 32/32 lora, no qlora. Unsloth learns, accelerate + deepspeed zero3 + peft + vllm does not. Model was r1 1.5b distill. |
Hi @qgallouedec, thanks for your prompt reply and for the hard work! You are right, it's not super descriptive as an issue. Hopefully the details below will provide more insights. Let me know if you need any further info. I am currently using:
I am running everything on an 8xH100 machine. Here is the
Here is the
Here is the
And this is how I run it
If you switch Below you can find the plots as well. ![]() ![]() |
Compare to train with full parameter, LoRA needs bigger learning rate, for example change 5e-7 to 1e-5. This is my experience, pls see the change in followed figures |
In my case I'm using UPDATE: I increased the learning rate from |
Another guess is that when your model scale is too small, the LoRA might lead to too few trainable parameters for convergence |
Are you using unsloth in your screenshots of it working? I also find it working with unsloth with 32/32 lora and 1e-5 LR. However as I mentioned above it is not working for me with the exact same experiment using trl directly (via accelerate launch). |
use_peft: true
together with use_vllm: true
not working as intendeduse_peft: true
not working as intended
Yeah, I used unsloth. I cant even start the training with only trl on 2 GPUs #2864 |
use_peft: true
not working as intendeduse_peft: true
together with use_vllm: true
not working as intended
I've run a couple of experiments now as well -- I also concur that PEFT + VLLM is not working. Have tried with/without deepspeed, doesn't seem to effect results. From looking at the code the obvious place to look is
Its very odd it definately looks like it should be working but if its not that function where else can it be? |
Thanks @matt23654 for pointing at the right thing! I've followed the example of @tgaddair in #2725 and wrote a quick, albeit somewhat hacky fix. Here is the monkey patch for the MRE:
I'm not particularly fond of the As always, here are the plots that confirm that this is a viable patch. ![]() ![]() |
@AndreiCComan Thank you for the proposed fix! Would you kindly explain how the change fix the issue? (What was the issue in the previous implementation?) Thanks! |
In the line
The issue in the original implementation is in these lines:
When
And this will work as expected :) Now what I'm interested in is the following: Is there a way to avoid |
Thanks for the explanation. I dumped the state dicts obtained from both versions and it validated your thoughts. It seems that PEFT model's |
Thanks @XZ-X for the PR! You anticipated me :) Well done! I've tested your fix-vllm-peft branch and I confirm it's working as intended! Here are the plots that validate it: ![]() ![]() |
I ran into some issues when I used fix-vllm-peft branch, on a multi-GPU setup with |
Hi @Maghoumi, how are you running the script? I tried to set |
My code is very similar to this one and I'm using |
I think the issue might be the
Note that I'm assuming you're reserving 1 of your 4 GPUs for vLLM. |
Thanks @Maghoumi for reporting the error, and thanks @AndreiCComan for debugging the script! I further tried on gradient checkpointing on my training script (different to the above config, using zero2; here's my launch script and my deepspeed config) and I did not encounter the error. @Maghoumi Would you kindly share the specific configuration that triggers the problem? I can try further if @AndreiCComan 's suggestion on modifying ds config did not fix the error. |
@AndreiCComan Could you clarify something for me please? For the config you proposed:
You mean save these to Or do you mean modify In any case, to avoid issues on my end, it might be easier if you just paste the entire yaml config content that you want me to try, so I can try with the exact config you have in mind. |
Yes, exactly. Substitute the config altogether. |
@AndreiCComan @XZ-X I tried the new substituted config above and I confirm training with PEFT is progressing without errors. I will let this run for a while to ensure it's actually learning. But at least I no longer see those errors I reported before. |
@Maghoumi not the case for me ![]() purple is with peft and vllm enabled |
I’ve come across a potential issue and was wondering if anyone else has experienced it. When
use_peft: true
is set, the behavior differs depending onuse_vllm
. Here is a list of combinations I tried.use_peft: true
anduse_vllm: false
then ✅use_peft: true
anduse_vllm: true
then ❌use_peft: false
anduse_vllm: true
then ✅use_peft: false
anduse_vllm: false
then ✅Where:
Any insights into what could cause this behavior?
Reference issue: #2725
UPDATE: I also added an MRE and plots below in the comments for clarification.
The text was updated successfully, but these errors were encountered: