Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix grpo vllm lora #3134

Merged
merged 13 commits into from
Feb 17, 2025
8 changes: 5 additions & 3 deletions docs/source/Instruction/GRPO.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ pip install "trl>=0.15"
奖励函数接受模型生成的文本 completions 以及其他数据集中的列作为参数,并对模型生成的文本进行打分。以下是一个示例,展示了如何实现一个简单的长度奖励函数。该函数会在模型生成的文本长度超过 1024 时,给予 1.0 的奖励信号;否则,奖励信号为 0.0。

```python
from swift.plugin.orm import ORM, orms
from swift.plugin import ORM, orms
class DummyLengthRewardFunction(ORM)
def __call__(completions, **kwargs):
return [1.0 if len(completion) > 1024 else 0.0 for completion in completions]
Expand Down Expand Up @@ -134,7 +134,8 @@ swift rlhf \
--num_generations 7 \
--temperature 0.9 \
--system 'examples/train/grpo/prompt.txt' \
--deepspeed zero2
--deepspeed zero2 \
--log_completions true
```

单卡
Expand Down Expand Up @@ -167,5 +168,6 @@ swift rlhf \
--dataset_num_proc 4 \
--num_generations 4 \
--temperature 0.9 \
--system 'examples/train/grpo/prompt.txt'
--system 'examples/train/grpo/prompt.txt' \
--log_completions true
```
8 changes: 5 additions & 3 deletions docs/source_en/Instruction/GRPO.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ pip install "trl>=0.15"
A reward function takes the text `completions` generated by a model and other columns from the dataset as parameters, and scores the model's generated text. Below is an example that demonstrates how to implement a simple length-based reward function. This function will give a reward signal of 1.0 if the length of the generated text exceeds 1024; otherwise, the reward signal will be 0.0.

```python
from swift.plugin.orm import ORM, orms
from swift.plugin import ORM, orms

class DummyLengthRewardFunction(ORM):
def __call__(self, completions, **kwargs):
Expand Down Expand Up @@ -138,7 +138,8 @@ swift rlhf \
--num_generations 7 \
--temperature 0.9 \
--system 'examples/train/grpo/prompt.txt' \
--deepspeed zero2
--deepspeed zero2 \
--log_completions true
```

Single-GPU
Expand Down Expand Up @@ -171,5 +172,6 @@ swift rlhf \
--dataset_num_proc 4 \
--num_generations 4 \
--temperature 0.9 \
--system 'examples/train/grpo/prompt.txt'
--system 'examples/train/grpo/prompt.txt' \
--log_completions true
```
3 changes: 2 additions & 1 deletion examples/train/grpo/full_lmdeploy.sh
Original file line number Diff line number Diff line change
Expand Up @@ -34,4 +34,5 @@ swift rlhf \
--num_generations 3 \
--temperature 0.9 \
--system 'examples/train/grpo/prompt.txt' \
--deepspeed zero3
--deepspeed zero3 \
--log_completions true
3 changes: 2 additions & 1 deletion examples/train/grpo/full_vllm.sh
Original file line number Diff line number Diff line change
Expand Up @@ -34,4 +34,5 @@ swift rlhf \
--num_generations 7 \
--temperature 0.9 \
--system 'examples/train/grpo/prompt.txt' \
--deepspeed zero2
--deepspeed zero2 \
--log_completions true
98 changes: 0 additions & 98 deletions examples/train/grpo/grpo.py

This file was deleted.

5 changes: 3 additions & 2 deletions examples/train/grpo/grpo.sh
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
CUDA_VISIBLE_DEVICES=0 \
swift rlhf \
--rlhf_type grpo \
--model Qwen/Qwen2.5-7B-Instruct \
--model Qwen/Qwen2.5-7B \
--reward_funcs accuracy format \
--train_type lora \
--lora_rank 8 \
Expand All @@ -30,4 +30,5 @@ swift rlhf \
--dataset_num_proc 4 \
--num_generations 4 \
--temperature 0.9 \
--system 'examples/train/grpo/prompt.txt'
--system 'examples/train/grpo/prompt.txt' \
--log_completions true
40 changes: 40 additions & 0 deletions examples/train/grpo/lora_vllm.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# pip install math_verify # reward function
# pip install "trl>=0.15"
# GPU memory: 2 * 80GiB

