Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[testing][rollout] feat: support integration of vllm>=0.7.0 (spmd-version) #209

Merged
merged 41 commits into from
Feb 14, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
be4cd50
[test] test for vllm-spmd
Jan 17, 2025
d76c04d
[test] test for sync weight in OpenRLHF style
Jan 17, 2025
e20ba1b
[chore] Remove dependencies on vllm<=0.6.3
Jan 17, 2025
ac4c91d
[test] Add time profiling on vllm sync weight
Jan 17, 2025
6fb4999
[test] Some formatting changes
Jan 17, 2025
b64a473
Merge branch 'volcengine:main' into zsl/vllm-spmd
ZSL98 Jan 17, 2025
4fe511a
Merge branch 'volcengine:main' into zsl/vllm-spmd
ZSL98 Jan 18, 2025
c77bbec
Merge branch 'volcengine:main' into zsl/vllm-spmd
ZSL98 Jan 22, 2025
234a52d
Add a tiny version of run_qwen2-7b_seq_balance.sh
Jan 22, 2025
bc689b6
init some files
Jan 22, 2025
6f55342
Merge remote-tracking branch 'upstream/main' into zsl/vllm-spmd
Jan 27, 2025
5a2d526
update
Jan 28, 2025
c0a5099
update
Feb 3, 2025
4a6d686
update
Feb 4, 2025
6c78554
support fsdp
Feb 5, 2025
ef47177
support vllm>=0.7.0 and fsdp
Feb 5, 2025
e8a7487
Merge remote-tracking branch 'origin/main' into latest
Feb 5, 2025
d71b5a2
Merge branch 'volcengine:main' into latest
ZSL98 Feb 5, 2025
18ed87f
Merge branch 'volcengine:main' into latest
ZSL98 Feb 6, 2025
a27ee29
remove redundant files
Feb 6, 2025
e936114
update
Feb 6, 2025
ffa88ed
[test] update run_fsdp_vllm_spmd.py
Feb 6, 2025
0fd82c6
Merge branch 'volcengine:main' into latest
ZSL98 Feb 7, 2025
c34d67e
fix
Feb 8, 2025
a6237a6
license
Feb 8, 2025
71f8a84
update
Feb 11, 2025
f6d2ef9
update
Feb 11, 2025
6d956e3
Merge branch 'volcengine:main' into latest
ZSL98 Feb 11, 2025
7ba97d8
update
Feb 11, 2025
6efc130
update
Feb 11, 2025
bf33df3
doc
Feb 13, 2025
c99d258
update doc
Feb 13, 2025
de27933
update
Feb 13, 2025
eb8242d
Merge branch 'volcengine:main' into latest
ZSL98 Feb 14, 2025
a85d88f
quick fix
Feb 14, 2025
43979c7
Merge branch 'latest' of https://github.com/ZSL98/verl into latest
Feb 14, 2025
fcd1474
doc
Feb 14, 2025
fcc0451
Merge branch 'volcengine:main' into latest
ZSL98 Feb 14, 2025
fe1ad99
reformat
Feb 14, 2025
1bf9e5a
Merge branch 'latest' of https://github.com/ZSL98/verl into latest
Feb 14, 2025
8f2ffae
fix ci
Feb 14, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .github/workflows/vllm.yml
Original file line number Diff line number Diff line change
Expand Up @@ -40,3 +40,8 @@ jobs:
run: |
cd tests/rollout
torchrun --standalone --nnodes=1 --nproc_per_node=8 $(which pytest) -s test_vllm_hf_loader.py
- name: Test the latest vLLM
run: |
pip3 install --upgrade vllm
cd tests/rollout
torchrun --standalone --nnodes=1 --nproc_per_node=4 $(which pytest) -s test_vllm_spmd.py
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,9 @@ Checkout this [Jupyter Notebook](https://github.com/volcengine/verl/tree/main/ex
## Performance Tuning Guide
The performance is essential for on-policy RL algorithm. We write a detailed performance tuning guide to allow people tune the performance. See [here](https://verl.readthedocs.io/en/latest/perf/perf_tuning.html) for more details.

## vLLM v0.7 testing version
We have released a testing version of veRL that supports vLLM>=0.7.0. Please refer to [this document](https://github.com/volcengine/verl/docs/README_vllm0.7.md) for installation guide and more information.

## Contribution Guide
Contributions from the community are welcome!

Expand Down
62 changes: 62 additions & 0 deletions docs/README_vllm0.7.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
# Readme for verl(vllm>=0.7) version

## Installation

Note: This version of veRL supports **FSDP** for training and **vLLM** for rollout. (Megatron-LM is not supported yet.)

```
# Create the conda environment
conda create -n verl python==3.10
conda activate verl

# Install verl
git clone https://github.com/volcengine/verl.git
cd verl
pip3 install -e .
# Install vLLM>=0.7
pip3 install "vllm>=0.7.0"
# Install flash-attn
pip3 install flash-attn --no-build-isolation

```

For existing stable vllm versions (<=0.7.2), you also need to make some tiny patches manually on vllm (/path/to/site-packages/vllm after installation) after the above steps:

- vllm/distributed/parallel_state.py: Remove the assertion below:

```
if (world_size
!= tensor_model_parallel_size * pipeline_model_parallel_size):
raise RuntimeError(
f"world_size ({world_size}) is not equal to "
f"tensor_model_parallel_size ({tensor_model_parallel_size}) x "
f"pipeline_model_parallel_size ({pipeline_model_parallel_size})")

```

- vllm/executor/uniproc_executor.py: change `local_rank = rank` to `local_rank = int(os.environ["LOCAL_RANK"])`
- vllm/model_executor/model_loader/weight_utils.py: remove the `torch.cuda.empty_cache()` in `pt_weights_iterator`

These modifications have already been merged into the main branch of vLLM. To avoid modifying these files manually, you can directly build vLLM from source.

## Features

### Use cuda graph

After installation, examples using FSDP as training backends can be used. By default, the `enforce_eager` is set to True, which disables the cuda graph. To enjoy cuda graphs and the sleep mode of vLLM>=0.7, add the following lines to the bash script:

```
actor_rollout_ref.rollout.enforce_eager=False \
actor_rollout_ref.rollout.free_cache_engine=False \

```

For a typical job like examples/ppo_trainer/run_qwen2-7b_seq_balance.sh, the rollout generation time is 115 seconds with vLLM0.6.3, while it is 85 seconds with vLLM0.7.0. By enabling the cudagraph, the generation duration is further reduced to 62 seconds.

**Note:** Currently, if the `n` is greater than 1 in `SamplingParams` in vLLM>=0.7, there is a potential performance issue on the stability of rollout generation time (Some iterations would see generation time bursts). We are working with the vLLM team to check this issue.

### Other features in vLLM

1. **num_scheduler_step>1:** not supported yet (weight loading has not been aligned with `MultiStepModelRunner`)
2. **Prefix caching:** not supported yet (vLLM sleep mode does not support prefix caching)
3. **Chunked prefill:** supported
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ dependencies = [
"ray>=2.38",
"tensordict",
"transformers<4.48",
"vllm<=0.6.3",
"vllm<=0.7.3",
"peft",
"liger-kernel",
"pylatexenc",
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ pybind11
ray>=2.38
tensordict<0.6
transformers<4.48
vllm<=0.6.3
vllm
wandb
liger-kernel
pylatexenc
Expand Down
20 changes: 19 additions & 1 deletion tests/rollout/run_fsdp_vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@

from vllm import SamplingParams

import time
import torch.distributed as dist


def main():
assert torch.cuda.is_available(), 'CUDA must be present to run FSDP vLLM example'
Expand Down Expand Up @@ -112,10 +115,25 @@ def main():
enforce_eager=True,
dtype='bfloat16',
load_format='dummy_dtensor',
gpu_memory_utilization=0.1,
gpu_memory_utilization=0.8,
trust_remote_code=True)

# Warmup iterations
for _ in range(10):
torch.cuda.synchronize()
llm.sync_model_weights(actor_weights=state_dict, load_format='dtensor')
torch.cuda.synchronize()
dist.barrier()

start_time = time.time()
llm.sync_model_weights(actor_weights=state_dict, load_format='dtensor')
torch.cuda.synchronize()
dist.barrier()
end_time = time.time()

# Calculate elapsed time
elapsed_time = end_time - start_time
print(f"Time taken: {elapsed_time:.6f} seconds")

input_ids = input_ids.cuda()
attention_mask = attention_mask.cuda()
Expand Down
162 changes: 162 additions & 0 deletions tests/rollout/test_vllm_spmd.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import torch
import transformers

from vllm import LLM, SamplingParams
from verl.utils.model import update_model_config
from transformers import AutoTokenizer, AutoConfig, AutoModelForCausalLM

from transformers import GenerationConfig

from verl.utils.torch_functional import pad_sequence_to_length


def levenshtein(s1, s2):
m, n = len(s1), len(s2)
# Initialize matrix of zeros
dp = [[0] * (n + 1) for _ in range(m + 1)]
# Initialize first column and first row of the matrix
for i in range(m + 1):
dp[i][0] = i # Deletion from s1 to empty string
for j in range(n + 1):
dp[0][j] = j # Insertion to s1 from empty string
# Compute the Levenshtein distance matrix
for i in range(1, m + 1):
for j in range(1, n + 1):
cost = 0 if s1[i - 1] == s2[j - 1] else 1 # No cost if characters match
dp[i][j] = min(
dp[i - 1][j] + 1, # Deletion
dp[i][j - 1] + 1, # Insertion
dp[i - 1][j - 1] + cost # Substitution
)
return dp[m][n]


def are_lists_similar(a, b):
if len(a) != len(b):
print("The lists are of different lengths.")
return False

total_length = 0
total_diff = 0

for s1, s2 in zip(a, b):
max_len = max(len(s1), len(s2))
total_length += max_len
diff = levenshtein(s1, s2)
total_diff += diff
print(f"Comparing strings:\n{s1}\n{s2}\nDifference: {diff} characters\n")

percentage_difference = (total_diff / total_length) * 100
print(f"Total difference: {percentage_difference:.2f}%")

return percentage_difference <= 10


def test_vllm_spmd():
assert torch.cuda.device_count() >= 2, 'At least 2 GPUs is required to run tp+dp tests.'

# fill rollout config
max_prompt_length = 16
max_response_length = 16

# Initialize model and token
local_cache_path = '~/.cache/verl/rlhf'
local_cache_path = os.path.expanduser(local_cache_path)
hdfs_path = 'Qwen/Qwen2-7B-Instruct'
from verl.utils.fs import copy_local_path_from_hdfs
local_model_path = copy_local_path_from_hdfs(src=hdfs_path, cache_dir=local_cache_path)
tokenizer = AutoTokenizer.from_pretrained(local_model_path, padding_side='left')

preencode_prompts = [
"Who won the Champions League in 2019?",
"The founder of Apple is",
"What's your name",
]
tokenizer.pad_token = tokenizer.eos_token
prompts = tokenizer(preencode_prompts, return_tensors='pt', padding=True)
input_ids = prompts['input_ids']
attention_mask = prompts['attention_mask']

input_ids = pad_sequence_to_length(input_ids, max_prompt_length, tokenizer.pad_token_id, left_pad=True)
attention_mask = pad_sequence_to_length(attention_mask, max_prompt_length, 0, left_pad=True)

actor_model = AutoModelForCausalLM.from_pretrained(local_model_path)
actor_model.to(torch.bfloat16)

actor_model_config = AutoConfig.from_pretrained(local_model_path)

temperature = 0
top_p = 1

kwargs = dict(n=1,
temperature=temperature,
top_p=top_p,
max_tokens=max_response_length,
logprobs=1,
ignore_eos=True)

sampling_params = SamplingParams(**kwargs)
tensor_parallel_size = 4

llm = LLM(model=local_model_path,
enable_sleep_mode=True,
tensor_parallel_size=tensor_parallel_size,
distributed_executor_backend="external_launcher",
dtype='bfloat16',
gpu_memory_utilization=0.5)

print('start generation')
input_ids = input_ids.cuda()
attention_mask = attention_mask.cuda()
batch_size = input_ids.size(0)

generation_config = GenerationConfig(do_sample=False)
actor_model.cuda()
output = actor_model.generate(
input_ids=input_ids,
attention_mask=attention_mask,
max_new_tokens=max_response_length,
# max_length=max_length,
eos_token_id=tokenizer.eos_token_id,
pad_token_id=tokenizer.pad_token_id,
generation_config=generation_config,
# renormalize_logits=True,
output_scores=False, # this is potentially very large
return_dict_in_generate=True,
use_cache=False) # may OOM when use_cache = True
seq = output.sequences
response = seq[:, max_prompt_length:]

hf_response_tokens = tokenizer.batch_decode(response)
print(f'hf response: {hf_response_tokens}')

outputs = llm.generate(preencode_prompts, sampling_params=sampling_params, use_tqdm=False)
vllm_response_tokens = []
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
vllm_response_tokens.append(generated_text)

print(f'vllm response: {vllm_response_tokens}')
assert are_lists_similar(hf_response_tokens, vllm_response_tokens), \
f'Strings differ more than 10%:\n'
print('Check Pass')


# if __name__ == "__main__":
# test_vllm_spmd()
11 changes: 10 additions & 1 deletion verl/third_party/vllm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

from importlib.metadata import version, PackageNotFoundError
from packaging import version as vs


def get_version(pkg):
Expand All @@ -24,6 +25,7 @@ def get_version(pkg):

package_name = 'vllm'
package_version = get_version(package_name)
vllm_version = None

if package_version == '0.3.1':
vllm_version = '0.3.1'
Expand All @@ -45,7 +47,14 @@ def get_version(pkg):
from .vllm_v_0_6_3.llm import LLM
from .vllm_v_0_6_3.llm import LLMEngine
from .vllm_v_0_6_3 import parallel_state
elif vs.parse(package_version) >= vs.parse('0.6.6.post2.dev252+g8027a724'):
# From 0.6.6.post2 on, vllm supports SPMD inference
# See https://github.com/vllm-project/vllm/pull/12071

from vllm import LLM
from vllm.distributed import parallel_state
from .vllm_spmd.dtensor_weight_loaders import load_dtensor_weights
else:
raise ValueError(
f'vllm version {package_version} not supported. Currently supported versions are 0.3.1, 0.4.2, 0.5.4 and 0.6.3.'
f'vllm version {package_version} not supported. Currently supported versions are 0.3.1, 0.4.2, 0.5.4, 0.6.3 and 0.7.0+'
)
13 changes: 13 additions & 0 deletions verl/third_party/vllm/vllm_spmd/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Loading
Loading