Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[testing][rollout] feat: support integration of vllm>=0.7.0 (spmd-version) #209

Merged
merged 41 commits into from
Feb 14, 2025

Conversation

ZSL98
Copy link
Contributor

@ZSL98 ZSL98 commented Feb 5, 2025

This PR aims to integrate vllm>=0.7.0 and preserve:
Backward compatibility: 0.3.1, 0.4.2, 0.5.4, 0.6.3 are still supported
Forward compatibility: Future versions of vllm (>= 0.7.0) will be supported without requiring manual maintenance for each new release.

The readme of this Beta version is located at docs/README_vllm0.7.md, where users can find the installation method and related features. This readme is copied as below.


Readme for verl(vllm>=0.7) version

Installation

Note: This version of veRL supports FSDP for training and vLLM for rollout. (Megatron-LM is not supported yet.)

# Create the conda environment
conda create -n verl python==3.10
conda activate verl

# Install verl
git clone https://github.com/volcengine/verl.git
cd verl
pip3 install -e .
# Install vLLM>=0.7
pip3 install vllm==0.7.0
# Install flash-attn
pip3 install flash-attn --no-build-isolation

For existing stable vllm versions (<=0.7.2), you also need to make some tiny patches manually on vllm (/path/to/site-packages/vllm after installation) after the above steps:

  • vllm/distributed/parallel_state.py: Remove the assertion below:
if (world_size
        != tensor_model_parallel_size * pipeline_model_parallel_size):
    raise RuntimeError(
        f"world_size ({world_size}) is not equal to "
        f"tensor_model_parallel_size ({tensor_model_parallel_size}) x "
        f"pipeline_model_parallel_size ({pipeline_model_parallel_size})")

  • vllm/executor/uniproc_executor.py: change local_rank = rank to local_rank = int(os.environ["LOCAL_RANK"])
  • vllm/model_executor/model_loader/weight_utils.py: remove the torch.cuda.empty_cache() in pt_weights_iterator

These modifications have already been merged into the main branch of vLLM. To avoid modifying these files manually, you can directly build vLLM from source.

Features

Use cuda graph

After installation, examples using FSDP as training backends can be used. By default, the enforce_eager is set to True, which disables the cuda graph. To enjoy cuda graphs and the sleep mode of vLLM>=0.7, add the following lines to the bash script:

actor_rollout_ref.rollout.enforce_eager=False \
actor_rollout_ref.rollout.free_cache_engine=False \

For a typical job like examples/ppo_trainer/run_qwen2-7b_seq_balance.sh, the rollout generation time is 115 seconds with vLLM0.6.3, while it is 85 seconds with vLLM0.7.0. By enabling the cudagraph, the generation duration is further reduced to 62 seconds.

Note: Currently, if the n is greater than 1 in SamplingParams in vLLM>=0.7, there is a potential performance issue on the stability of rollout generation time (Some iterations would see generation time bursts). We are working with the vLLM team to check this issue.

Other features in vLLM

  1. num_scheduler_step>1: not supported yet (weight loading has not been aligned with MultiStepModelRunner)
  2. Prefix caching: not supported yet (vLLM sleep mode does not support prefix caching)
  3. Chunked prefill: supported

@ZSL98 ZSL98 marked this pull request as ready for review February 6, 2025 08:54
@ZSL98
Copy link
Contributor Author

ZSL98 commented Feb 10, 2025

If you are loading pt weight files and using vllm's sleep mode, please comment out https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/model_loader/weight_utils.py#L465 because of a potential pytorch issue: pytorch/pytorch#145168

@YangWang92
Copy link

BTW, I got an error with current SPMD vLLMRollout on multi-node training, I am still trying to figure what happened.

(main_task pid=3370254) RuntimeError: Sizes of tensors must match except in dimension 1. Expected size 1280 but got size 256 for tensor number 1 in the list.

@ZSL98
Copy link
Contributor Author

ZSL98 commented Feb 13, 2025

BTW, I got an error with current SPMD vLLMRollout on multi-node training, I am still trying to figure what happened.

(main_task pid=3370254) RuntimeError: Sizes of tensors must match except in dimension 1. Expected size 1280 but got size 256 for tensor number 1 in the list.

@YangWang92 You may try the latest commit in this pr, to handle the SamplingParams n>1 case in vllm_rollout_spmd.py

response = []
for output in outputs:
    for sample_id in range(len(output.outputs)):
        response.append(output.outputs[sample_id].token_ids)

@YangWang92
Copy link

BTW, I got an error with current SPMD vLLMRollout on multi-node training, I am still trying to figure what happened.

(main_task pid=3370254) RuntimeError: Sizes of tensors must match except in dimension 1. Expected size 1280 but got size 256 for tensor number 1 in the list.

