Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
Jintao-Huang committed Feb 17, 2025
1 parent 87270f9 commit 89c39d5
Showing 1 changed file with 58 additions and 40 deletions.
98 changes: 58 additions & 40 deletions swift/trainers/rlhf_trainer/grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,21 +2,21 @@
# Part of the implementation is borrowed from huggingface/trl.
import inspect
import os
from collections import defaultdict, namedtuple
from collections import defaultdict
from contextlib import contextmanager
from typing import Any, Callable, Dict, List, Optional, Union
from unittest.mock import patch

import torch
import torch.nn as nn
from accelerate.utils import broadcast_object_list, gather, gather_object
from accelerate.utils import broadcast_object_list, gather, gather_object, is_peft_model
from transformers import PreTrainedModel
from transformers.utils.versions import require_version
from trl import GRPOTrainer as HFGRPOTrainer
from trl.models import unwrap_model_for_generation

from swift.llm import InferRequest, RequestConfig, to_device
from swift.plugin.orm import orms
from swift.plugin import orms
from swift.utils import (JsonlWriter, get_device, get_device_count, get_dist_setting, get_logger, is_lmdeploy_available,
is_vllm_available, is_wandb_available)
from ..mixin import SwiftMixin
Expand Down Expand Up @@ -125,7 +125,7 @@ def __init__(self,
logger.warning(
f'The requested device {fast_infer_device} is also used for training. '
f'This may lead to unexpected behavior. It is recommended to use a dedicated device for vLLM.')
if use_vllm and not use_lmdeploy:
if use_vllm:
if not is_vllm_available():
raise ImportError('vLLM is not available and `use_vllm` is set to True. '
'Please install vLLM with `pip install vllm` to use it.')
Expand All @@ -147,8 +147,6 @@ def __init__(self,
enforce_eager=args.vllm_enforce_eager,
limit_mm_per_prompt=args.vllm_limit_mm_per_prompt,
max_model_len=args.vllm_max_model_len)
# compat _move_model_to_vllm
self.llm = namedtuple('LLM', ['llm_engine'])(self.engine.engine.engine)
self.engine.default_template = self.template
elif use_lmdeploy:
# https://github.com/tastelikefeet/lmdeploy.git@feat/reload_state_dict
Expand All @@ -173,24 +171,6 @@ def __init__(self,
device=[fast_infer_device],
session_len=args.lmdeploy_session_len,
cache_max_entry_count=args.lmdeploy_cache_max_entry_count)
# compat _move_model_to_vllm
import collections
import collections.abc
for type_name in collections.abc.__all__:
# AttrDict may throw not modules `Mapping` error
setattr(collections, type_name, getattr(collections.abc, type_name))
from attrdict import AttrDict
self.llm = AttrDict({
'llm_engine': {
'model_executor': {
'driver_worker': {
'model_runner': {
'model': self.engine.engine.engine
}
}
}
}
})
self.engine.default_template = self.template
self._last_loaded_step = 0 # tag to avoid useless loading during grad accumulation

Expand Down Expand Up @@ -231,34 +211,72 @@ def _template_context(template):
template.set_mode(mode)
template.max_length = max_length

def _move_model_to_vllm_lmdeploy(self, unwrapped_model):
with unwrap_model_for_generation(
self.model_wrapped, self.accelerator,
gather_deepspeed3_params=self.args.ds3_gather_for_generation) as unwrapped_model:
if is_peft_model(unwrapped_model):
state_dict = unwrapped_model.state_dict()
# Remove base_model and base_layer prefixes
state_dict = {
k.removeprefix('base_model.model.').replace('.base_layer', ''): v
for k, v in state_dict.items()
}
# Remove values with adapter prefix (example: "_lora")
state_dict = {k: v for k, v in state_dict.items() if unwrapped_model.prefix not in k}
# When module to save, remove its prefix and discard the original module
state_dict = {
k.replace('modules_to_save.default.', ''): v
for k, v in state_dict.items() if 'original_module' not in k
}
else:
state_dict = unwrapped_model.state_dict()
if self.accelerator.is_main_process:
if self.args.use_vllm:
llm_model = self.engine.engine.engine.model_executor.driver_worker.model_runner.model
else:
llm_model = self.engine.engine.engine
llm_model.load_weights(state_dict.items())

def _prepare_inputs(self, inputs) -> Dict[str, Union[torch.Tensor, Any]]:
device = self.accelerator.device

# Generate completions using either vLLM or regular generation
if self.args.use_vllm:
# First, have main process load weights if needed
if self.state.global_step != self._last_loaded_step:
self._move_model_to_vllm()
self._last_loaded_step = self.state.global_step

# Generate completions using vLLM: gather all prompts and use them in a single call in the main process
all_inputs = gather_object(inputs)
if self.accelerator.is_main_process:
outputs = self.engine.infer(all_inputs, self.request_config, use_tqdm=False)
else:
outputs = [None] * len(all_inputs)
if self.args.use_vllm or self.args.use_lmdeploy:
# ref: https://github.com/huggingface/trl/issues/2856
from accelerate.utils.other import is_compiled_module
with unwrap_model_for_generation(
self.model_wrapped, self.accelerator,
gather_deepspeed3_params=self.args.ds3_gather_for_generation) as unwrapped_model:
if is_compiled_module(unwrapped_model):
unwrapped_model = unwrapped_model._orig_mod
if is_peft_model(unwrapped_model):
unwrapped_model.merge_adapter()
# First, have main process load weights if needed
if self.state.global_step != self._last_loaded_step:
self._move_model_to_vllm_lmdeploy(unwrapped_model)
self._last_loaded_step = self.state.global_step
# Generate completions using vLLM: gather all prompts and use them in a single call in the main process
all_inputs = gather_object(inputs)
if self.accelerator.is_main_process:
outputs = self.engine.infer(all_inputs, self.request_config, use_tqdm=False)
else:
outputs = [None] * len(all_inputs)

# Broadcast the completions from the main process to all processes, ensuring each process receives its
# corresponding slice.
outputs = broadcast_object_list(outputs, from_process=0)
# Broadcast the completions from the main process to all processes, ensuring each process receives its
# corresponding slice.
outputs = broadcast_object_list(outputs, from_process=0)
if is_peft_model(unwrapped_model):
unwrapped_model.unmerge_adapter()
else:
# Regular generation path
is_multimodal = self.model.model_meta.is_multimodal
if is_multimodal:
models = self.template.remove_post_encode_hook()
with unwrap_model_for_generation(self.model, self.accelerator):
with unwrap_model_for_generation(self.model_wrapped, self.accelerator):
# same reference
outputs = self.engine.infer(inputs, self.request_config, use_tqdm=False)
self.model.train()
if is_multimodal:
self.template.register_post_encode_hook(models)

Expand Down

0 comments on commit 89c39d5

Please sign in to comment.