From 87270f9f24fa51d544000ab51c32951977a05935 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Mon, 17 Feb 2025 20:08:36 +0800 Subject: [PATCH 1/2] update --- docs/source/Instruction/GRPO.md | 8 ++-- docs/source_en/Instruction/GRPO.md | 8 ++-- examples/train/grpo/full_lmdeploy.sh | 3 +- examples/train/grpo/full_vllm.sh | 3 +- examples/train/grpo/grpo.sh | 5 ++- examples/train/grpo/lora_vllm.sh | 40 +++++++++++++++++++ examples/train/grpo/multi_node/multi_node1.sh | 3 +- examples/train/grpo/multi_node/multi_node2.sh | 3 +- examples/train/grpo/plugin/plugin.py | 2 +- examples/train/grpo/plugin/run_external_rm.sh | 3 +- swift/llm/__init__.py | 4 +- swift/llm/infer/__init__.py | 4 +- swift/llm/sampling/base.py | 3 +- swift/plugin/__init__.py | 4 ++ swift/trainers/rlhf_arguments.py | 3 -- 15 files changed, 73 insertions(+), 23 deletions(-) create mode 100644 examples/train/grpo/lora_vllm.sh diff --git a/docs/source/Instruction/GRPO.md b/docs/source/Instruction/GRPO.md index ee5fc49c7a..15f306f1da 100644 --- a/docs/source/Instruction/GRPO.md +++ b/docs/source/Instruction/GRPO.md @@ -17,7 +17,7 @@ pip install "trl>=0.15" 奖励函数接受模型生成的文本 completions 以及其他数据集中的列作为参数,并对模型生成的文本进行打分。以下是一个示例,展示了如何实现一个简单的长度奖励函数。该函数会在模型生成的文本长度超过 1024 时,给予 1.0 的奖励信号;否则,奖励信号为 0.0。 ```python -from swift.plugin.orm import ORM, orms +from swift.plugin import ORM, orms class DummyLengthRewardFunction(ORM) def __call__(completions, **kwargs): return [1.0 if len(completion) > 1024 else 0.0 for completion in completions] @@ -134,7 +134,8 @@ swift rlhf \ --num_generations 7 \ --temperature 0.9 \ --system 'examples/train/grpo/prompt.txt' \ - --deepspeed zero2 + --deepspeed zero2 \ + --log_completions true ``` 单卡 @@ -167,5 +168,6 @@ swift rlhf \ --dataset_num_proc 4 \ --num_generations 4 \ --temperature 0.9 \ - --system 'examples/train/grpo/prompt.txt' + --system 'examples/train/grpo/prompt.txt' \ + --log_completions true ``` diff --git a/docs/source_en/Instruction/GRPO.md b/docs/source_en/Instruction/GRPO.md index 84b08bf67c..7d854ee62d 100644 --- a/docs/source_en/Instruction/GRPO.md +++ b/docs/source_en/Instruction/GRPO.md @@ -19,7 +19,7 @@ pip install "trl>=0.15" A reward function takes the text `completions` generated by a model and other columns from the dataset as parameters, and scores the model's generated text. Below is an example that demonstrates how to implement a simple length-based reward function. This function will give a reward signal of 1.0 if the length of the generated text exceeds 1024; otherwise, the reward signal will be 0.0. ```python -from swift.plugin.orm import ORM, orms +from swift.plugin import ORM, orms class DummyLengthRewardFunction(ORM): def __call__(self, completions, **kwargs): @@ -138,7 +138,8 @@ swift rlhf \ --num_generations 7 \ --temperature 0.9 \ --system 'examples/train/grpo/prompt.txt' \ - --deepspeed zero2 + --deepspeed zero2 \ + --log_completions true ``` Single-GPU @@ -171,5 +172,6 @@ swift rlhf \ --dataset_num_proc 4 \ --num_generations 4 \ --temperature 0.9 \ - --system 'examples/train/grpo/prompt.txt' + --system 'examples/train/grpo/prompt.txt' \ + --log_completions true ``` diff --git a/examples/train/grpo/full_lmdeploy.sh b/examples/train/grpo/full_lmdeploy.sh index c3f0993966..d06d1a0a1a 100644 --- a/examples/train/grpo/full_lmdeploy.sh +++ b/examples/train/grpo/full_lmdeploy.sh @@ -34,4 +34,5 @@ swift rlhf \ --num_generations 3 \ --temperature 0.9 \ --system 'examples/train/grpo/prompt.txt' \ - --deepspeed zero3 + --deepspeed zero3 \ + --log_completions true diff --git a/examples/train/grpo/full_vllm.sh b/examples/train/grpo/full_vllm.sh index 292e1cb4d2..6a96641ce3 100644 --- a/examples/train/grpo/full_vllm.sh +++ b/examples/train/grpo/full_vllm.sh @@ -34,4 +34,5 @@ swift rlhf \ --num_generations 7 \ --temperature 0.9 \ --system 'examples/train/grpo/prompt.txt' \ - --deepspeed zero2 + --deepspeed zero2 \ + --log_completions true diff --git a/examples/train/grpo/grpo.sh b/examples/train/grpo/grpo.sh index 1f39720d1d..952d3b59da 100644 --- a/examples/train/grpo/grpo.sh +++ b/examples/train/grpo/grpo.sh @@ -5,7 +5,7 @@ CUDA_VISIBLE_DEVICES=0 \ swift rlhf \ --rlhf_type grpo \ - --model Qwen/Qwen2.5-7B-Instruct \ + --model Qwen/Qwen2.5-7B \ --reward_funcs accuracy format \ --train_type lora \ --lora_rank 8 \ @@ -30,4 +30,5 @@ swift rlhf \ --dataset_num_proc 4 \ --num_generations 4 \ --temperature 0.9 \ - --system 'examples/train/grpo/prompt.txt' + --system 'examples/train/grpo/prompt.txt' \ + --log_completions true diff --git a/examples/train/grpo/lora_vllm.sh b/examples/train/grpo/lora_vllm.sh new file mode 100644 index 0000000000..854dc21b29 --- /dev/null +++ b/examples/train/grpo/lora_vllm.sh @@ -0,0 +1,40 @@ +# pip install math_verify # reward function +# pip install "trl>=0.15" +# GPU memory: 2 * 80GiB + +MASTER_PORT=29501 \ +CUDA_VISIBLE_DEVICES=0,1 \ +swift rlhf \ + --rlhf_type grpo \ + --model Qwen/Qwen2.5-7B \ + --reward_funcs accuracy format \ + --train_type lora \ + --use_vllm true \ + --vllm_device auto \ + --vllm_gpu_memory_utilization 0.7 \ + --vllm_max_model_len 8192 \ + --lora_rank 8 \ + --lora_alpha 32 \ + --target_modules all-linear \ + --torch_dtype bfloat16 \ + --dataset 'AI-MO/NuminaMath-TIR#1000' \ + --max_completion_length 1024 \ + --num_train_epochs 1 \ + --per_device_train_batch_size 16 \ + --per_device_eval_batch_size 16 \ + --learning_rate 1e-5 \ + --gradient_accumulation_steps 1 \ + --eval_steps 100 \ + --save_steps 100 \ + --save_total_limit 2 \ + --logging_steps 5 \ + --max_length 2048 \ + --output_dir output \ + --warmup_ratio 0.05 \ + --dataloader_num_workers 4 \ + --dataset_num_proc 4 \ + --num_generations 16 \ + --temperature 0.9 \ + --deepspeed zero2 \ + --system 'examples/train/grpo/prompt.txt' \ + --log_completions true diff --git a/examples/train/grpo/multi_node/multi_node1.sh b/examples/train/grpo/multi_node/multi_node1.sh index 22d2772b6b..f4526157bc 100755 --- a/examples/train/grpo/multi_node/multi_node1.sh +++ b/examples/train/grpo/multi_node/multi_node1.sh @@ -36,4 +36,5 @@ swift rlhf \ --num_generations 7 \ --temperature 0.9 \ --system 'examples/train/grpo/prompt.txt' \ - --deepspeed zero2 + --deepspeed zero2 \ + --log_completions true diff --git a/examples/train/grpo/multi_node/multi_node2.sh b/examples/train/grpo/multi_node/multi_node2.sh index 8909286909..519e5489f6 100755 --- a/examples/train/grpo/multi_node/multi_node2.sh +++ b/examples/train/grpo/multi_node/multi_node2.sh @@ -34,4 +34,5 @@ swift rlhf \ --num_generations 7 \ --temperature 0.9 \ --system 'examples/train/grpo/prompt.txt' \ - --deepspeed zero2 + --deepspeed zero2 \ + --log_completions true diff --git a/examples/train/grpo/plugin/plugin.py b/examples/train/grpo/plugin/plugin.py index a478742602..c1d7a9e870 100644 --- a/examples/train/grpo/plugin/plugin.py +++ b/examples/train/grpo/plugin/plugin.py @@ -1,7 +1,7 @@ import re from typing import List -from swift.plugin.orm import ORM, orms +from swift.plugin import ORM, orms from swift.utils import get_logger logger = get_logger() diff --git a/examples/train/grpo/plugin/run_external_rm.sh b/examples/train/grpo/plugin/run_external_rm.sh index 14bb16c872..c0d232b666 100644 --- a/examples/train/grpo/plugin/run_external_rm.sh +++ b/examples/train/grpo/plugin/run_external_rm.sh @@ -31,4 +31,5 @@ swift rlhf \ --dataset_num_proc 4 \ --num_generations 4 \ --temperature 0.9 \ - --system 'examples/train/grpo/prompt.txt' + --system 'examples/train/grpo/prompt.txt' \ + --log_completions true diff --git a/swift/llm/__init__.py b/swift/llm/__init__.py index b0d10dafc5..d3818631f3 100644 --- a/swift/llm/__init__.py +++ b/swift/llm/__init__.py @@ -6,7 +6,7 @@ if TYPE_CHECKING: # Recommend using `xxx_main` from .infer import (VllmEngine, RequestConfig, LmdeployEngine, PtEngine, InferEngine, infer_main, deploy_main, - InferClient, run_deploy, AdapterRequest, prepare_model_template) + InferClient, run_deploy, AdapterRequest, prepare_model_template, BaseInferEngine) from .export import (export_main, merge_lora, quantize_model, export_to_ollama) from .eval import eval_main from .app import app_main @@ -35,7 +35,7 @@ 'rlhf': ['rlhf_main'], 'infer': [ 'deploy_main', 'VllmEngine', 'RequestConfig', 'LmdeployEngine', 'PtEngine', 'infer_main', 'InferClient', - 'run_deploy', 'InferEngine', 'AdapterRequest', 'prepare_model_template' + 'run_deploy', 'InferEngine', 'AdapterRequest', 'prepare_model_template', 'BaseInferEngine' ], 'export': ['export_main', 'merge_lora', 'quantize_model', 'export_to_ollama'], 'app': ['app_main'], diff --git a/swift/llm/infer/__init__.py b/swift/llm/infer/__init__.py index a89bf85d61..b29ef951ed 100644 --- a/swift/llm/infer/__init__.py +++ b/swift/llm/infer/__init__.py @@ -9,7 +9,7 @@ from .protocol import RequestConfig from .utils import prepare_model_template from .infer_engine import (InferEngine, VllmEngine, LmdeployEngine, PtEngine, InferClient, - prepare_generation_config, AdapterRequest) + prepare_generation_config, AdapterRequest, BaseInferEngine) else: _extra_objects = {k: v for k, v in globals().items() if not k.startswith('_')} _import_structure = { @@ -19,7 +19,7 @@ 'utils': ['prepare_model_template'], 'infer_engine': [ 'InferEngine', 'VllmEngine', 'LmdeployEngine', 'PtEngine', 'InferClient', 'prepare_generation_config', - 'AdapterRequest' + 'AdapterRequest', 'BaseInferEngine' ], } diff --git a/swift/llm/sampling/base.py b/swift/llm/sampling/base.py index ca3c818d6f..b5967e1234 100644 --- a/swift/llm/sampling/base.py +++ b/swift/llm/sampling/base.py @@ -1,8 +1,7 @@ from typing import Any, Dict, List from swift.llm import SamplingArguments -from swift.plugin.orm import orms -from swift.plugin.prm import prms +from swift.plugin import orms, prms from swift.utils import get_logger logger = get_logger() diff --git a/swift/plugin/__init__.py b/swift/plugin/__init__.py index 1cdb3a785f..f5d0612645 100644 --- a/swift/plugin/__init__.py +++ b/swift/plugin/__init__.py @@ -11,6 +11,8 @@ from .optimizer import optimizers_map from .tools import get_tools_prompt, get_tools_keyword from .tuner import Tuner, extra_tuners + from .prm import prms, PRM + from .orm import orms, ORM else: _extra_objects = {k: v for k, v in globals().items() if not k.startswith('_')} @@ -22,6 +24,8 @@ 'optimizer': ['optimizers_map'], 'tools': ['get_tools_prompt', 'get_tools_keyword'], 'tuner': ['Tuner', 'extra_tuners'], + 'prm': ['prms', 'PRM'], + 'orm': ['orms', 'ORM'] } import sys diff --git a/swift/trainers/rlhf_arguments.py b/swift/trainers/rlhf_arguments.py index 160437a571..603d1f91d8 100644 --- a/swift/trainers/rlhf_arguments.py +++ b/swift/trainers/rlhf_arguments.py @@ -49,9 +49,6 @@ class GRPOConfig(GRPOArgumentsMixin, SwiftArgumentsMixin, HfGRPOConfig): repetition_penalty: Optional[float] = None def __post_init__(self): - if self.use_lmdeploy: - # In case trl GRPOTrainer need use_vllm - self.use_vllm = True super().__post_init__() if self.cosine_max_len is None: self.cosine_max_len = self.max_completion_length From 89c39d5c3d99a7dae1f9a595f070c44f81f2098c Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Mon, 17 Feb 2025 20:25:26 +0800 Subject: [PATCH 2/2] update --- swift/trainers/rlhf_trainer/grpo_trainer.py | 98 ++++++++++++--------- 1 file changed, 58 insertions(+), 40 deletions(-) diff --git a/swift/trainers/rlhf_trainer/grpo_trainer.py b/swift/trainers/rlhf_trainer/grpo_trainer.py index a1921cc99e..b71bce0269 100644 --- a/swift/trainers/rlhf_trainer/grpo_trainer.py +++ b/swift/trainers/rlhf_trainer/grpo_trainer.py @@ -2,21 +2,21 @@ # Part of the implementation is borrowed from huggingface/trl. import inspect import os -from collections import defaultdict, namedtuple +from collections import defaultdict from contextlib import contextmanager from typing import Any, Callable, Dict, List, Optional, Union from unittest.mock import patch import torch import torch.nn as nn -from accelerate.utils import broadcast_object_list, gather, gather_object +from accelerate.utils import broadcast_object_list, gather, gather_object, is_peft_model from transformers import PreTrainedModel from transformers.utils.versions import require_version from trl import GRPOTrainer as HFGRPOTrainer from trl.models import unwrap_model_for_generation from swift.llm import InferRequest, RequestConfig, to_device -from swift.plugin.orm import orms +from swift.plugin import orms from swift.utils import (JsonlWriter, get_device, get_device_count, get_dist_setting, get_logger, is_lmdeploy_available, is_vllm_available, is_wandb_available) from ..mixin import SwiftMixin @@ -125,7 +125,7 @@ def __init__(self, logger.warning( f'The requested device {fast_infer_device} is also used for training. ' f'This may lead to unexpected behavior. It is recommended to use a dedicated device for vLLM.') - if use_vllm and not use_lmdeploy: + if use_vllm: if not is_vllm_available(): raise ImportError('vLLM is not available and `use_vllm` is set to True. ' 'Please install vLLM with `pip install vllm` to use it.') @@ -147,8 +147,6 @@ def __init__(self, enforce_eager=args.vllm_enforce_eager, limit_mm_per_prompt=args.vllm_limit_mm_per_prompt, max_model_len=args.vllm_max_model_len) - # compat _move_model_to_vllm - self.llm = namedtuple('LLM', ['llm_engine'])(self.engine.engine.engine) self.engine.default_template = self.template elif use_lmdeploy: # https://github.com/tastelikefeet/lmdeploy.git@feat/reload_state_dict @@ -173,24 +171,6 @@ def __init__(self, device=[fast_infer_device], session_len=args.lmdeploy_session_len, cache_max_entry_count=args.lmdeploy_cache_max_entry_count) - # compat _move_model_to_vllm - import collections - import collections.abc - for type_name in collections.abc.__all__: - # AttrDict may throw not modules `Mapping` error - setattr(collections, type_name, getattr(collections.abc, type_name)) - from attrdict import AttrDict - self.llm = AttrDict({ - 'llm_engine': { - 'model_executor': { - 'driver_worker': { - 'model_runner': { - 'model': self.engine.engine.engine - } - } - } - } - }) self.engine.default_template = self.template self._last_loaded_step = 0 # tag to avoid useless loading during grad accumulation @@ -231,34 +211,72 @@ def _template_context(template): template.set_mode(mode) template.max_length = max_length + def _move_model_to_vllm_lmdeploy(self, unwrapped_model): + with unwrap_model_for_generation( + self.model_wrapped, self.accelerator, + gather_deepspeed3_params=self.args.ds3_gather_for_generation) as unwrapped_model: + if is_peft_model(unwrapped_model): + state_dict = unwrapped_model.state_dict() + # Remove base_model and base_layer prefixes + state_dict = { + k.removeprefix('base_model.model.').replace('.base_layer', ''): v + for k, v in state_dict.items() + } + # Remove values with adapter prefix (example: "_lora") + state_dict = {k: v for k, v in state_dict.items() if unwrapped_model.prefix not in k} + # When module to save, remove its prefix and discard the original module + state_dict = { + k.replace('modules_to_save.default.', ''): v + for k, v in state_dict.items() if 'original_module' not in k + } + else: + state_dict = unwrapped_model.state_dict() + if self.accelerator.is_main_process: + if self.args.use_vllm: + llm_model = self.engine.engine.engine.model_executor.driver_worker.model_runner.model + else: + llm_model = self.engine.engine.engine + llm_model.load_weights(state_dict.items()) + def _prepare_inputs(self, inputs) -> Dict[str, Union[torch.Tensor, Any]]: device = self.accelerator.device # Generate completions using either vLLM or regular generation - if self.args.use_vllm: - # First, have main process load weights if needed - if self.state.global_step != self._last_loaded_step: - self._move_model_to_vllm() - self._last_loaded_step = self.state.global_step - - # Generate completions using vLLM: gather all prompts and use them in a single call in the main process - all_inputs = gather_object(inputs) - if self.accelerator.is_main_process: - outputs = self.engine.infer(all_inputs, self.request_config, use_tqdm=False) - else: - outputs = [None] * len(all_inputs) + if self.args.use_vllm or self.args.use_lmdeploy: + # ref: https://github.com/huggingface/trl/issues/2856 + from accelerate.utils.other import is_compiled_module + with unwrap_model_for_generation( + self.model_wrapped, self.accelerator, + gather_deepspeed3_params=self.args.ds3_gather_for_generation) as unwrapped_model: + if is_compiled_module(unwrapped_model): + unwrapped_model = unwrapped_model._orig_mod + if is_peft_model(unwrapped_model): + unwrapped_model.merge_adapter() + # First, have main process load weights if needed + if self.state.global_step != self._last_loaded_step: + self._move_model_to_vllm_lmdeploy(unwrapped_model) + self._last_loaded_step = self.state.global_step + # Generate completions using vLLM: gather all prompts and use them in a single call in the main process + all_inputs = gather_object(inputs) + if self.accelerator.is_main_process: + outputs = self.engine.infer(all_inputs, self.request_config, use_tqdm=False) + else: + outputs = [None] * len(all_inputs) - # Broadcast the completions from the main process to all processes, ensuring each process receives its - # corresponding slice. - outputs = broadcast_object_list(outputs, from_process=0) + # Broadcast the completions from the main process to all processes, ensuring each process receives its + # corresponding slice. + outputs = broadcast_object_list(outputs, from_process=0) + if is_peft_model(unwrapped_model): + unwrapped_model.unmerge_adapter() else: # Regular generation path is_multimodal = self.model.model_meta.is_multimodal if is_multimodal: models = self.template.remove_post_encode_hook() - with unwrap_model_for_generation(self.model, self.accelerator): + with unwrap_model_for_generation(self.model_wrapped, self.accelerator): # same reference outputs = self.engine.infer(inputs, self.request_config, use_tqdm=False) + self.model.train() if is_multimodal: self.template.register_post_encode_hook(models)