Skip to content

Commit

Permalink
Merge branch 'fix_async_vllm_lora' into async_fix_vllm_lora
Browse files Browse the repository at this point in the history
  • Loading branch information
Jintao-Huang committed Feb 17, 2025
2 parents da57f53 + 89c39d5 commit c08e561
Showing 1 changed file with 7 additions and 24 deletions.
31 changes: 7 additions & 24 deletions swift/trainers/rlhf_trainer/grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# Part of the implementation is borrowed from huggingface/trl.
import inspect
import os
from collections import defaultdict, namedtuple
from collections import defaultdict
from contextlib import contextmanager
from typing import Any, Callable, Dict, List, Optional, Union
from unittest.mock import patch
Expand Down Expand Up @@ -147,8 +147,6 @@ def __init__(self,
enforce_eager=args.vllm_enforce_eager,
limit_mm_per_prompt=args.vllm_limit_mm_per_prompt,
max_model_len=args.vllm_max_model_len)
# compat _move_model_to_vllm
self.llm = namedtuple('LLM', ['llm_engine'])(self.engine.engine.engine)
self.engine.default_template = self.template
elif use_lmdeploy:
# https://github.com/tastelikefeet/lmdeploy.git@feat/reload_state_dict
Expand All @@ -173,24 +171,6 @@ def __init__(self,
device=[fast_infer_device],
session_len=args.lmdeploy_session_len,
cache_max_entry_count=args.lmdeploy_cache_max_entry_count)
# compat _move_model_to_vllm
import collections
import collections.abc
for type_name in collections.abc.__all__:
# AttrDict may throw not modules `Mapping` error
setattr(collections, type_name, getattr(collections.abc, type_name))
from attrdict import AttrDict
self.llm = AttrDict({
'llm_engine': {
'model_executor': {
'driver_worker': {
'model_runner': {
'model': self.engine.engine.engine
}
}
}
}
})
self.engine.default_template = self.template
self._last_loaded_step = 0 # tag to avoid useless loading during grad accumulation

Expand Down Expand Up @@ -231,7 +211,7 @@ def _template_context(template):
template.set_mode(mode)
template.max_length = max_length

def _move_model_to_vllm(self, unwrapped_model):
def _move_model_to_vllm_lmdeploy(self, unwrapped_model):
with unwrap_model_for_generation(
self.model_wrapped, self.accelerator,
gather_deepspeed3_params=self.args.ds3_gather_for_generation) as unwrapped_model:
Expand All @@ -252,7 +232,10 @@ def _move_model_to_vllm(self, unwrapped_model):
else:
state_dict = unwrapped_model.state_dict()
if self.accelerator.is_main_process:
llm_model = self.llm.llm_engine.model_executor.driver_worker.model_runner.model
if self.args.use_vllm:
llm_model = self.engine.engine.engine.model_executor.driver_worker.model_runner.model
else:
llm_model = self.engine.engine.engine
llm_model.load_weights(state_dict.items())

def _prepare_inputs(self, inputs) -> Dict[str, Union[torch.Tensor, Any]]:
Expand All @@ -271,7 +254,7 @@ def _prepare_inputs(self, inputs) -> Dict[str, Union[torch.Tensor, Any]]:
unwrapped_model.merge_adapter()
# First, have main process load weights if needed
if self.state.global_step != self._last_loaded_step:
self._move_model_to_vllm(unwrapped_model)
self._move_model_to_vllm_lmdeploy(unwrapped_model)
self._last_loaded_step = self.state.global_step
# Generate completions using vLLM: gather all prompts and use them in a single call in the main process
all_inputs = gather_object(inputs)
Expand Down

0 comments on commit c08e561

Please sign in to comment.