Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[GRPO] use_peft: true together with use_vllm: true not working as intended #2856

Closed
AndreiCComan opened this issue Feb 13, 2025 · 26 comments · Fixed by #2873
Closed

[GRPO] use_peft: true together with use_vllm: true not working as intended #2856

AndreiCComan opened this issue Feb 13, 2025 · 26 comments · Fixed by #2873
Labels
🐛 bug Something isn't working ⚡ PEFT Related to PEFT

Comments

@AndreiCComan
Copy link

AndreiCComan commented Feb 13, 2025

I’ve come across a potential issue and was wondering if anyone else has experienced it. When use_peft: true is set, the behavior differs depending on use_vllm. Here is a list of combinations I tried.

  • if use_peft: true and use_vllm: false then ✅
  • if use_peft: true and use_vllm: true then ❌
  • if use_peft: false and use_vllm: true then ✅
  • if use_peft: false and use_vllm: false then ✅

Where:

  • ✅ the reward curve displays the expected pattern.
  • ❌ the reward curve remains mostly flat.

Any insights into what could cause this behavior?


Reference issue: #2725


UPDATE: I also added an MRE and plots below in the comments for clarification.

@github-actions github-actions bot added ⚡ PEFT Related to PEFT 🐛 bug Something isn't working labels Feb 13, 2025
@AndreiCComan AndreiCComan changed the title use_peft: true together with use_vllm: true not working as intended [GRPO] use_peft: true together with use_vllm: true not working as intended Feb 13, 2025
@qgallouedec
Copy link
Member

Thanks, can you please share MRE and system info (check the issue template for guidance)

@matt23654
Copy link

matt23654 commented Feb 14, 2025

I am having the same issue, reward is flat with vllm + peft, same settings had reward increasing on unsloth.

This is with paged_adamw_8bit, gradient check pointing and 32/32 lora, no qlora. Unsloth learns, accelerate + deepspeed zero3 + peft + vllm does not. Model was r1 1.5b distill.

@AndreiCComan
Copy link
Author

AndreiCComan commented Feb 14, 2025

Hi @qgallouedec, thanks for your prompt reply and for the hard work! You are right, it's not super descriptive as an issue. Hopefully the details below will provide more insights. Let me know if you need any further info.

I am currently using:

  • trl==0.15.0
  • vllm==0.7.2
  • peft==0.14.0
  • accelerate==1.1.1
  • datasets==3.2.0
  • transformers==4.48.3

I am running everything on an 8xH100 machine.

Here is the accelerate.yaml configuration

num_machines: 1
num_processes: 7
mixed_precision: bf16
distributed_type: MULTI_GPU

Here is the mre.py script

#----------------------------------------------------------------------------------------------------
from datasets import load_dataset
from trl import GRPOConfig, GRPOTrainer, TrlParser, ModelConfig, get_peft_config
#----------------------------------------------------------------------------------------------------

def main(
    grpo_config,
    model_config
):
    def reward_len(
        completions,
        **kwargs
    ):
        return [-abs(20 - len(completion)) for completion in completions]

    trainer = GRPOTrainer(
        args=grpo_config,
        model=model_config.model_name_or_path,
        peft_config=get_peft_config(model_config),
        reward_funcs=reward_len,
        train_dataset=load_dataset(
            "trl-lib/tldr",
            split="train"
        ),
    )
    trainer.train()
#----------------------------------------------------------------------------------------------------
if __name__ == '__main__':
    parser = TrlParser(
        (
            GRPOConfig,
            ModelConfig
        )
    )
    grpo_config, model_config = parser.parse_args_and_config()
    main(
        grpo_config=grpo_config,
        model_config=model_config
    )
#----------------------------------------------------------------------------------------------------

Here is the mre.yaml configuration

#----------------------------------------------------------------------------------------------------
# ModelConfig
use_peft: false
torch_dtype: bfloat16
load_in_4bit: true
model_name_or_path: Qwen/Qwen2-0.5B-Instruct
lora_target_modules: all-linear
attn_implementation: flash_attention_2
#----------------------------------------------------------------------------------------------------
# GRPOConfig
seed: 42
bf16: true

sync_ref_model: True
ref_model_sync_steps: 64

optim: adamw_torch
adam_beta1: 0.9
adam_beta2: 0.99
weight_decay: 0.1
max_grad_norm: 0.1

