diff --git a/.github/workflows/vllm.yml b/.github/workflows/vllm.yml index bffdfaea..97f3aab1 100644 --- a/.github/workflows/vllm.yml +++ b/.github/workflows/vllm.yml @@ -40,3 +40,8 @@ jobs: run: | cd tests/rollout torchrun --standalone --nnodes=1 --nproc_per_node=8 $(which pytest) -s test_vllm_hf_loader.py + - name: Test the latest vLLM + run: | + pip3 install --upgrade vllm + cd tests/rollout + torchrun --standalone --nnodes=1 --nproc_per_node=4 $(which pytest) -s test_vllm_spmd.py diff --git a/README.md b/README.md index bfc021a7..77740e99 100644 --- a/README.md +++ b/README.md @@ -89,6 +89,9 @@ Checkout this [Jupyter Notebook](https://github.com/volcengine/verl/tree/main/ex ## Performance Tuning Guide The performance is essential for on-policy RL algorithm. We write a detailed performance tuning guide to allow people tune the performance. See [here](https://verl.readthedocs.io/en/latest/perf/perf_tuning.html) for more details. +## vLLM v0.7 testing version +We have released a testing version of veRL that supports vLLM>=0.7.0. Please refer to [this document](https://github.com/volcengine/verl/docs/README_vllm0.7.md) for installation guide and more information. + ## Contribution Guide Contributions from the community are welcome! diff --git a/docs/README_vllm0.7.md b/docs/README_vllm0.7.md new file mode 100644 index 00000000..4cae9d44 --- /dev/null +++ b/docs/README_vllm0.7.md @@ -0,0 +1,62 @@ +# Readme for verl(vllm>=0.7) version + +## Installation + +Note: This version of veRL supports **FSDP** for training and **vLLM** for rollout. (Megatron-LM is not supported yet.) + +``` +# Create the conda environment +conda create -n verl python==3.10 +conda activate verl + +# Install verl +git clone https://github.com/volcengine/verl.git +cd verl +pip3 install -e . +# Install vLLM>=0.7 +pip3 install "vllm>=0.7.0" +# Install flash-attn +pip3 install flash-attn --no-build-isolation + +``` + +For existing stable vllm versions (<=0.7.2), you also need to make some tiny patches manually on vllm (/path/to/site-packages/vllm after installation) after the above steps: + +- vllm/distributed/parallel_state.py: Remove the assertion below: + +``` +if (world_size + != tensor_model_parallel_size * pipeline_model_parallel_size): + raise RuntimeError( + f"world_size ({world_size}) is not equal to " + f"tensor_model_parallel_size ({tensor_model_parallel_size}) x " + f"pipeline_model_parallel_size ({pipeline_model_parallel_size})") + +``` + +- vllm/executor/uniproc_executor.py: change `local_rank = rank` to `local_rank = int(os.environ["LOCAL_RANK"])` +- vllm/model_executor/model_loader/weight_utils.py: remove the `torch.cuda.empty_cache()` in `pt_weights_iterator` + +These modifications have already been merged into the main branch of vLLM. To avoid modifying these files manually, you can directly build vLLM from source. + +## Features + +### Use cuda graph + +After installation, examples using FSDP as training backends can be used. By default, the `enforce_eager` is set to True, which disables the cuda graph. To enjoy cuda graphs and the sleep mode of vLLM>=0.7, add the following lines to the bash script: + +``` +actor_rollout_ref.rollout.enforce_eager=False \ +actor_rollout_ref.rollout.free_cache_engine=False \ + +``` + +For a typical job like examples/ppo_trainer/run_qwen2-7b_seq_balance.sh, the rollout generation time is 115 seconds with vLLM0.6.3, while it is 85 seconds with vLLM0.7.0. By enabling the cudagraph, the generation duration is further reduced to 62 seconds. + +**Note:** Currently, if the `n` is greater than 1 in `SamplingParams` in vLLM>=0.7, there is a potential performance issue on the stability of rollout generation time (Some iterations would see generation time bursts). We are working with the vLLM team to check this issue. + +### Other features in vLLM + +1. **num_scheduler_step>1:** not supported yet (weight loading has not been aligned with `MultiStepModelRunner`) +2. **Prefix caching:** not supported yet (vLLM sleep mode does not support prefix caching) +3. **Chunked prefill:** supported \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 7b28eee5..f4425045 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -40,7 +40,7 @@ dependencies = [ "ray>=2.38", "tensordict", "transformers<4.48", - "vllm<=0.6.3", + "vllm<=0.7.3", "peft", "liger-kernel", "pylatexenc", diff --git a/requirements.txt b/requirements.txt index 1e3e3b8b..50684f03 100644 --- a/requirements.txt +++ b/requirements.txt @@ -12,7 +12,7 @@ pybind11 ray>=2.38 tensordict<0.6 transformers<4.48 -vllm<=0.6.3 +vllm wandb liger-kernel pylatexenc diff --git a/tests/rollout/run_fsdp_vllm.py b/tests/rollout/run_fsdp_vllm.py index d9e165a9..93fbba37 100644 --- a/tests/rollout/run_fsdp_vllm.py +++ b/tests/rollout/run_fsdp_vllm.py @@ -23,6 +23,9 @@ from vllm import SamplingParams +import time +import torch.distributed as dist + def main(): assert torch.cuda.is_available(), 'CUDA must be present to run FSDP vLLM example' @@ -112,10 +115,25 @@ def main(): enforce_eager=True, dtype='bfloat16', load_format='dummy_dtensor', - gpu_memory_utilization=0.1, + gpu_memory_utilization=0.8, trust_remote_code=True) + # Warmup iterations + for _ in range(10): + torch.cuda.synchronize() + llm.sync_model_weights(actor_weights=state_dict, load_format='dtensor') + torch.cuda.synchronize() + dist.barrier() + + start_time = time.time() llm.sync_model_weights(actor_weights=state_dict, load_format='dtensor') + torch.cuda.synchronize() + dist.barrier() + end_time = time.time() + + # Calculate elapsed time + elapsed_time = end_time - start_time + print(f"Time taken: {elapsed_time:.6f} seconds") input_ids = input_ids.cuda() attention_mask = attention_mask.cuda() diff --git a/tests/rollout/test_vllm_spmd.py b/tests/rollout/test_vllm_spmd.py new file mode 100644 index 00000000..244a2855 --- /dev/null +++ b/tests/rollout/test_vllm_spmd.py @@ -0,0 +1,162 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import torch +import transformers + +from vllm import LLM, SamplingParams +from verl.utils.model import update_model_config +from transformers import AutoTokenizer, AutoConfig, AutoModelForCausalLM + +from transformers import GenerationConfig + +from verl.utils.torch_functional import pad_sequence_to_length + + +def levenshtein(s1, s2): + m, n = len(s1), len(s2) + # Initialize matrix of zeros + dp = [[0] * (n + 1) for _ in range(m + 1)] + # Initialize first column and first row of the matrix + for i in range(m + 1): + dp[i][0] = i # Deletion from s1 to empty string + for j in range(n + 1): + dp[0][j] = j # Insertion to s1 from empty string + # Compute the Levenshtein distance matrix + for i in range(1, m + 1): + for j in range(1, n + 1): + cost = 0 if s1[i - 1] == s2[j - 1] else 1 # No cost if characters match + dp[i][j] = min( + dp[i - 1][j] + 1, # Deletion + dp[i][j - 1] + 1, # Insertion + dp[i - 1][j - 1] + cost # Substitution + ) + return dp[m][n] + + +def are_lists_similar(a, b): + if len(a) != len(b): + print("The lists are of different lengths.") + return False + + total_length = 0 + total_diff = 0 + + for s1, s2 in zip(a, b): + max_len = max(len(s1), len(s2)) + total_length += max_len + diff = levenshtein(s1, s2) + total_diff += diff + print(f"Comparing strings:\n{s1}\n{s2}\nDifference: {diff} characters\n") + + percentage_difference = (total_diff / total_length) * 100 + print(f"Total difference: {percentage_difference:.2f}%") + + return percentage_difference <= 10 + + +def test_vllm_spmd(): + assert torch.cuda.device_count() >= 2, 'At least 2 GPUs is required to run tp+dp tests.' + + # fill rollout config + max_prompt_length = 16 + max_response_length = 16 + + # Initialize model and token + local_cache_path = '~/.cache/verl/rlhf' + local_cache_path = os.path.expanduser(local_cache_path) + hdfs_path = 'Qwen/Qwen2-7B-Instruct' + from verl.utils.fs import copy_local_path_from_hdfs + local_model_path = copy_local_path_from_hdfs(src=hdfs_path, cache_dir=local_cache_path) + tokenizer = AutoTokenizer.from_pretrained(local_model_path, padding_side='left') + + preencode_prompts = [ + "Who won the Champions League in 2019?", + "The founder of Apple is", + "What's your name", + ] + tokenizer.pad_token = tokenizer.eos_token + prompts = tokenizer(preencode_prompts, return_tensors='pt', padding=True) + input_ids = prompts['input_ids'] + attention_mask = prompts['attention_mask'] + + input_ids = pad_sequence_to_length(input_ids, max_prompt_length, tokenizer.pad_token_id, left_pad=True) + attention_mask = pad_sequence_to_length(attention_mask, max_prompt_length, 0, left_pad=True) + + actor_model = AutoModelForCausalLM.from_pretrained(local_model_path) + actor_model.to(torch.bfloat16) + + actor_model_config = AutoConfig.from_pretrained(local_model_path) + + temperature = 0 + top_p = 1 + + kwargs = dict(n=1, + temperature=temperature, + top_p=top_p, + max_tokens=max_response_length, + logprobs=1, + ignore_eos=True) + + sampling_params = SamplingParams(**kwargs) + tensor_parallel_size = 4 + + llm = LLM(model=local_model_path, + enable_sleep_mode=True, + tensor_parallel_size=tensor_parallel_size, + distributed_executor_backend="external_launcher", + dtype='bfloat16', + gpu_memory_utilization=0.5) + + print('start generation') + input_ids = input_ids.cuda() + attention_mask = attention_mask.cuda() + batch_size = input_ids.size(0) + + generation_config = GenerationConfig(do_sample=False) + actor_model.cuda() + output = actor_model.generate( + input_ids=input_ids, + attention_mask=attention_mask, + max_new_tokens=max_response_length, + # max_length=max_length, + eos_token_id=tokenizer.eos_token_id, + pad_token_id=tokenizer.pad_token_id, + generation_config=generation_config, + # renormalize_logits=True, + output_scores=False, # this is potentially very large + return_dict_in_generate=True, + use_cache=False) # may OOM when use_cache = True + seq = output.sequences + response = seq[:, max_prompt_length:] + + hf_response_tokens = tokenizer.batch_decode(response) + print(f'hf response: {hf_response_tokens}') + + outputs = llm.generate(preencode_prompts, sampling_params=sampling_params, use_tqdm=False) + vllm_response_tokens = [] + for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text + vllm_response_tokens.append(generated_text) + + print(f'vllm response: {vllm_response_tokens}') + assert are_lists_similar(hf_response_tokens, vllm_response_tokens), \ + f'Strings differ more than 10%:\n' + print('Check Pass') + + +# if __name__ == "__main__": +# test_vllm_spmd() diff --git a/verl/third_party/vllm/__init__.py b/verl/third_party/vllm/__init__.py index 290c8378..ab925c1a 100644 --- a/verl/third_party/vllm/__init__.py +++ b/verl/third_party/vllm/__init__.py @@ -13,6 +13,7 @@ # limitations under the License. from importlib.metadata import version, PackageNotFoundError +from packaging import version as vs def get_version(pkg): @@ -24,6 +25,7 @@ def get_version(pkg): package_name = 'vllm' package_version = get_version(package_name) +vllm_version = None if package_version == '0.3.1': vllm_version = '0.3.1' @@ -45,7 +47,14 @@ def get_version(pkg): from .vllm_v_0_6_3.llm import LLM from .vllm_v_0_6_3.llm import LLMEngine from .vllm_v_0_6_3 import parallel_state +elif vs.parse(package_version) >= vs.parse('0.6.6.post2.dev252+g8027a724'): + # From 0.6.6.post2 on, vllm supports SPMD inference + # See https://github.com/vllm-project/vllm/pull/12071 + + from vllm import LLM + from vllm.distributed import parallel_state + from .vllm_spmd.dtensor_weight_loaders import load_dtensor_weights else: raise ValueError( - f'vllm version {package_version} not supported. Currently supported versions are 0.3.1, 0.4.2, 0.5.4 and 0.6.3.' + f'vllm version {package_version} not supported. Currently supported versions are 0.3.1, 0.4.2, 0.5.4, 0.6.3 and 0.7.0+' ) diff --git a/verl/third_party/vllm/vllm_spmd/__init__.py b/verl/third_party/vllm/vllm_spmd/__init__.py new file mode 100644 index 00000000..1ce90c5e --- /dev/null +++ b/verl/third_party/vllm/vllm_spmd/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/verl/third_party/vllm/vllm_spmd/dtensor_weight_loaders.py b/verl/third_party/vllm/vllm_spmd/dtensor_weight_loaders.py new file mode 100644 index 00000000..a3042cab --- /dev/null +++ b/verl/third_party/vllm/vllm_spmd/dtensor_weight_loaders.py @@ -0,0 +1,380 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# Copyright 2023 The vLLM team. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Adapted from https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/model_loader + +from typing import Dict + +import torch.nn as nn +from torch.distributed._tensor import DTensor +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.models.utils import is_pp_missing_parameter + + +def gemma_dtensor_weight_loader(actor_weights: Dict, vllm_model: nn.Module) -> nn.Module: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + params_dict = dict(vllm_model.named_parameters()) + for name, loaded_weight in actor_weights.items(): + for param_name, shard_name, shard_id in stacked_params_mapping: + if shard_name not in name: + continue + stacked_name = name.replace(shard_name, param_name) + # Skip loading extra bias for GPTQ models. + if stacked_name.endswith(".bias") and stacked_name not in params_dict: + continue + local_loaded_weight = redistribute_dtensor(param_name=name, loaded_weights=loaded_weight) + param = params_dict[stacked_name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, local_loaded_weight.to(dtype=param.dtype), shard_id) + break + else: + # lm_head is not used in vllm as it is tied with embed_token. + # To prevent errors, skip loading lm_head.weight. + if "lm_head.weight" in name: + continue + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + local_loaded_weight = redistribute_dtensor(param_name=name, loaded_weights=loaded_weight) + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, local_loaded_weight.to(dtype=param.dtype)) + + +def gptbigcode_dtensor_load_weights(actor_weights: Dict, vllm_model: nn.Module): + params_dict = dict(vllm_model.named_parameters(remove_duplicate=False)) + for name, loaded_weight in actor_weights.items(): + if "lm_head.weight" in name: + continue + if ".attn.bias" in name: + # Skip attention mask. + # NOTE: "c_attn.bias" should not be skipped. + continue + local_loaded_weight = redistribute_dtensor(param_name=name, loaded_weights=loaded_weight) + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, local_loaded_weight.to(dtype=param.dtype)) + + +def starcoder2_dtensor_load_weights(actor_weights: Dict, vllm_model: nn.Module): + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ] + + params_dict = dict(vllm_model.named_parameters(remove_duplicate=False)) + for name, loaded_weight in actor_weights.items(): + if "rotary_emb.inv_freq" in name: + continue + + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + local_loaded_weight = redistribute_dtensor(param_name=name, loaded_weights=loaded_weight) + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, local_loaded_weight.to(dtype=param.dtype), shard_id) + break + else: + if vllm_model.config.tie_word_embeddings and "lm_head.weight" in name: + continue + param = params_dict[name] + local_loaded_weight = redistribute_dtensor(param_name=name, loaded_weights=loaded_weight) + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, local_loaded_weight.to(dtype=param.dtype)) + + +def llama_dtensor_weight_loader(actor_weights: Dict, vllm_model: nn.Module) -> nn.Module: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + (".qkv_proj", ".q_proj", "q"), + (".qkv_proj", ".k_proj", "k"), + (".qkv_proj", ".v_proj", "v"), + (".gate_up_proj", ".gate_proj", 0), + (".gate_up_proj", ".up_proj", 1), + ] + params_dict = dict(vllm_model.named_parameters()) + for name, loaded_weight in actor_weights.items(): + if "rotary_emb.inv_freq" in name: + continue + if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name: + # Models trained using ColossalAI may include these tensors in + # the checkpoint. Skip them. + continue + # With tie_word_embeddings, we can skip lm_head.weight + # The weight might appear unnecessarily in the files if the model is + # processed with quantization, LoRA, fine-tuning, etc. + if vllm_model.config.tie_word_embeddings and "lm_head.weight" in name: + continue + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + local_loaded_weight = redistribute_dtensor(param_name=name, loaded_weights=loaded_weight) + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, local_loaded_weight.to(dtype=param.dtype), shard_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + local_loaded_weight = redistribute_dtensor(param_name=name, loaded_weights=loaded_weight) + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, local_loaded_weight) + + +def qwen2_dtensor_weight_loader(actor_weights: Dict, vllm_model: nn.Module) -> nn.Module: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + params_dict = dict(vllm_model.named_parameters(remove_duplicate=False)) + for name, loaded_weight in actor_weights.items(): + if "rotary_emb.inv_freq" in name: + continue + if vllm_model.config.tie_word_embeddings and "lm_head.weight" in name: + continue + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + local_loaded_weight = redistribute_dtensor(param_name=name, loaded_weights=loaded_weight) + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, local_loaded_weight.to(dtype=param.dtype), shard_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[name] + local_loaded_weight = redistribute_dtensor(param_name=name, loaded_weights=loaded_weight) + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, local_loaded_weight.to(dtype=param.dtype)) + + +def qwen2vl_dtensor_weight_loader(actor_weights: Dict, vllm_model: nn.Module) -> nn.Module: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + params_dict = dict(vllm_model.named_parameters(remove_duplicate=False)) + for name, loaded_weight in actor_weights.items(): + if "rotary_emb.inv_freq" in name: + continue + if vllm_model.config.tie_word_embeddings and "lm_head.weight" in name: + continue + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + local_loaded_weight = redistribute_dtensor(param_name=name, loaded_weights=loaded_weight) + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, local_loaded_weight.to(dtype=param.dtype), shard_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[name] + local_loaded_weight = redistribute_dtensor(param_name=name, loaded_weights=loaded_weight) + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, local_loaded_weight.to(dtype=param.dtype)) + + +from vllm.model_executor.layers.fused_moe import FusedMoE + + +def deepseekv2_dtensor_weight_loader(actor_weights: Dict, vllm_model: nn.Module) -> nn.Module: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + + # Params for weights, fp8 weight scales, fp8 activation scales + # (param_name, weight_name, expert_id, shard_id) + expert_params_mapping = FusedMoE.make_expert_params_mapping( + ckpt_gate_proj_name="gate_proj", + ckpt_down_proj_name="down_proj", + ckpt_up_proj_name="up_proj", + num_experts=vllm_model.config.n_routed_experts, + ) + + params_dict = dict(vllm_model.named_parameters(remove_duplicate=False)) + for name, loaded_weight in actor_weights.items(): + if "rotary_emb.inv_freq" in name: + continue + for param_name, weight_name, shard_id in stacked_params_mapping: + # Skip non-stacked layers and experts (experts handled below). + if weight_name not in name: + continue + # We have mlp.experts[0].gate_proj in the checkpoint. + # Since we handle the experts below in expert_params_mapping, + # we need to skip here BEFORE we update the name, otherwise + # name will be updated to mlp.experts[0].gate_up_proj, which + # will then be updated below in expert_params_mapping + # for mlp.experts[0].gate_gate_up_proj, which breaks load. + if ("mlp.experts." in name) and name not in params_dict: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + + if is_pp_missing_parameter(name, vllm_model): + continue + + param = params_dict[name] + local_loaded_weight = redistribute_dtensor(param_name=name, loaded_weights=loaded_weight) + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, local_loaded_weight.to(dtype=param.dtype), shard_id) + break + else: + for mapping in expert_params_mapping: + param_name, weight_name, expert_id, shard_id = mapping + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + + if is_pp_missing_parameter(name, vllm_model): + continue + + param = params_dict[name] + local_loaded_weight = redistribute_dtensor(param_name=name, loaded_weights=loaded_weight) + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader( + param, + local_loaded_weight.to(dtype=param.dtype), + weight_name, + shard_id=shard_id, + expert_id=expert_id, + ) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + + if is_pp_missing_parameter(name, vllm_model): + continue + + param = params_dict[name] + local_loaded_weight = redistribute_dtensor(param_name=name, loaded_weights=loaded_weight) + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, local_loaded_weight.to(dtype=param.dtype)) + + +def gpt2_dtensor_weight_loader(actor_weights: Dict, vllm_model: nn.Module) -> nn.Module: + pass + + +def redistribute_dtensor(param_name: str, loaded_weights: DTensor, parallelize_plan: Dict = None): + param_name = _process_parameter_names(name=param_name) + if parallelize_plan is not None: + assert ( + param_name + in parallelize_plan.keys()), f"param name: {param_name} not in parallelize_plan :{parallelize_plan.keys()}" + placement = parallelize_plan[param_name] + local_loaded_weights = loaded_weights.redistribute(device_mesh=loaded_weights.device_mesh, + placements=placement).to_local() + else: + local_loaded_weights = loaded_weights.full_tensor() + return local_loaded_weights + + +def _process_parameter_names(name): + # Remove '.weight' if it exists at the end of the string + if name.endswith(".weight"): + name = name[:-7] + + # Remove 'model.layers.x.' or 'model.' prefix + if "model.layers" in name: + parts = name.split(".") + # Reconstruct the string without 'model.layers.x.' + name = ".".join(parts[3:]) # parts[0] is 'model', parts[1] is 'layers', parts[2] is 'x' + elif name.startswith("model."): + name = name[6:] # Remove 'model.' + + return name + + +__MODEL_DTENSOR_WEIGHT_LOADER_REGISTRY__ = { + "GPT2LMHeadModel": gpt2_dtensor_weight_loader, + "LlamaForCausalLM": llama_dtensor_weight_loader, + "LLaMAForCausalLM": llama_dtensor_weight_loader, + "MistralForCausalLM": llama_dtensor_weight_loader, # mistral is the same as llama in vLLM + "InternLMForCausalLM": llama_dtensor_weight_loader, + "AquilaModel": llama_dtensor_weight_loader, + "AquilaForCausalLM": llama_dtensor_weight_loader, + "Phi3ForCausalLM": llama_dtensor_weight_loader, + "GemmaForCausalLM": gemma_dtensor_weight_loader, + "Gemma2ForCausalLM": gemma_dtensor_weight_loader, + "GPTBigCodeForCausalLM": gptbigcode_dtensor_load_weights, + "Starcoder2ForCausalLM": starcoder2_dtensor_load_weights, + "Qwen2ForCausalLM": qwen2_dtensor_weight_loader, + "DeepseekV2ForCausalLM": deepseekv2_dtensor_weight_loader, + "Qwen2VLForConditionalGeneration": qwen2vl_dtensor_weight_loader, +} + + +# the actor model is .state_dict() +# Load dtensor weights +def load_dtensor_weights(actor_weights: Dict, vllm_model: nn.Module): + weight_loader = _get_model_weight_loader(vllm_model.__class__.__name__) + weight_loader(actor_weights, vllm_model) + # NOTE(sgm) to reduce peak memory usage, we offload vllm model to cpu + # after init, and we need this after sync model weights for in first iter. + vllm_model = vllm_model.cuda() + + +def _get_model_weight_loader(arch: str): + if arch in __MODEL_DTENSOR_WEIGHT_LOADER_REGISTRY__: + return __MODEL_DTENSOR_WEIGHT_LOADER_REGISTRY__[arch] + raise ValueError(f"Model architectures {arch} are not supported for now. " + f"Supported architectures: {__MODEL_DTENSOR_WEIGHT_LOADER_REGISTRY__.keys()}") + + +# NOTE(sgm): we use per-parameter weight loader in each vllm sub +def update_dtensor_weight_loader(): + pass diff --git a/verl/utils/torch_functional.py b/verl/utils/torch_functional.py index 29a521f8..32bfda22 100644 --- a/verl/utils/torch_functional.py +++ b/verl/utils/torch_functional.py @@ -225,6 +225,20 @@ def split_dict_tensor_into_batches(tensors: TensorDict, batch_size) -> List[Tens return tensors.split(batch_size) +def pad_2d_list_to_length(response, pad_token_id, max_length=None): + """ + pad a 2D list (e.g. responses, logprobs) to a 2D tensor. + """ + response_length = max(len(sub_list) for sub_list in response) + if max_length is not None and max_length > response_length: + target_length = max_length + else: + target_length = response_length + padded_response = [tuple(sub_list) + (pad_token_id,) * (target_length - len(sub_list)) for sub_list in response] + tensor = torch.tensor(padded_response) + return tensor + + def pad_sequence_to_length(tensors, max_seq_len, pad_token_id, left_pad=False): """ pad a 2D tensors (e.g. responses, logprobs) in the last dim to max_seq_length. diff --git a/verl/workers/fsdp_workers.py b/verl/workers/fsdp_workers.py index ed96127d..a995317c 100644 --- a/verl/workers/fsdp_workers.py +++ b/verl/workers/fsdp_workers.py @@ -300,13 +300,23 @@ def _build_rollout(self): rollout_sharding_manager = BaseShardingManager() # TODO: a sharding manager that do nothing? elif self.config.rollout.name == 'vllm': - from verl.workers.rollout.vllm_rollout import vLLMRollout + from verl.workers.rollout.vllm_rollout import vLLMRollout, vllm_mode from verl.workers.sharding_manager import FSDPVLLMShardingManager log_gpu_memory_usage('Before building vllm rollout', logger=None) - rollout = vLLMRollout(actor_module=self.actor_module_fsdp, - config=self.config.rollout, - tokenizer=self.tokenizer, - model_hf_config=self.actor_model_config) + local_path = copy_local_path_from_hdfs(self.config.model.path) + if vllm_mode == 'customized': + rollout = vLLMRollout(actor_module=self.actor_module_fsdp, + config=self.config.rollout, + tokenizer=self.tokenizer, + model_hf_config=self.actor_model_config) + elif vllm_mode == 'spmd': + rollout = vLLMRollout(model_path=local_path, + config=self.config.rollout, + tokenizer=self.tokenizer, + model_hf_config=self.actor_model_config, + device_mesh=rollout_device_mesh) + else: + raise NotImplementedError("vllm_mode must be 'customized' or 'spmd'") log_gpu_memory_usage('After building vllm rollout', logger=None) if torch.distributed.get_world_size() == 1: self.config.rollout.load_format = 'dummy_hf' diff --git a/verl/workers/megatron_workers.py b/verl/workers/megatron_workers.py index 3464a6e5..894d64cf 100644 --- a/verl/workers/megatron_workers.py +++ b/verl/workers/megatron_workers.py @@ -223,7 +223,7 @@ def megatron_actor_model_provider(pre_process, post_process): def _build_rollout(self): if self.config.rollout.name == 'vllm': - from verl.workers.rollout.vllm_rollout import vLLMRollout + from verl.workers.rollout.vllm_rollout import vLLMRollout, vllm_mode from verl.workers.sharding_manager import MegatronVLLMShardingManager from verl.utils.model import normalize_pp_vpp_params @@ -247,6 +247,7 @@ def _build_rollout(self): params = normalize_pp_vpp_params(params=params, num_hidden_layers=self.actor_model_config.num_hidden_layers, layer_name='layers') + assert vllm_mode == 'customized', "Support for vllm>=0.7 for Megatron-LM backend has not been implemented yet." rollout = vLLMRollout(actor_module=params, config=self.config.rollout, tokenizer=self.tokenizer, diff --git a/verl/workers/rollout/vllm_rollout/__init__.py b/verl/workers/rollout/vllm_rollout/__init__.py index 4f06d209..43e6aab4 100644 --- a/verl/workers/rollout/vllm_rollout/__init__.py +++ b/verl/workers/rollout/vllm_rollout/__init__.py @@ -12,4 +12,22 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .vllm_rollout import vLLMRollout \ No newline at end of file +from importlib.metadata import version, PackageNotFoundError + + +def get_version(pkg): + try: + return version(pkg) + except PackageNotFoundError: + return None + + +package_name = 'vllm' +package_version = get_version(package_name) + +if package_version <= '0.6.3': + vllm_mode = 'customized' + from .vllm_rollout import vLLMRollout +else: + vllm_mode = 'spmd' + from .vllm_rollout_spmd import vLLMRollout diff --git a/verl/workers/rollout/vllm_rollout/vllm_rollout_spmd.py b/verl/workers/rollout/vllm_rollout/vllm_rollout_spmd.py new file mode 100644 index 00000000..bcee3544 --- /dev/null +++ b/verl/workers/rollout/vllm_rollout/vllm_rollout_spmd.py @@ -0,0 +1,234 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +The vllm_rollout that can be applied in different backend +When working with FSDP: +- Use DTensor weight loader (recommended) or HF weight loader +- Utilize state_dict from the FSDP to synchronize the weights among tp ranks in vLLM +When working with Megatron: +- Use Megatron weight loader +- During training, only the current pp stage holds the parameters +- Before inference, broadcast the parameters of the current pp rank to all other pp ranks (all pp ranks holds all the parameters) +- Bind the parameters to the inference engine +- Do inference in tp. pp is treated as additional dp +- After inference, all the parameters that doesn't belong to this pp rank is freed. +""" +from typing import List +from contextlib import contextmanager +from omegaconf import DictConfig +import torch +import torch.distributed +from tensordict import TensorDict +from torch import nn + +from verl import DataProto +from verl.utils.torch_functional import get_eos_mask, pad_2d_list_to_length +from verl.workers.rollout.base import BaseRollout +from vllm.distributed import parallel_state as vllm_ps +from vllm import LLM, SamplingParams +from verl.third_party.vllm import vllm_version + +# TODO +# 1. support pp in vllm +# 2. passing tokenizer is not necessary? no encoding/decoding is happending here +# 3. simplify init logics + + +# NOTE(sgm): add for verl. We can optimize it by making the dataloader yield List[int] without padding. +def _pre_process_inputs(pad_token_id, prompt_token_ids: torch.Tensor) -> List[int]: + # remove the left padding in the prompt token_id + # pad_token_id = self.llm_engine.tokenizer.pad_token_id if self.llm_engine.tokenizer.pad_token_id is not None else self.llm_engine.tokenizer.eos_token_id + non_pad_index = torch.nonzero(prompt_token_ids != pad_token_id, as_tuple=False)[0][0] + token_ids = prompt_token_ids[non_pad_index:].tolist() + return token_ids + + +class vLLMRollout(BaseRollout): + + def __init__(self, model_path: str, config: DictConfig, tokenizer, model_hf_config, **kwargs): + """A vLLM rollout. It requires the module is supported by the vllm. + + Args: + module: module here follows huggingface APIs + config: DictConfig + tokenizer: the task/model tokenizer + model_hf_config: the huggingface config to initiallize the generating model in vllm + **kwargs: train_tp, for Megatron Backend to initialize hybrid engine (zero redundancy) process group + """ + super().__init__() + self.config = config + assert not (not config.enforce_eager and config.free_cache_engine), \ + "disable CUDA graph (enforce_eager = False) if free cache engine" + + tensor_parallel_size = self.config.get('tensor_model_parallel_size', 1) + assert tensor_parallel_size <= torch.distributed.get_world_size(), \ + "tensor parallel size should be less than or equal to the world size" + max_num_batched_tokens = self.config.get('max_num_batched_tokens', 8192) + + if kwargs.get('train_tp', None) is not None: + # deployed with megatron + import os + os.environ['CUDA_TIMER_STREAM_KAFKA_ENABLE'] = '0' + os.environ['MEGATRON_IMPORT_TIMERS'] = '0' + train_tp = kwargs.get('train_tp', None) + num_tp_per_train_tp = train_tp // tensor_parallel_size + vllm_ps.initialize_parallel_state(tensor_model_parallel_size=tensor_parallel_size, + num_tp_per_train_tp=num_tp_per_train_tp) + + assert model_hf_config.max_position_embeddings >= config.prompt_length + config.response_length, \ + "model context length should be greater than total sequence length" + + self.inference_engine = LLM( + model=model_path, + enable_sleep_mode=True, + tensor_parallel_size=tensor_parallel_size, + distributed_executor_backend="external_launcher", + dtype=config.dtype, + enforce_eager=config.enforce_eager, + gpu_memory_utilization=config.gpu_memory_utilization, + disable_custom_all_reduce=True, + skip_tokenizer_init=False, + max_model_len=config.prompt_length + config.response_length, + disable_log_stats=config.disable_log_stats, + max_num_batched_tokens=max_num_batched_tokens, + enable_chunked_prefill=config.enable_chunked_prefill, + ) + + # Offload vllm model to reduce peak memory usage + self.inference_engine.sleep(level=1) + + kwargs = dict( + n=1, + logprobs=1, # can be set to 0 and let actor to recompute + max_tokens=config.response_length, + ) + + # # we may detokenize the result all together later + if vllm_version != '0.3.1': + kwargs['detokenize'] = False + + # supporting adding any sampling params from the config file + for k in config.keys(): + if hasattr(SamplingParams(), str(k)): + kwargs[k] = config.get(k) + + print(f"kwargs: {kwargs}") + self.sampling_params = SamplingParams(**kwargs) + + self.pad_token_id = tokenizer.pad_token_id + + @contextmanager + def update_sampling_params(self, **kwargs): + # update sampling params + old_sampling_params_args = {} + if kwargs: + for key, value in kwargs.items(): + if hasattr(self.sampling_params, key): + old_value = getattr(self.sampling_params, key) + old_sampling_params_args[key] = old_value + setattr(self.sampling_params, key, value) + yield + # roll back to previous sampling params + # if len(old_sampling_params_args): + for key, value in old_sampling_params_args.items(): + setattr(self.sampling_params, key, value) + + @torch.no_grad() + def generate_sequences(self, prompts: DataProto, **kwargs) -> DataProto: + # rebuild vllm cache engine + if vllm_version in ('0.3.1', '0.4.2', '0.5.4', '0.6.3') and self.config.free_cache_engine: + self.inference_engine.init_cache_engine() + + idx = prompts.batch['input_ids'] # (bs, prompt_length) + # left-padded attention_mask + attention_mask = prompts.batch['attention_mask'] + position_ids = prompts.batch['position_ids'] + + # used to construct attention_mask + eos_token_id = prompts.meta_info['eos_token_id'] + + batch_size = idx.size(0) + + idx_list = [] + # parse idx from torch.Tensor to List[List[str]] + for i in range(batch_size): + idx_list.append(_pre_process_inputs(self.pad_token_id, idx[i])) + + do_sample = prompts.meta_info.get('do_sample', True) + if not do_sample: + kwargs = { + 'best_of': 1, + 'top_p': 1.0, + 'top_k': -1, + 'min_p': 0.0, + 'temperature': 0, + 'n': 1 # if greedy, only 1 response + } + + # users can customize different sampling_params at different run + with self.update_sampling_params(**kwargs): + outputs = self.inference_engine.generate( + prompts=None, # because we have already convert it to prompt token id + sampling_params=self.sampling_params, + prompt_token_ids=idx_list, + use_tqdm=False) + + # TODO(sgm): disable logprob when recompute_log_prob is enable + # if n = 1: (bs, response_length) ; if n > 1: (bs * n, response_length) + + response = [] + for output in outputs: + for sample_id in range(len(output.outputs)): + response.append(output.outputs[sample_id].token_ids) + + response = pad_2d_list_to_length(response, self.pad_token_id, + max_length=self.config.response_length).to(idx.device) + + if self.config.n > 1 and do_sample: + idx = idx.repeat_interleave(self.config.n, dim=0) + attention_mask = attention_mask.repeat_interleave(self.config.n, dim=0) + position_ids = position_ids.repeat_interleave(self.config.n, dim=0) + batch_size = batch_size * self.config.n + seq = torch.cat([idx, response], dim=-1) + + response_length = response.size(1) + delta_position_id = torch.arange(1, response_length + 1, device=position_ids.device) + delta_position_id = delta_position_id.unsqueeze(0).repeat(batch_size, 1) + + # TODO(sgm): fix position_ids on right_pad + # prompt: left pad + response: right pad + # attention_mask: [0,0,0,0,1,1,1,1, | 1,1,1,0,0,0,0,0] + # position_ids: [0,0,0,0,0,1,2,3, | 4,5,6,7,8,9,10,11] + response_position_ids = position_ids[:, -1:] + delta_position_id + position_ids = torch.cat([position_ids, response_position_ids], dim=-1) + response_attention_mask = get_eos_mask(response_id=response, eos_token=eos_token_id, dtype=attention_mask.dtype) + attention_mask = torch.cat((attention_mask, response_attention_mask), dim=-1) + + # all the tp ranks should contain the same data here. data in all ranks are valid + batch = TensorDict( + { + 'prompts': idx, + 'responses': response, + 'input_ids': seq, # here input_ids become the whole sentences + # 'old_log_probs': log_probs, # we will recompute old log prob with actor + 'attention_mask': attention_mask, + 'position_ids': position_ids + }, + batch_size=batch_size) + + # free vllm cache engine + if vllm_version in ('0.3.1', '0.4.2', '0.5.4', '0.6.3') and self.config.free_cache_engine: + self.inference_engine.free_cache_engine() + + return DataProto(batch=batch) diff --git a/verl/workers/sharding_manager/fsdp_vllm.py b/verl/workers/sharding_manager/fsdp_vllm.py index 19490f4e..c79d3031 100644 --- a/verl/workers/sharding_manager/fsdp_vllm.py +++ b/verl/workers/sharding_manager/fsdp_vllm.py @@ -24,6 +24,7 @@ from verl import DataProto from verl.utils.torch_functional import (broadcast_dict_tensor, allgather_dict_tensors) from verl.utils.debug import log_gpu_memory_usage +from verl.third_party.vllm import vllm_version from .base import BaseShardingManager @@ -72,7 +73,17 @@ def __enter__(self): log_gpu_memory_usage('After state_dict() in sharding manager memory', logger=logger) # Copy, not share memory load_format = 'hf' if self.full_params else 'dtensor' - self.inference_engine.sync_model_weights(params, load_format=load_format) + if vllm_version in ('0.4.2', '0.5.4', '0.6.3'): + self.inference_engine.sync_model_weights(params, load_format=load_format) + else: + self.inference_engine.wake_up() + # TODO(ZSL): deal with 'hf' format + if load_format == 'dtensor': + from verl.third_party.vllm import load_dtensor_weights + load_dtensor_weights( + params, self.inference_engine.llm_engine.model_executor.driver_worker.worker.model_runner.model) + else: + raise NotImplementedError(f'load_format {load_format} not implemented') log_gpu_memory_usage('After sync model weights in sharding manager', logger=logger) del params @@ -92,7 +103,11 @@ def __enter__(self): def __exit__(self, exc_type, exc_value, traceback): log_gpu_memory_usage('Before vllm offload in sharding manager', logger=logger) - self.inference_engine.offload_model_weights() + # TODO(ZSL): check this + if vllm_version in ('0.4.2', '0.5.4', '0.6.3'): + self.inference_engine.offload_model_weights() + else: + self.inference_engine.sleep(level=1) log_gpu_memory_usage('After vllm offload in sharding manager', logger=logger) # self.module.to('cuda') @@ -111,18 +126,29 @@ def __exit__(self, exc_type, exc_value, traceback): def preprocess_data(self, data: DataProto) -> DataProto: # TODO: Current impl doesn't consider FSDP with torch micro-dp - data.batch = allgather_dict_tensors(data.batch.contiguous(), - size=vllm_ps.get_tensor_model_parallel_world_size(), - group=vllm_ps.get_tensor_model_parallel_group(), - dim=0) + if vllm_version in ('0.3.1', '0.4.2', '0.5.4', '0.6.3'): + data.batch = allgather_dict_tensors(data.batch.contiguous(), + size=vllm_ps.get_tensor_model_parallel_world_size(), + group=vllm_ps.get_tensor_model_parallel_group(), + dim=0) + else: + data.batch = allgather_dict_tensors(data.batch.contiguous(), + size=vllm_ps.get_tensor_model_parallel_world_size(), + group=vllm_ps.get_tensor_model_parallel_group().device_group, + dim=0) return data def postprocess_data(self, data: DataProto) -> DataProto: # TODO: Current impl doesn't consider FSDP with torch micro-dp - broadcast_dict_tensor(data.batch, - src=vllm_ps.get_tensor_model_parallel_src_rank(), - group=vllm_ps.get_tensor_model_parallel_group()) + local_world_size = vllm_ps.get_tensor_model_parallel_world_size() + src_rank = (torch.distributed.get_rank() // local_world_size) * local_world_size + if vllm_version in ('0.3.1', '0.4.2', '0.5.4', '0.6.3'): + broadcast_dict_tensor(data.batch, src=src_rank, group=vllm_ps.get_tensor_model_parallel_group()) + else: + broadcast_dict_tensor(data.batch, + src=src_rank, + group=vllm_ps.get_tensor_model_parallel_group().device_group) dp_rank = torch.distributed.get_rank() dp_size = torch.distributed.get_world_size() # not consider torch micro-dp tp_size = vllm_ps.get_tensor_model_parallel_world_size()