Skip to content

Commit

Permalink
Make vllm compatible with verl (vllm-project#12824)
Browse files Browse the repository at this point in the history
Co-authored-by: zhangshulai <[email protected]>
  • Loading branch information
ZSL98 and zhangshulai authored Feb 7, 2025
1 parent ef533d2 commit 433c4a4
Show file tree
Hide file tree
Showing 2 changed files with 1 addition and 8 deletions.
7 changes: 0 additions & 7 deletions vllm/distributed/parallel_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -1024,13 +1024,6 @@ def initialize_model_parallel(
backend = backend or torch.distributed.get_backend(
get_world_group().device_group)

if (world_size
!= tensor_model_parallel_size * pipeline_model_parallel_size):
raise RuntimeError(
f"world_size ({world_size}) is not equal to "
f"tensor_model_parallel_size ({tensor_model_parallel_size}) x "
f"pipeline_model_parallel_size ({pipeline_model_parallel_size})")

# Build the tensor model-parallel groups.
num_tensor_model_parallel_groups: int = (world_size //
tensor_model_parallel_size)
Expand Down
2 changes: 1 addition & 1 deletion vllm/executor/uniproc_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def _init_executor(self) -> None:
# - MASTER_PORT
distributed_init_method = "env://"
rank = int(os.environ["RANK"])
local_rank = rank
local_rank = int(os.environ["LOCAL_RANK"])
is_driver_worker = True
kwargs = dict(
vllm_config=self.vllm_config,
Expand Down

0 comments on commit 433c4a4

Please sign in to comment.