save_steps: 64
save_strategy: steps

use_vllm: true
vllm_device: auto
vllm_gpu_memory_utilization: 0.9

warmup_ratio: 0.1
logging_steps: 1
learning_rate: 2e-5
lr_scheduler_type: cosine

num_generations: 7
num_train_epochs: 1

per_device_eval_batch_size: 1
per_device_train_batch_size: 16

gradient_accumulation_steps: 16

gradient_checkpointing: false
gradient_checkpointing_kwargs:
  use_reentrant: false

resume_from_checkpoint: false

run_name: MRE
output_dir: MRE
#----------------------------------------------------------------------------------------------------

And this is how I run it

accelerate launch --config_file accelerate.yaml mre.py --config mre.yaml

If you switch use_peft from true to false in mre.yaml you will notice the difference.

Below you can find the plots as well.

@AIR-hl
Copy link
Contributor

AIR-hl commented Feb 14, 2025

Hi @qgallouedec, thanks for your prompt reply and for the hard work! You are right, it's not super descriptive as an issue. Hopefully the details below will provide more insights. Let me know if you need any further info.

I am currently using:

  • trl==0.15.0
  • vllm==0.7.2
  • peft==0.14.0
  • accelerate==1.1.1
  • datasets==3.2.0
  • transformers==4.48.3

I am running everything on an 8xH100 machine.

Here is the accelerate.yaml configuration

num_machines: 1
num_processes: 7
mixed_precision: bf16
distributed_type: MULTI_GPU

Here is the mre.py script

#----------------------------------------------------------------------------------------------------
from datasets import load_dataset
from trl import GRPOConfig, GRPOTrainer, TrlParser, ModelConfig, get_peft_config
#----------------------------------------------------------------------------------------------------

def main(
    grpo_config,
    model_config
):
    def reward_len(
        completions,
        **kwargs
    ):
        return [-abs(20 - len(completion)) for completion in completions]

    trainer = GRPOTrainer(
        args=grpo_config,
        model=model_config.model_name_or_path,
        peft_config=get_peft_config(model_config),
        reward_funcs=reward_len,
        train_dataset=load_dataset(
            "trl-lib/tldr",
            split="train"
        ),
    )
    trainer.train()
#----------------------------------------------------------------------------------------------------
if __name__ == '__main__':
    parser = TrlParser(
        (
            GRPOConfig,
            ModelConfig
        )
    )
    grpo_config, model_config = parser.parse_args_and_config()
    main(
        grpo_config=grpo_config,
        model_config=model_config
    )
#----------------------------------------------------------------------------------------------------

Here is the mre.yaml configuration

#----------------------------------------------------------------------------------------------------
# ModelConfig
use_peft: false
torch_dtype: bfloat16
load_in_4bit: true
model_name_or_path: Qwen/Qwen2-0.5B-Instruct
lora_target_modules: all-linear
attn_implementation: flash_attention_2
#----------------------------------------------------------------------------------------------------
# GRPOConfig
seed: 42
bf16: true

sync_ref_model: True
ref_model_sync_steps: 64

optim: adamw_torch
adam_beta1: 0.9
adam_beta2: 0.99
weight_decay: 0.1
max_grad_norm: 0.1

save_steps: 64
save_strategy: steps

use_vllm: true
vllm_device: auto
vllm_gpu_memory_utilization: 0.9

warmup_ratio: 0.1
logging_steps: 1
learning_rate: 2e-5
lr_scheduler_type: cosine

num_generations: 7
num_train_epochs: 1

per_device_eval_batch_size: 1
per_device_train_batch_size: 16

gradient_accumulation_steps: 16

gradient_checkpointing: false
gradient_checkpointing_kwargs:
  use_reentrant: false

resume_from_checkpoint: false

run_name: MRE
output_dir: MRE
#----------------------------------------------------------------------------------------------------

And this is how I run it

accelerate launch --config_file accelerate.yaml mre.py --config mre.yaml

If you switch use_vllm from true to false in mre.yaml you will notice the difference.

Below you can find the plots as well.

Image Image

Compare to train with full parameter, LoRA needs bigger learning rate, for example change 5e-7 to 1e-5. This is my experience, pls see the change in followed figures

Image

Image

@AndreiCComan
Copy link
Author

AndreiCComan commented Feb 14, 2025

