Skip to content

Commit

Permalink
fix grpo vllm lora (modelscope#3134)
Browse files Browse the repository at this point in the history
  • Loading branch information
Jintao-Huang authored Feb 17, 2025
1 parent 1ccbea8 commit 3a41cca
Show file tree
Hide file tree
Showing 17 changed files with 136 additions and 163 deletions.
8 changes: 5 additions & 3 deletions docs/source/Instruction/GRPO.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ pip install "trl>=0.15"
奖励函数接受模型生成的文本 completions 以及其他数据集中的列作为参数,并对模型生成的文本进行打分。以下是一个示例,展示了如何实现一个简单的长度奖励函数。该函数会在模型生成的文本长度超过 1024 时,给予 1.0 的奖励信号;否则,奖励信号为 0.0。

```python
from swift.plugin.orm import ORM, orms
from swift.plugin import ORM, orms
class DummyLengthRewardFunction(ORM)
def __call__(completions, **kwargs):
return [1.0 if len(completion) > 1024 else 0.0 for completion in completions]
Expand Down Expand Up @@ -134,7 +134,8 @@ swift rlhf \
--num_generations 7 \
--temperature 0.9 \
--system 'examples/train/grpo/prompt.txt' \
--deepspeed zero2
--deepspeed zero2 \
--log_completions true
```

单卡
Expand Down Expand Up @@ -167,5 +168,6 @@ swift rlhf \
--dataset_num_proc 4 \
--num_generations 4 \
--temperature 0.9 \
--system 'examples/train/grpo/prompt.txt'
--system 'examples/train/grpo/prompt.txt' \
--log_completions true
```
8 changes: 5 additions & 3 deletions docs/source_en/Instruction/GRPO.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ pip install "trl>=0.15"
A reward function takes the text `completions` generated by a model and other columns from the dataset as parameters, and scores the model's generated text. Below is an example that demonstrates how to implement a simple length-based reward function. This function will give a reward signal of 1.0 if the length of the generated text exceeds 1024; otherwise, the reward signal will be 0.0.

```python
from swift.plugin.orm import ORM, orms
from swift.plugin import ORM, orms

class DummyLengthRewardFunction(ORM):
def __call__(self, completions, **kwargs):
Expand Down Expand Up @@ -138,7 +138,8 @@ swift rlhf \
--num_generations 7 \
--temperature 0.9 \
--system 'examples/train/grpo/prompt.txt' \
--deepspeed zero2
--deepspeed zero2 \
--log_completions true
```

Single-GPU
Expand Down Expand Up @@ -171,5 +172,6 @@ swift rlhf \
--dataset_num_proc 4 \
--num_generations 4 \
--temperature 0.9 \
--system 'examples/train/grpo/prompt.txt'
--system 'examples/train/grpo/prompt.txt' \
--log_completions true
```
3 changes: 2 additions & 1 deletion examples/train/grpo/full_lmdeploy.sh
Original file line number Diff line number Diff line change
Expand Up @@ -34,4 +34,5 @@ swift rlhf \
--num_generations 3 \
--temperature 0.9 \
--system 'examples/train/grpo/prompt.txt' \
--deepspeed zero3
--deepspeed zero3 \
--log_completions true
3 changes: 2 additions & 1 deletion examples/train/grpo/full_vllm.sh
Original file line number Diff line number Diff line change
Expand Up @@ -34,4 +34,5 @@ swift rlhf \
--num_generations 7 \
--temperature 0.9 \
--system 'examples/train/grpo/prompt.txt' \
--deepspeed zero2
--deepspeed zero2 \
--log_completions true
98 changes: 0 additions & 98 deletions examples/train/grpo/grpo.py

This file was deleted.

5 changes: 3 additions & 2 deletions examples/train/grpo/grpo.sh
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
CUDA_VISIBLE_DEVICES=0 \
swift rlhf \
--rlhf_type grpo \
--model Qwen/Qwen2.5-7B-Instruct \
--model Qwen/Qwen2.5-7B \
--reward_funcs accuracy format \
--train_type lora \
--lora_rank 8 \
Expand All @@ -30,4 +30,5 @@ swift rlhf \
--dataset_num_proc 4 \
--num_generations 4 \
--temperature 0.9 \
--system 'examples/train/grpo/prompt.txt'
--system 'examples/train/grpo/prompt.txt' \
--log_completions true
40 changes: 40 additions & 0 deletions examples/train/grpo/lora_vllm.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# pip install math_verify # reward function
# pip install "trl>=0.15"
# GPU memory: 2 * 80GiB

MASTER_PORT=29501 \
CUDA_VISIBLE_DEVICES=0,1 \
swift rlhf \
--rlhf_type grpo \
--model Qwen/Qwen2.5-7B \
--reward_funcs accuracy format \
--train_type lora \
--use_vllm true \
--vllm_device auto \
--vllm_gpu_memory_utilization 0.7 \
--vllm_max_model_len 8192 \
--lora_rank 8 \
--lora_alpha 32 \
--target_modules all-linear \
--torch_dtype bfloat16 \
--dataset 'AI-MO/NuminaMath-TIR#1000' \
--max_completion_length 1024 \
--num_train_epochs 1 \
--per_device_train_batch_size 16 \
--per_device_eval_batch_size 16 \
--learning_rate 1e-5 \
--gradient_accumulation_steps 1 \
--eval_steps 100 \
--save_steps 100 \
--save_total_limit 2 \
--logging_steps 5 \
--max_length 2048 \
--output_dir output \
--warmup_ratio 0.05 \
--dataloader_num_workers 4 \
--dataset_num_proc 4 \
--num_generations 16 \
--temperature 0.9 \
--deepspeed zero2 \
--system 'examples/train/grpo/prompt.txt' \
--log_completions true
3 changes: 2 additions & 1 deletion examples/train/grpo/multi_node/multi_node1.sh
Original file line number Diff line number Diff line change
Expand Up @@ -36,4 +36,5 @@ swift rlhf \
--num_generations 7 \
--temperature 0.9 \
--system 'examples/train/grpo/prompt.txt' \
--deepspeed zero2
--deepspeed zero2 \
--log_completions true
3 changes: 2 additions & 1 deletion examples/train/grpo/multi_node/multi_node2.sh
Original file line number Diff line number Diff line change
Expand Up @@ -34,4 +34,5 @@ swift rlhf \
--num_generations 7 \
--temperature 0.9 \
--system 'examples/train/grpo/prompt.txt' \
--deepspeed zero2
--deepspeed zero2 \
--log_completions true
2 changes: 1 addition & 1 deletion examples/train/grpo/plugin/plugin.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import re
from typing import List

from swift.plugin.orm import ORM, orms
from swift.plugin import ORM, orms
from swift.utils import get_logger

logger = get_logger()
Expand Down
3 changes: 2 additions & 1 deletion examples/train/grpo/plugin/run_external_rm.sh
Original file line number Diff line number Diff line change
Expand Up @@ -31,4 +31,5 @@ swift rlhf \
--dataset_num_proc 4 \
--num_generations 4 \
--temperature 0.9 \
--system 'examples/train/grpo/prompt.txt'
--system 'examples/train/grpo/prompt.txt' \
--log_completions true
4 changes: 2 additions & 2 deletions swift/llm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
if TYPE_CHECKING:
# Recommend using `xxx_main`
from .infer import (VllmEngine, RequestConfig, LmdeployEngine, PtEngine, InferEngine, infer_main, deploy_main,
InferClient, run_deploy, AdapterRequest, prepare_model_template)
InferClient, run_deploy, AdapterRequest, prepare_model_template, BaseInferEngine)
from .export import (export_main, merge_lora, quantize_model, export_to_ollama)
from .eval import eval_main
from .app import app_main
Expand Down Expand Up @@ -35,7 +35,7 @@
'rlhf': ['rlhf_main'],
'infer': [
'deploy_main', 'VllmEngine', 'RequestConfig', 'LmdeployEngine', 'PtEngine', 'infer_main', 'InferClient',
'run_deploy', 'InferEngine', 'AdapterRequest', 'prepare_model_template'
'run_deploy', 'InferEngine', 'AdapterRequest', 'prepare_model_template', 'BaseInferEngine'
],
'export': ['export_main', 'merge_lora', 'quantize_model', 'export_to_ollama'],
'app': ['app_main'],
Expand Down
4 changes: 2 additions & 2 deletions swift/llm/infer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from .protocol import RequestConfig
from .utils import prepare_model_template
from .infer_engine import (InferEngine, VllmEngine, LmdeployEngine, PtEngine, InferClient,
prepare_generation_config, AdapterRequest)
prepare_generation_config, AdapterRequest, BaseInferEngine)
else:
_extra_objects = {k: v for k, v in globals().items() if not k.startswith('_')}
_import_structure = {
Expand All @@ -19,7 +19,7 @@
'utils': ['prepare_model_template'],
'infer_engine': [
'InferEngine', 'VllmEngine', 'LmdeployEngine', 'PtEngine', 'InferClient', 'prepare_generation_config',
'AdapterRequest'
'AdapterRequest', 'BaseInferEngine'
],
}

