diff --git a/swift/trainers/rlhf_trainer/grpo_trainer.py b/swift/trainers/rlhf_trainer/grpo_trainer.py index a1921cc99..b71bce026 100644 --- a/swift/trainers/rlhf_trainer/grpo_trainer.py +++ b/swift/trainers/rlhf_trainer/grpo_trainer.py @@ -2,21 +2,21 @@ # Part of the implementation is borrowed from huggingface/trl. import inspect import os -from collections import defaultdict, namedtuple +from collections import defaultdict from contextlib import contextmanager from typing import Any, Callable, Dict, List, Optional, Union from unittest.mock import patch import torch import torch.nn as nn -from accelerate.utils import broadcast_object_list, gather, gather_object +from accelerate.utils import broadcast_object_list, gather, gather_object, is_peft_model from transformers import PreTrainedModel from transformers.utils.versions import require_version from trl import GRPOTrainer as HFGRPOTrainer from trl.models import unwrap_model_for_generation from swift.llm import InferRequest, RequestConfig, to_device -from swift.plugin.orm import orms +from swift.plugin import orms from swift.utils import (JsonlWriter, get_device, get_device_count, get_dist_setting, get_logger, is_lmdeploy_available, is_vllm_available, is_wandb_available) from ..mixin import SwiftMixin @@ -125,7 +125,7 @@ def __init__(self, logger.warning( f'The requested device {fast_infer_device} is also used for training. ' f'This may lead to unexpected behavior. It is recommended to use a dedicated device for vLLM.') - if use_vllm and not use_lmdeploy: + if use_vllm: if not is_vllm_available(): raise ImportError('vLLM is not available and `use_vllm` is set to True. ' 'Please install vLLM with `pip install vllm` to use it.') @@ -147,8 +147,6 @@ def __init__(self, enforce_eager=args.vllm_enforce_eager, limit_mm_per_prompt=args.vllm_limit_mm_per_prompt, max_model_len=args.vllm_max_model_len) - # compat _move_model_to_vllm - self.llm = namedtuple('LLM', ['llm_engine'])(self.engine.engine.engine) self.engine.default_template = self.template elif use_lmdeploy: # https://github.com/tastelikefeet/lmdeploy.git@feat/reload_state_dict @@ -173,24 +171,6 @@ def __init__(self, device=[fast_infer_device], session_len=args.lmdeploy_session_len, cache_max_entry_count=args.lmdeploy_cache_max_entry_count) - # compat _move_model_to_vllm - import collections - import collections.abc - for type_name in collections.abc.__all__: - # AttrDict may throw not modules `Mapping` error - setattr(collections, type_name, getattr(collections.abc, type_name)) - from attrdict import AttrDict - self.llm = AttrDict({ - 'llm_engine': { - 'model_executor': { - 'driver_worker': { - 'model_runner': { - 'model': self.engine.engine.engine - } - } - } - } - }) self.engine.default_template = self.template self._last_loaded_step = 0 # tag to avoid useless loading during grad accumulation @@ -231,34 +211,72 @@ def _template_context(template): template.set_mode(mode) template.max_length = max_length + def _move_model_to_vllm_lmdeploy(self, unwrapped_model): + with unwrap_model_for_generation( + self.model_wrapped, self.accelerator, + gather_deepspeed3_params=self.args.ds3_gather_for_generation) as unwrapped_model: + if is_peft_model(unwrapped_model): + state_dict = unwrapped_model.state_dict() + # Remove base_model and base_layer prefixes + state_dict = { + k.removeprefix('base_model.model.').replace('.base_layer', ''): v + for k, v in state_dict.items() + } + # Remove values with adapter prefix (example: "_lora") + state_dict = {k: v for k, v in state_dict.items() if unwrapped_model.prefix not in k} + # When module to save, remove its prefix and discard the original module + state_dict = { + k.replace('modules_to_save.default.', ''): v + for k, v in state_dict.items() if 'original_module' not in k + } + else: + state_dict = unwrapped_model.state_dict() + if self.accelerator.is_main_process: + if self.args.use_vllm: + llm_model = self.engine.engine.engine.model_executor.driver_worker.model_runner.model + else: + llm_model = self.engine.engine.engine + llm_model.load_weights(state_dict.items()) + def _prepare_inputs(self, inputs) -> Dict[str, Union[torch.Tensor, Any]]: device = self.accelerator.device # Generate completions using either vLLM or regular generation - if self.args.use_vllm: - # First, have main process load weights if needed - if self.state.global_step != self._last_loaded_step: - self._move_model_to_vllm() - self._last_loaded_step = self.state.global_step - - # Generate completions using vLLM: gather all prompts and use them in a single call in the main process - all_inputs = gather_object(inputs) - if self.accelerator.is_main_process: - outputs = self.engine.infer(all_inputs, self.request_config, use_tqdm=False) - else: - outputs = [None] * len(all_inputs) + if self.args.use_vllm or self.args.use_lmdeploy: + # ref: https://github.com/huggingface/trl/issues/2856 + from accelerate.utils.other import is_compiled_module + with unwrap_model_for_generation( + self.model_wrapped, self.accelerator, + gather_deepspeed3_params=self.args.ds3_gather_for_generation) as unwrapped_model: + if is_compiled_module(unwrapped_model): + unwrapped_model = unwrapped_model._orig_mod + if is_peft_model(unwrapped_model): + unwrapped_model.merge_adapter() + # First, have main process load weights if needed + if self.state.global_step != self._last_loaded_step: + self._move_model_to_vllm_lmdeploy(unwrapped_model) + self._last_loaded_step = self.state.global_step + # Generate completions using vLLM: gather all prompts and use them in a single call in the main process + all_inputs = gather_object(inputs) + if self.accelerator.is_main_process: + outputs = self.engine.infer(all_inputs, self.request_config, use_tqdm=False) + else: + outputs = [None] * len(all_inputs) - # Broadcast the completions from the main process to all processes, ensuring each process receives its - # corresponding slice. - outputs = broadcast_object_list(outputs, from_process=0) + # Broadcast the completions from the main process to all processes, ensuring each process receives its + # corresponding slice. + outputs = broadcast_object_list(outputs, from_process=0) + if is_peft_model(unwrapped_model): + unwrapped_model.unmerge_adapter() else: # Regular generation path is_multimodal = self.model.model_meta.is_multimodal if is_multimodal: models = self.template.remove_post_encode_hook() - with unwrap_model_for_generation(self.model, self.accelerator): + with unwrap_model_for_generation(self.model_wrapped, self.accelerator): # same reference outputs = self.engine.infer(inputs, self.request_config, use_tqdm=False) + self.model.train() if is_multimodal: self.template.register_post_encode_hook(models)