Compare to train with full parameter, LoRA needs bigger learning rate, for example change 5e-7 to 1e-5. This is my experience, pls see the change in followed figures

In my case I'm using 2e-5 which I guess is high enough?

UPDATE: I increased the learning rate from 2e-5 to 2e-4. The only case where it works with use_peft: true is when use_vllm: false, but not when use_vllm: true. I've added the plots below.

@AIR-hl
Copy link
Contributor

AIR-hl commented Feb 14, 2025

In my case I'm using 2e-5 which I guess is high enough?

Another guess is that when your model scale is too small, the LoRA might lead to too few trainable parameters for convergence

@matt23654
Copy link

matt23654 commented Feb 14, 2025

In my case I'm using 2e-5 which I guess is high enough?

Another guess is that when your model scale is too small, the LoRA might lead to too few trainable parameters for convergence

Are you using unsloth in your screenshots of it working? I also find it working with unsloth with 32/32 lora and 1e-5 LR. However as I mentioned above it is not working for me with the exact same experiment using trl directly (via accelerate launch).

@AndreiCComan AndreiCComan changed the title [GRPO] use_peft: true together with use_vllm: true not working as intended [GRPO] use_peft: true not working as intended Feb 14, 2025
@AIR-hl
Copy link
Contributor

AIR-hl commented Feb 14, 2025

In my case I'm using 2e-5 which I guess is high enough?

Another guess is that when your model scale is too small, the LoRA might lead to too few trainable parameters for convergence

Are you using unsloth in your screenshots of it working? I also find it working with unsloth with 32/32 lora and 1e-5 LR. However as I mentioned above it is not working for me with the exact same experiment using trl directly (via accelerate launch).

Yeah, I used unsloth. I cant even start the training with only trl on 2 GPUs #2864

@AndreiCComan
Copy link
Author

AndreiCComan commented Feb 14, 2025

For completion, here are the updated plots with the four combinations I mentioned in the issue description. The issue seems to be use_peft: true and use_vllm: true.

@AndreiCComan
Copy link
Author

I have added a new experiment with use_peft: true and use_vllm: true and a higher learning rate (from 2e-5 to 2e-4) and it turns out that the 2e-4 curve is identical to 2e-5.

@AndreiCComan AndreiCComan changed the title [GRPO] use_peft: true not working as intended [GRPO] use_peft: true together with use_vllm: true not working as intended Feb 15, 2025
@AndreiCComan
Copy link
Author

Since in the previous plots the use_peft: true and use_vllm: false with 2e-5 learning rate case seems to be working (just slower to learn), I've also increased the learning rate to 2e-4 and updated the plots.

CONCLUSION: The only combination that does not work is use_peft: true and use_vllm: true.

@matt23654
Copy link

I've run a couple of experiments now as well -- I also concur that PEFT + VLLM is not working. Have tried with/without deepspeed, doesn't seem to effect results. From looking at the code the obvious place to look is _move_model_to_vllm. Here is what I found in case others are trying to debug:

  • The list of keys sent to VLLM in the non-PEFT and PEFT case seem to be the same.
  • VLLM's model has the weights grouped up as eg self_attn.qkv_proj whereas the state_dict splits them all up separately as eg self_attn.q_proj. This is the same for PEFT and non-PEFT cases though so I assume load_weights is in fact combining them correctly.

Its very odd it definately looks like it should be working but if its not that function where else can it be?

@AndreiCComan
Copy link
Author

AndreiCComan commented Feb 15, 2025

Thanks @matt23654 for pointing at the right thing! I've followed the example of @tgaddair in #2725 and wrote a quick, albeit somewhat hacky fix. Here is the monkey patch for the MRE:

#----------------------------------------------------------------------------------------------------
from datasets import load_dataset
from trl import GRPOConfig, GRPOTrainer, TrlParser, ModelConfig, get_peft_config
#----------------------------------------------------------------------------------------------------
# PATCHING PART
import copy
from accelerate.utils import is_peft_model
from accelerate.utils.other import is_compiled_module
from trl.models.utils import unwrap_model_for_generation
def _move_model_to_vllm_patched(
    self
):
    with unwrap_model_for_generation(
        self.model, self.accelerator, gather_deepspeed3_params=self.args.ds3_gather_for_generation
    ) as unwrapped_model:
        if is_compiled_module(unwrapped_model):
            unwrapped_model = unwrapped_model._orig_mod
        if is_peft_model(unwrapped_model):
            state_dict = copy.deepcopy(unwrapped_model).merge_and_unload().state_dict()
            # unwrapped_model.merge_adapter()
            # state_dict = unwrapped_model.state_dict()
            # unwrapped_model.unmerge_adapter()
            # # Remove base_model and base_layer prefixes
            # state_dict = {
            #     k.removeprefix("base_model.model.").replace(".base_layer", ""): v for k, v in state_dict.items()
            # }
            # # Remove values with adapter prefix (example: "_lora")
            # state_dict = {k: v for k, v in state_dict.items() if unwrapped_model.prefix not in k}
            # # When module to save, remove its prefix and discard the original module
            # state_dict = {
            #     k.replace("modules_to_save.default.", ""): v
            #     for k, v in state_dict.items()
            #     if "original_module" not in k
            # }
        else:
            state_dict = unwrapped_model.state_dict()
    if self.accelerator.is_main_process:
        llm_model = self.llm.llm_engine.model_executor.driver_worker.model_runner.model
        llm_model.load_weights(state_dict.items())
GRPOTrainer._move_model_to_vllm = _move_model_to_vllm_patched
#----------------------------------------------------------------------------------------------------
def main(
    grpo_config,
    model_config
):
    def reward_len(
        completions,
        **kwargs
    ):
        return [-abs(20 - len(completion)) for completion in completions]

    trainer = GRPOTrainer(
        args=grpo_config,
        model=model_config.model_name_or_path,
        peft_config=get_peft_config(model_config),
        reward_funcs=reward_len,
        train_dataset=load_dataset(
            "trl-lib/tldr",
            split="train"
        ),
    )
    trainer.train()
#----------------------------------------------------------------------------------------------------
if __name__ == '__main__':
    parser = TrlParser(
        (
            GRPOConfig,
            ModelConfig
        )
    )
    grpo_config, model_config = parser.parse_args_and_config()
    main(
        grpo_config=grpo_config,
        model_config=model_config
    )
#----------------------------------------------------------------------------------------------------

I'm not particularly fond of the copy.deepcopy part. I'll look for a more elegant solution when I have some time and possibly submit a PR.

As always, here are the plots that confirm that this is a viable patch.

@XZ-X
Copy link
Contributor

XZ-X commented Feb 15, 2025

@AndreiCComan Thank you for the proposed fix! Would you kindly explain how the change fix the issue? (What was the issue in the previous implementation?) Thanks!

@AndreiCComan
Copy link
Author

AndreiCComan commented Feb 15, 2025

@AndreiCComan Thank you for the proposed fix! Would you kindly explain how the change fix the issue? (What was the issue in the previous implementation?) Thanks!

In the line state_dict = copy.deepcopy(unwrapped_model).merge_and_unload().state_dict():

  • copy.deepcopy(unwrapped_model)
    • Ensures that the original unwrapped_model remains unchanged. This prevents modifications to the in-memory model, avoiding unintended side effects.
  • .merge_and_unload()
    • Merges the PEFT adapter weights into the base model. This step ensures that LoRA or other PEFT-based parameters are properly included in the extracted state_dict.
  • .state_dict()
    • Extracts the necessary weights after merging the adapters. These weights are then passed to llm_model.load_weights(state_dict.items()), ensuring that vLLM receives a correctly merged model.

The issue in the original implementation is in these lines:

unwrapped_model.merge_adapter()
state_dict = unwrapped_model.state_dict()
unwrapped_model.unmerge_adapter()

When merge_adapter() and unmerge_adapter() are called, they modify the same state_dict dictionary, leading to incorrect weights. Using copy.deepcopy() ensures the state dictionary is independent and unaffected by these changes. If fact one could still use copy.deepcopy() in these lines as follows:

unwrapped_model.merge_adapter()
# state_dict = unwrapped_model.state_dict()
state_dict = copy.deepcopy(unwrapped_model.state_dict())
unwrapped_model.unmerge_adapter()

And this will work as expected :)

Now what I'm interested in is the following: Is there a way to avoid copy.deepcopy?

@XZ-X
Copy link
Contributor

XZ-X commented Feb 15, 2025

Thanks for the explanation. I dumped the state dicts obtained from both versions and it validated your thoughts.
The weights base_model.model.model.<layer-name>.base_layer.weight seems almost the same with the original model.<layer-name>.weight (with a minor difference less than 1e-7).

It seems that PEFT model's state_dicts is a shallow copy of the model's weights.
So if we unmerge the model right after reading out the state dicts (as in the original implementation), it essentially just reads out the base model's weights.

@AndreiCComan
Copy link
Author

Thanks for the explanation. I dumped the state dicts obtained from both versions and it validated your thoughts. The weights base_model.model.model.<layer-name>.base_layer.weight seems almost the same with the original model.<layer-name>.weight (with a minor difference less than 1e-7).

It seems that PEFT model's state_dicts is a shallow copy of the model's weights. So if we unmerge the model right after reading out the state dicts (as in the original implementation), it essentially just reads out the base model's weights.

Thanks @XZ-X for the PR! You anticipated me :) Well done!

I've tested your fix-vllm-peft branch and I confirm it's working as intended! Here are the plots that validate it:

@Maghoumi
Copy link

I ran into some issues when I used fix-vllm-peft branch, on a multi-GPU setup with gradient_checkpointing=True. Sorry I don't have the full error log handy right now, but it was something along the lines of gradients being different in the forward pass compared to the re-computed values

@AndreiCComan
Copy link
Author

I ran into some issues when I used fix-vllm-peft branch, on a multi-GPU setup with gradient_checkpointing=True. Sorry I don't have the full error log handy right now, but it was something along the lines of gradients being different in the forward pass compared to the re-computed values

Hi @Maghoumi, how are you running the script? I tried to set gradient_checkpointing: true in the mre.yaml configuration above, and seems to be working on my side.

@Maghoumi
Copy link

My code is very similar to this one and I'm using accelerate launch to run with 4 GPUs using LLaMa 3.1 8B.

@AndreiCComan
Copy link
Author

My code is very similar to this one and I'm using accelerate launch to run with 4 GPUs using LLaMa 3.1 8B.

I think the issue might be the deepspeed_zero3.yaml configuration of accelerate. Can you try using the following configuration instead?

num_machines: 1
num_processes: 3 # (4 - 1 GPU for vllm)
mixed_precision: bf16
distributed_type: MULTI_GPU

Note that I'm assuming you're reserving 1 of your 4 GPUs for vLLM.

@XZ-X
Copy link
Contributor

XZ-X commented Feb 16, 2025

Thanks @Maghoumi for reporting the error, and thanks @AndreiCComan for debugging the script!

I further tried on gradient checkpointing on my training script (different to the above config, using zero2; here's my launch script and my deepspeed config) and I did not encounter the error.

@Maghoumi Would you kindly share the specific configuration that triggers the problem? I can try further if @AndreiCComan 's suggestion on modifying ds config did not fix the error.

@Maghoumi
Copy link

Maghoumi commented Feb 16, 2025

@AndreiCComan Could you clarify something for me please? For the config you proposed:

num_machines: 1
num_processes: 3 # (4 - 1 GPU for vllm)
mixed_precision: bf16
distributed_type: MULTI_GPU

You mean save these to andrei.yaml then launch with accelerate launch --config andrei.yaml instead of deepspeed_zero3.yaml?

Or do you mean modify deepspeed_zero3.yaml and overwrite only what matches with your suggestion and still launch with accelerate launch --config deepspeed_zero3.yaml?

In any case, to avoid issues on my end, it might be easier if you just paste the entire yaml config content that you want me to try, so I can try with the exact config you have in mind.

@AndreiCComan
Copy link
Author

AndreiCComan commented Feb 16, 2025

You mean save these to andrei.yaml then launch with accelerate launch --config andrei.yaml instead of deepspeed_zero3.yaml?

Yes, exactly. Substitute the config altogether.

@Maghoumi
Copy link

@AndreiCComan @XZ-X I tried the new substituted config above and I confirm training with PEFT is progressing without errors. I will let this run for a while to ensure it's actually learning. But at least I no longer see those errors I reported before.
Hope this helps!

@zaddy6
Copy link

zaddy6 commented Feb 17, 2025

@Maghoumi not the case for me

Image

purple is with peft and vllm enabled

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
🐛 bug Something isn't working ⚡ PEFT Related to PEFT
Projects
None yet
Development

Successfully merging a pull request may close this issue.

7 participants