MASTER_PORT=29501 \
CUDA_VISIBLE_DEVICES=0,1 \
swift rlhf \
--rlhf_type grpo \
--model Qwen/Qwen2.5-7B \
--reward_funcs accuracy format \
--train_type lora \
--use_vllm true \
--vllm_device auto \
--vllm_gpu_memory_utilization 0.7 \
--vllm_max_model_len 8192 \
--lora_rank 8 \
--lora_alpha 32 \
--target_modules all-linear \
--torch_dtype bfloat16 \
--dataset 'AI-MO/NuminaMath-TIR#1000' \
--max_completion_length 1024 \
--num_train_epochs 1 \
--per_device_train_batch_size 16 \
--per_device_eval_batch_size 16 \
--learning_rate 1e-5 \
--gradient_accumulation_steps 1 \
--eval_steps 100 \
--save_steps 100 \
--save_total_limit 2 \
--logging_steps 5 \
--max_length 2048 \
--output_dir output \
--warmup_ratio 0.05 \
--dataloader_num_workers 4 \
--dataset_num_proc 4 \
--num_generations 16 \
--temperature 0.9 \
--deepspeed zero2 \
--system 'examples/train/grpo/prompt.txt' \
--log_completions true
3 changes: 2 additions & 1 deletion examples/train/grpo/multi_node/multi_node1.sh
Original file line number Diff line number Diff line change
Expand Up @@ -36,4 +36,5 @@ swift rlhf \
--num_generations 7 \
--temperature 0.9 \
--system 'examples/train/grpo/prompt.txt' \
--deepspeed zero2
--deepspeed zero2 \
--log_completions true
3 changes: 2 additions & 1 deletion examples/train/grpo/multi_node/multi_node2.sh
Original file line number Diff line number Diff line change
Expand Up @@ -34,4 +34,5 @@ swift rlhf \
--num_generations 7 \
--temperature 0.9 \
--system 'examples/train/grpo/prompt.txt' \
--deepspeed zero2
--deepspeed zero2 \
--log_completions true
2 changes: 1 addition & 1 deletion examples/train/grpo/plugin/plugin.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import re
from typing import List

from swift.plugin.orm import ORM, orms
from swift.plugin import ORM, orms
from swift.utils import get_logger

logger = get_logger()
Expand Down
3 changes: 2 additions & 1 deletion examples/train/grpo/plugin/run_external_rm.sh
Original file line number Diff line number Diff line change
Expand Up @@ -31,4 +31,5 @@ swift rlhf \
--dataset_num_proc 4 \
--num_generations 4 \
--temperature 0.9 \
--system 'examples/train/grpo/prompt.txt'
--system 'examples/train/grpo/prompt.txt' \
--log_completions true
4 changes: 2 additions & 2 deletions swift/llm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
if TYPE_CHECKING:
# Recommend using `xxx_main`
from .infer import (VllmEngine, RequestConfig, LmdeployEngine, PtEngine, InferEngine, infer_main, deploy_main,
InferClient, run_deploy, AdapterRequest, prepare_model_template)
InferClient, run_deploy, AdapterRequest, prepare_model_template, BaseInferEngine)
from .export import (export_main, merge_lora, quantize_model, export_to_ollama)
from .eval import eval_main
from .app import app_main
Expand Down Expand Up @@ -35,7 +35,7 @@
'rlhf': ['rlhf_main'],
'infer': [
'deploy_main', 'VllmEngine', 'RequestConfig', 'LmdeployEngine', 'PtEngine', 'infer_main', 'InferClient',
'run_deploy', 'InferEngine', 'AdapterRequest', 'prepare_model_template'
'run_deploy', 'InferEngine', 'AdapterRequest', 'prepare_model_template', 'BaseInferEngine'
],
'export': ['export_main', 'merge_lora', 'quantize_model', 'export_to_ollama'],
'app': ['app_main'],
Expand Down
4 changes: 2 additions & 2 deletions swift/llm/infer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from .protocol import RequestConfig
from .utils import prepare_model_template
from .infer_engine import (InferEngine, VllmEngine, LmdeployEngine, PtEngine, InferClient,
prepare_generation_config, AdapterRequest)
prepare_generation_config, AdapterRequest, BaseInferEngine)
else:
_extra_objects = {k: v for k, v in globals().items() if not k.startswith('_')}
_import_structure = {
Expand All @@ -19,7 +19,7 @@
'utils': ['prepare_model_template'],
'infer_engine': [
'InferEngine', 'VllmEngine', 'LmdeployEngine', 'PtEngine', 'InferClient', 'prepare_generation_config',
'AdapterRequest'
'AdapterRequest', 'BaseInferEngine'
],
}

Expand Down
3 changes: 1 addition & 2 deletions swift/llm/sampling/base.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
from typing import Any, Dict, List

from swift.llm import SamplingArguments
from swift.plugin.orm import orms
from swift.plugin.prm import prms
from swift.plugin import orms, prms
from swift.utils import get_logger

logger = get_logger()
Expand Down
4 changes: 4 additions & 0 deletions swift/plugin/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
from .optimizer import optimizers_map
from .tools import get_tools_prompt, get_tools_keyword
from .tuner import Tuner, extra_tuners
from .prm import prms, PRM
from .orm import orms, ORM

else:
_extra_objects = {k: v for k, v in globals().items() if not k.startswith('_')}
Expand All @@ -22,6 +24,8 @@
'optimizer': ['optimizers_map'],
'tools': ['get_tools_prompt', 'get_tools_keyword'],
'tuner': ['Tuner', 'extra_tuners'],
'prm': ['prms', 'PRM'],
'orm': ['orms', 'ORM']
}

import sys
Expand Down
3 changes: 0 additions & 3 deletions swift/trainers/rlhf_arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,6 @@ class GRPOConfig(GRPOArgumentsMixin, SwiftArgumentsMixin, HfGRPOConfig):
repetition_penalty: Optional[float] = None

def __post_init__(self):
if self.use_lmdeploy:
# In case trl GRPOTrainer need use_vllm
self.use_vllm = True
super().__post_init__()
if self.cosine_max_len is None:
self.cosine_max_len = self.max_completion_length
Loading