diff --git a/tests/distributed/test_pipeline_partition.py b/tests/distributed/test_pipeline_partition.py new file mode 100644 index 0000000000000..2d4d07dd27522 --- /dev/null +++ b/tests/distributed/test_pipeline_partition.py @@ -0,0 +1,34 @@ +import os + +import pytest + +from vllm.distributed.utils import get_pp_indices + + +def test_custom_layer_partition(): + + def _verify(partition_str, num_layers, pp_size, goldens): + bak = os.environ.get("VLLM_PP_LAYER_PARTITION", None) + os.environ["VLLM_PP_LAYER_PARTITION"] = partition_str + for pp_rank, golden in enumerate(goldens): + assert get_pp_indices(num_layers, pp_rank, pp_size) == golden + if bak is not None: + os.environ["VLLM_PP_LAYER_PARTITION"] = bak + + # Even partition + _verify("5,5,5,5", 20, 4, [(0, 5), (5, 10), (10, 15), (15, 20)]) + # Balanced partition + _verify("4,6,6,4", 20, 4, [(0, 4), (4, 10), (10, 16), (16, 20)]) + # Put reminder somewhere + _verify("5,6,5,6", 22, 4, [(0, 5), (5, 11), (11, 16), (16, 22)]) + # Invalid partition strings + with pytest.raises(ValueError): + _verify("5,5,5,5,", 20, 4, [(0, 5), (5, 10), (10, 15), (15, 20)]) + with pytest.raises(ValueError): + _verify("5,5,5,a", 20, 4, [(0, 5), (5, 10), (10, 15), (15, 20)]) + # Wrong number of partitions + with pytest.raises(ValueError): + _verify("5,5,5", 20, 4, [(0, 5), (5, 10), (10, 15), (15, 20)]) + # Wrong number of layers + with pytest.raises(ValueError): + _verify("5,5,5,5", 21, 4, [(0, 5), (5, 10), (10, 15), (15, 20)]) diff --git a/vllm/distributed/utils.py b/vllm/distributed/utils.py index b5cf6c45f478f..8c94ef8cb10ce 100644 --- a/vllm/distributed/utils.py +++ b/vllm/distributed/utils.py @@ -6,6 +6,11 @@ import torch +import vllm.envs as envs +from vllm.logger import init_logger + +logger = init_logger(__name__) + def ensure_divisibility(numerator, denominator): """Ensure that numerator is divisible by the denominator.""" @@ -54,11 +59,28 @@ def get_pp_indices(num_hidden_layers: int, pp_rank: int, If the number of layers is not divisible by the number of partitions, the last partition will have the remaining layers. """ - layers_per_partition = num_hidden_layers // pp_size - start_layer = pp_rank * layers_per_partition - end_layer = start_layer + layers_per_partition + partition_list_str = envs.VLLM_PP_LAYER_PARTITION + if partition_list_str is not None: + try: + partitions = [ + int(layer) for layer in partition_list_str.split(",") + ] + except ValueError as err: + raise ValueError("Invalid partition string: {}".format( + partition_list_str)) from err + if len(partitions) != pp_size: + raise ValueError(f"{len(partitions)=} does not match {pp_size=}.") + if sum(partitions) != num_hidden_layers: + raise ValueError( + f"{sum(partitions)=} does not match {num_hidden_layers=}.") + start_layer = sum(partitions[:pp_rank]) + end_layer = start_layer + partitions[pp_rank] + else: + layers_per_partition = num_hidden_layers // pp_size + start_layer = pp_rank * layers_per_partition + end_layer = start_layer + layers_per_partition - if pp_rank == pp_size - 1: - end_layer = num_hidden_layers + if pp_rank == pp_size - 1: + end_layer = num_hidden_layers return (start_layer, end_layer) diff --git a/vllm/envs.py b/vllm/envs.py index f06b6d66ea6f4..aef7ac385ec66 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -28,6 +28,7 @@ VLLM_LOGGING_CONFIG_PATH: Optional[str] = None VLLM_TRACE_FUNCTION: int = 0 VLLM_ATTENTION_BACKEND: Optional[str] = None + VLLM_PP_LAYER_PARTITION: Optional[str] = None VLLM_CPU_KVCACHE_SPACE: int = 0 VLLM_CPU_OMP_THREADS_BIND: str = "" VLLM_OPENVINO_KVCACHE_SPACE: int = 0 @@ -242,6 +243,10 @@ def get_default_config_root(): "VLLM_ATTENTION_BACKEND": lambda: os.getenv("VLLM_ATTENTION_BACKEND", None), + # Pipeline stage partition strategy + "VLLM_PP_LAYER_PARTITION": + lambda: os.getenv("VLLM_PP_LAYER_PARTITION", None), + # (CPU backend only) CPU key-value cache space. # default is 4GB "VLLM_CPU_KVCACHE_SPACE":