Expand Down
3 changes: 1 addition & 2 deletions swift/llm/sampling/base.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
from typing import Any, Dict, List

from swift.llm import SamplingArguments
from swift.plugin.orm import orms
from swift.plugin.prm import prms
from swift.plugin import orms, prms
from swift.utils import get_logger

logger = get_logger()
Expand Down
4 changes: 4 additions & 0 deletions swift/plugin/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
from .optimizer import optimizers_map
from .tools import get_tools_prompt, get_tools_keyword
from .tuner import Tuner, extra_tuners
from .prm import prms, PRM
from .orm import orms, ORM

else:
_extra_objects = {k: v for k, v in globals().items() if not k.startswith('_')}
Expand All @@ -22,6 +24,8 @@
'optimizer': ['optimizers_map'],
'tools': ['get_tools_prompt', 'get_tools_keyword'],
'tuner': ['Tuner', 'extra_tuners'],
'prm': ['prms', 'PRM'],
'orm': ['orms', 'ORM']
}

import sys
Expand Down
3 changes: 0 additions & 3 deletions swift/trainers/rlhf_arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,6 @@ class GRPOConfig(GRPOArgumentsMixin, SwiftArgumentsMixin, HfGRPOConfig):
repetition_penalty: Optional[float] = None

def __post_init__(self):
if self.use_lmdeploy:
# In case trl GRPOTrainer need use_vllm
self.use_vllm = True
super().__post_init__()
if self.cosine_max_len is None:
self.cosine_max_len = self.max_completion_length
Loading

0 comments on commit 3a41cca

Please sign in to comment.