@YangWang92 You may try the latest commit in this pr, to handle the SamplingParams n>1 case in vllm_rollout_spmd.py

response = []
for output in outputs:
    for sample_id in range(len(output.outputs)):
        response.append(output.outputs[sample_id].token_ids)

Thanks for your help, and let me try.

@YangWang92
Copy link

BTW, I got an error with current SPMD vLLMRollout on multi-node training, I am still trying to figure what happened.

(main_task pid=3370254) RuntimeError: Sizes of tensors must match except in dimension 1. Expected size 1280 but got size 256 for tensor number 1 in the list.

@YangWang92 You may try the latest commit in this pr, to handle the SamplingParams n>1 case in vllm_rollout_spmd.py

response = []
for output in outputs:
    for sample_id in range(len(output.outputs)):
        response.append(output.outputs[sample_id].token_ids)

I confirmed that current code works well. Thanks!

@ZSL98 ZSL98 changed the title [WIP] Integrating vllm>=0.7.0 Integration of vllm>=0.7.0 Feb 14, 2025
@PeterSH6 PeterSH6 changed the title Integration of vllm>=0.7.0 [testing][rollout]feat: support integration of vllm>=0.7.0 (spmd-version) Feb 14, 2025
@PeterSH6 PeterSH6 changed the title [testing][rollout]feat: support integration of vllm>=0.7.0 (spmd-version) [testing][rollout] feat: support integration of vllm>=0.7.0 (spmd-version) Feb 14, 2025
@PeterSH6 PeterSH6 merged commit f8b4d08 into volcengine:main Feb 14, 2025
12 checks passed
as12138 pushed a commit to as12138/verl that referenced this pull request Feb 20, 2025
…sion) (volcengine#209)

This PR aims to integrate vllm>=0.7.0 and preserve:
**Backward compatibility**: 0.3.1, 0.4.2, 0.5.4, 0.6.3 are still
supported
**Forward compatibility**: Future versions of vllm (>= 0.7.0) will be
supported without requiring manual maintenance for each new release.

The readme of this Beta version is located at docs/README_vllm0.7.md,
where users can find the installation method and related features. This
readme is copied as below.

---

Note: This version of veRL supports **FSDP** for training and **vLLM**
for rollout. (Megatron-LM is not supported yet.)

```
conda create -n verl python==3.10
conda activate verl

git clone https://github.com/volcengine/verl.git
cd verl
pip3 install -e .
pip3 install vllm==0.7.0
pip3 install flash-attn --no-build-isolation

```

For existing stable vllm versions (<=0.7.2), you also need to make some
tiny patches manually on vllm (/path/to/site-packages/vllm after
installation) after the above steps:

- vllm/distributed/parallel_state.py: Remove the assertion below:

```
if (world_size
        != tensor_model_parallel_size * pipeline_model_parallel_size):
    raise RuntimeError(
        f"world_size ({world_size}) is not equal to "
        f"tensor_model_parallel_size ({tensor_model_parallel_size}) x "
        f"pipeline_model_parallel_size ({pipeline_model_parallel_size})")

```

- vllm/executor/uniproc_executor.py: change `local_rank = rank` to
`local_rank = int(os.environ["LOCAL_RANK"])`
- vllm/model_executor/model_loader/weight_utils.py: remove the
`torch.cuda.empty_cache()` in `pt_weights_iterator`

These modifications have already been merged into the main branch of
vLLM. To avoid modifying these files manually, you can directly build
vLLM from source.

After installation, examples using FSDP as training backends can be
used. By default, the `enforce_eager` is set to True, which disables the
cuda graph. To enjoy cuda graphs and the sleep mode of vLLM>=0.7, add
the following lines to the bash script:

```
actor_rollout_ref.rollout.enforce_eager=False \
actor_rollout_ref.rollout.free_cache_engine=False \

```

For a typical job like examples/ppo_trainer/run_qwen2-7b_seq_balance.sh,
the rollout generation time is 115 seconds with vLLM0.6.3, while it is
85 seconds with vLLM0.7.0. By enabling the cudagraph, the generation
duration is further reduced to 62 seconds.

**Note:** Currently, if the `n` is greater than 1 in `SamplingParams` in
vLLM>=0.7, there is a potential performance issue on the stability of
rollout generation time (Some iterations would see generation time
bursts). We are working with the vLLM team to check this issue.

1. **num_scheduler_step>1:** not supported yet (weight loading has not
been aligned with `MultiStepModelRunner`)
2. **Prefix caching:** not supported yet (vLLM sleep mode does not
support prefix caching)
3. **Chunked prefill:** supported

---------

Co-authored-by: zhangshulai <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants