From b86675daa08291ffa288c399cfb6f29c1b54c8c8 Mon Sep 17 00:00:00 2001 From: Nadav Shmayovits Date: Sun, 9 Jun 2024 23:23:24 +0300 Subject: [PATCH 01/21] Change model config support unequal tp division --- vllm/config.py | 30 +++++++++++++++++++++--------- vllm/worker/cache_engine.py | 9 ++++++--- vllm/worker/model_runner.py | 12 ++++++++---- vllm/worker/worker.py | 3 ++- 4 files changed, 37 insertions(+), 17 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index eee62d2683835..86fb01dbcff8a 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -6,6 +6,7 @@ import torch from transformers import PretrainedConfig +from vllm.distributed import get_current_tp_rank_partition_size from vllm.logger import init_logger from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS from vllm.model_executor.models import ModelRegistry @@ -227,11 +228,13 @@ def verify_with_parallel_config( ) -> None: total_num_attention_heads = self.hf_text_config.num_attention_heads tensor_parallel_size = parallel_config.tensor_parallel_size - if total_num_attention_heads % tensor_parallel_size != 0: + if (total_num_attention_heads % tensor_parallel_size != 0 + and self.quantization is not None): raise ValueError( - f"Total number of attention heads ({total_num_attention_heads})" + f"Total number of attention heads " + f"({total_num_attention_heads})" " must be divisible by tensor parallel size " - f"({tensor_parallel_size}).") + f"({tensor_parallel_size}) when quantization is used.") total_num_hidden_layers = self.hf_text_config.num_hidden_layers pipeline_parallel_size = parallel_config.pipeline_parallel_size @@ -320,20 +323,29 @@ def get_total_num_kv_heads(self) -> int: # equal to the number of attention heads. return self.hf_text_config.num_attention_heads - def get_num_kv_heads(self, parallel_config: "ParallelConfig") -> int: + def get_num_kv_heads(self, + parallel_config: "ParallelConfig", + tp_rank: int = 0) -> int: """Returns the number of KV heads per GPU.""" total_num_kv_heads = self.get_total_num_kv_heads() # If tensor parallelism is used, we divide the number of KV heads by # the tensor parallel size. We will replicate the KV heads in the # case where the number of KV heads is smaller than the tensor # parallel size so each GPU has at least one KV head. - return max(1, - total_num_kv_heads // parallel_config.tensor_parallel_size) + result = get_current_tp_rank_partition_size( + total_num_kv_heads, tp_rank, parallel_config.tensor_parallel_size) + return max(1, result) def get_num_attention_heads(self, - parallel_config: "ParallelConfig") -> int: - return self.hf_text_config.num_attention_heads // \ - parallel_config.tensor_parallel_size + parallel_config: "ParallelConfig", + tp_rank: int = 0) -> int: + num_total_kv_heads = self.get_total_num_kv_heads() + num_kv_heads = self.get_num_kv_heads(parallel_config, tp_rank) + num_total_attention_heads = self.hf_text_config.num_attention_heads + num_heads_per_kv_head = num_total_attention_heads // num_total_kv_heads + # For GQA attention we make sure the whole attention head group is + # together on the same GPU. + return num_kv_heads * num_heads_per_kv_head def get_num_layers(self, parallel_config: "ParallelConfig") -> int: total_num_hidden_layers = self.hf_text_config.num_hidden_layers diff --git a/vllm/worker/cache_engine.py b/vllm/worker/cache_engine.py index 2f0e59f7ae7c9..a62526e8e26d7 100644 --- a/vllm/worker/cache_engine.py +++ b/vllm/worker/cache_engine.py @@ -24,6 +24,7 @@ def __init__( cache_config: CacheConfig, model_config: ModelConfig, parallel_config: ParallelConfig, + tp_rank: int = 0, ) -> None: self.cache_config = cache_config self.model_config = model_config @@ -31,7 +32,8 @@ def __init__( self.head_size = model_config.get_head_size() self.num_layers = model_config.get_num_layers(parallel_config) - self.num_kv_heads = model_config.get_num_kv_heads(parallel_config) + self.num_kv_heads = model_config.get_num_kv_heads( + parallel_config, tp_rank) self.block_size = cache_config.block_size self.num_gpu_blocks = cache_config.num_gpu_blocks @@ -44,7 +46,7 @@ def __init__( # Get attention backend. self.attn_backend = get_attn_backend( - model_config.get_num_attention_heads(parallel_config), + model_config.get_num_attention_heads(parallel_config, tp_rank), self.head_size, self.num_kv_heads, model_config.get_sliding_window(), @@ -96,9 +98,10 @@ def get_cache_block_size( cache_config: CacheConfig, model_config: ModelConfig, parallel_config: ParallelConfig, + tp_rank: int = 0, ) -> int: head_size = model_config.get_head_size() - num_heads = model_config.get_num_kv_heads(parallel_config) + num_heads = model_config.get_num_kv_heads(parallel_config, tp_rank) num_layers = model_config.get_num_layers(parallel_config) key_cache_block = cache_config.block_size * num_heads * head_size diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 63ec22d79694f..0b7c093d27352 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -83,6 +83,7 @@ def __init__( kv_cache_dtype: Optional[str] = "auto", is_driver_worker: bool = False, vision_language_config: Optional[VisionLanguageConfig] = None, + tp_rank: int = 0, ): self.model_config = model_config self.parallel_config = parallel_config @@ -97,6 +98,7 @@ def __init__( self.device = self.device_config.device self.pin_memory = is_pin_memory_available() + self.tp_rank = tp_rank self.kv_cache_dtype = kv_cache_dtype self.sliding_window = model_config.get_sliding_window() self.block_size = cache_config.block_size @@ -114,9 +116,11 @@ def __init__( (max(_BATCH_SIZES_TO_CAPTURE), self.get_max_block_per_batch()), dtype=np.int32) self.attn_backend = get_attn_backend( - self.model_config.get_num_attention_heads(self.parallel_config), + self.model_config.get_num_attention_heads(self.parallel_config, + self.tp_rank), self.model_config.get_head_size(), - self.model_config.get_num_kv_heads(self.parallel_config), + self.model_config.get_num_kv_heads(self.parallel_config, + self.tp_rank), self.model_config.get_sliding_window(), self.model_config.dtype, self.kv_cache_dtype, @@ -591,9 +595,9 @@ def _prepare_model_input( paged_kv_indices=paged_kv_indices_tensor, paged_kv_last_page_len=paged_kv_last_page_len_tensor, num_qo_heads=self.model_config.get_num_attention_heads( - self.parallel_config), + self.parallel_config, self.tp_rank), num_kv_heads=self.model_config.get_num_kv_heads( - self.parallel_config), + self.parallel_config, self.tp_rank), head_dim=self.model_config.get_head_size(), page_size=16, seq_start_loc=seq_start_loc, diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 10411a2bf7a10..d52020bf1f37b 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -82,6 +82,7 @@ def __init__( kv_cache_dtype=self.cache_config.cache_dtype, is_driver_worker=is_driver_worker, vision_language_config=vision_language_config, + tp_rank=self.rank, ) # Uninitialized cache engine. Will be initialized by # initialize_cache. @@ -197,7 +198,7 @@ def initialize_cache(self, num_gpu_blocks: int, def _init_cache_engine(self): assert self.cache_config.num_gpu_blocks is not None self.cache_engine = CacheEngine(self.cache_config, self.model_config, - self.parallel_config) + self.parallel_config, self.rank) self.gpu_cache = self.cache_engine.gpu_cache def _warm_up_model(self) -> None: From a789569d8ffd13ede52ca8c54ecb35ad6d44d951 Mon Sep 17 00:00:00 2001 From: Nadav Shmayovits Date: Sun, 9 Jun 2024 23:23:58 +0300 Subject: [PATCH 02/21] Add unequal tp division util functions --- vllm/distributed/parallel_state.py | 32 ++++++++++++++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index 0ebd7a15eab9b..b9a7ac3c57b7b 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -376,3 +376,35 @@ def destroy_model_parallel(): _PP_DEVICE_GROUP = None global _PP_GLOBAL_RANKS _PP_GLOBAL_RANKS = None + + +def get_current_tp_rank_partition_offset(total_size: int, + tp_rank: Optional[int] = None, + tp_size: Optional[int] = None, + multiple_of: int = 1) -> int: + if tp_rank is None: + tp_rank = get_tensor_model_parallel_rank() + + if tp_size is None: + tp_size = get_tensor_model_parallel_world_size() + + assert total_size % multiple_of == 0 + total_size = total_size // multiple_of + return ((total_size // tp_size) * tp_rank + + min(total_size % tp_size, tp_rank)) * multiple_of + + +def get_current_tp_rank_partition_size(total_size: int, + tp_rank: Optional[int] = None, + tp_size: Optional[int] = None, + multiple_of: int = 1) -> int: + if tp_rank is None: + tp_rank = get_tensor_model_parallel_rank() + + if tp_size is None: + tp_size = get_tensor_model_parallel_world_size() + + assert total_size % multiple_of == 0 + total_size = total_size // multiple_of + return ((total_size // tp_size) + + (total_size % tp_size > tp_rank)) * multiple_of From 428b85f653c724e69c65ff18c5380d4553cbcec8 Mon Sep 17 00:00:00 2001 From: Nadav Shmayovits Date: Sun, 9 Jun 2024 23:25:20 +0300 Subject: [PATCH 03/21] Change parallel layers to support unequal tp division --- vllm/model_executor/layers/linear.py | 79 ++++++++++++------- .../layers/vocab_parallel_embedding.py | 3 +- 2 files changed, 54 insertions(+), 28 deletions(-) diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index f5b6bdd9f7fd7..5165339a432b7 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -5,7 +5,9 @@ import torch.nn.functional as F from torch.nn.parameter import Parameter -from vllm.distributed import (divide, get_tensor_model_parallel_rank, +from vllm.distributed import (divide, get_current_tp_rank_partition_offset, + get_current_tp_rank_partition_size, + get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, split_tensor_along_last_dim, tensor_model_parallel_all_gather, @@ -236,14 +238,16 @@ def __init__(self, self.gather_output = gather_output # Divide the weight matrix along the last dimension. + tp_rank = get_tensor_model_parallel_rank() tp_size = get_tensor_model_parallel_world_size() assert self.quant_method is not None - self.output_size_per_partition = divide(self.output_size, tp_size) + self.output_size_per_partition = output_size // tp_size + ( + output_size % tp_size > tp_rank) self.output_partition_sizes = [self.output_size_per_partition] # If QKV or MergedColumn, use output size of each partition. if hasattr(self, "output_sizes"): self.output_partition_sizes = [ - divide(output_size, tp_size) + output_size // tp_size + (output_size % tp_size > tp_rank) for output_size in self.output_sizes ] @@ -334,17 +338,17 @@ class MergedColumnParallelLinear(ColumnParallelLinear): quant_config: Quantization configure. """ - def __init__(self, - input_size: int, - output_sizes: List[int], - bias: bool = True, - gather_output: bool = False, - skip_bias_add: bool = False, - params_dtype: Optional[torch.dtype] = None, - quant_config: Optional[QuantizationConfig] = None): + def __init__( + self, + input_size: int, + output_sizes: List[int], + bias: bool = True, + gather_output: bool = False, + skip_bias_add: bool = False, + params_dtype: Optional[torch.dtype] = None, + quant_config: Optional[QuantizationConfig] = None, + ): self.output_sizes = output_sizes - tp_size = get_tensor_model_parallel_world_size() - assert all(output_size % tp_size == 0 for output_size in output_sizes) super().__init__(input_size=input_size, output_size=sum(output_sizes), bias=bias, @@ -418,8 +422,12 @@ def weight_loader(self, tp_rank = get_tensor_model_parallel_rank() tp_size = get_tensor_model_parallel_world_size() if output_dim is not None: - shard_offset = sum(self.output_sizes[:loaded_shard_id]) // tp_size - shard_size = self.output_sizes[loaded_shard_id] // tp_size + shard_offset = sum( + get_current_tp_rank_partition_size(output_size, tp_rank, + tp_size) + for output_size in self.output_sizes[:loaded_shard_id]) + shard_size = get_current_tp_rank_partition_size( + self.output_sizes[loaded_shard_id], tp_rank, tp_size) # Special case for quantization. # If quantized, we need to adjust the offset and size to account # for the packing. @@ -439,7 +447,8 @@ def weight_loader(self, param_data = param_data.narrow(output_dim, shard_offset, shard_size) - start_idx = tp_rank * shard_size + start_idx = get_current_tp_rank_partition_offset( + loaded_weight.shape[output_dim], tp_rank, tp_size) loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size) # Special case for AQLM codebooks. @@ -520,14 +529,18 @@ def __init__(self, self.total_num_kv_heads = total_num_kv_heads # Divide the weight matrix along the last dimension. tp_size = get_tensor_model_parallel_world_size() - self.num_heads = divide(self.total_num_heads, tp_size) + tp_rank = get_tensor_model_parallel_rank() + self.num_heads_per_kv_head = (self.total_num_heads // + self.total_num_kv_heads) + self.num_kv_heads = get_current_tp_rank_partition_size( + self.total_num_kv_heads, tp_rank, tp_size) + self.num_heads = self.num_kv_heads * self.num_heads_per_kv_head + self.num_kv_head_replicas = 1 if tp_size >= self.total_num_kv_heads: self.num_kv_heads = 1 self.num_kv_head_replicas = divide(tp_size, self.total_num_kv_heads) - else: - self.num_kv_heads = divide(self.total_num_kv_heads, tp_size) - self.num_kv_head_replicas = 1 + input_size = self.hidden_size output_size = (self.num_heads + 2 * self.num_kv_heads) * tp_size * self.head_size @@ -655,10 +668,13 @@ def weight_loader(self, param_data = param_data.narrow(output_dim, shard_offset, shard_size) if loaded_shard_id == "q": - shard_id = tp_rank + multiple_of = self.head_size * self.num_heads_per_kv_head else: - shard_id = tp_rank // self.num_kv_head_replicas - start_idx = shard_id * shard_size + multiple_of = self.head_size + tp_size = get_tensor_model_parallel_world_size() + total_size = loaded_weight.shape[output_dim] + start_idx = get_current_tp_rank_partition_offset( + total_size, tp_rank, tp_size, multiple_of=multiple_of) loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size) # Special case for for AQLM codebooks. @@ -720,6 +736,8 @@ class RowParallelLinear(LinearBase): We skip adding bias but instead return it. params_dtype: Data type for the parameters. quant_config: Quantization configure. + partition_multiple_of: Partitions will be divided, + so each partition is a multiple of this number. """ def __init__(self, @@ -730,7 +748,8 @@ def __init__(self, skip_bias_add: bool = False, params_dtype: Optional[torch.dtype] = None, reduce_results: bool = True, - quant_config: Optional[QuantizationConfig] = None): + quant_config: Optional[QuantizationConfig] = None, + partition_multiple_of: int = 1): super().__init__(input_size, output_size, skip_bias_add, params_dtype, quant_config) @@ -739,7 +758,10 @@ def __init__(self, # Divide the weight matrix along the last dimension. self.tp_size = get_tensor_model_parallel_world_size() - self.input_size_per_partition = divide(input_size, self.tp_size) + self.tp_rank = get_tensor_model_parallel_rank() + self.partition_multiple_of = partition_multiple_of + self.input_size_per_partition = get_current_tp_rank_partition_size( + input_size, self.tp_rank, self.tp_size, partition_multiple_of) assert self.quant_method is not None self.quant_method.create_weights( layer=self, @@ -768,12 +790,15 @@ def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): fp8_scales_shard_indexer = getattr(param, "fp8_scales_shard_indexer", None) - tp_rank = get_tensor_model_parallel_rank() input_dim = getattr(param, "input_dim", None) param_data = param.data if input_dim is not None: shard_size = param_data.shape[input_dim] - start_idx = tp_rank * shard_size + start_idx = get_current_tp_rank_partition_offset( + self.input_size, + self.tp_rank, + self.tp_size, + multiple_of=self.partition_multiple_of) loaded_weight = loaded_weight.narrow(input_dim, start_idx, shard_size) diff --git a/vllm/model_executor/layers/vocab_parallel_embedding.py b/vllm/model_executor/layers/vocab_parallel_embedding.py index 4585b1679cb5c..ce983f95fef0b 100644 --- a/vllm/model_executor/layers/vocab_parallel_embedding.py +++ b/vllm/model_executor/layers/vocab_parallel_embedding.py @@ -51,9 +51,10 @@ def __init__(self, embedding_dim: int, params_dtype: Optional[torch.dtype] = None, org_num_embeddings: Optional[int] = None, - padding_size: int = DEFAULT_VOCAB_PADDING_SIZE): + padding_size: Optional[int] = None): super().__init__() + padding_size = padding_size or get_tensor_model_parallel_world_size() # Keep the input dimensions. self.num_embeddings = num_embeddings self.org_vocab_size = org_num_embeddings or num_embeddings From c485d500739b302b97a1bc3d573dcc13c9cb548f Mon Sep 17 00:00:00 2001 From: Nadav Shmayovits Date: Sun, 9 Jun 2024 23:27:45 +0300 Subject: [PATCH 04/21] Add unequal tp division support for opt model --- vllm/model_executor/models/opt.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/vllm/model_executor/models/opt.py b/vllm/model_executor/models/opt.py index 4bf59105dbabb..064bbe89835de 100644 --- a/vllm/model_executor/models/opt.py +++ b/vllm/model_executor/models/opt.py @@ -25,7 +25,9 @@ from vllm.attention import Attention, AttentionMetadata from vllm.config import CacheConfig -from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.distributed import (get_current_tp_rank_partition_size, + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size) from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.linear import (ColumnParallelLinear, QKVParallelLinear, @@ -67,11 +69,11 @@ def __init__( ) -> None: super().__init__() self.embed_dim = embed_dim - tensor_model_parallel_world_size = ( - get_tensor_model_parallel_world_size()) + tp_size = get_tensor_model_parallel_world_size() + tp_rank = get_tensor_model_parallel_rank() total_num_heads = num_heads - assert num_heads % tensor_model_parallel_world_size == 0 - self.num_heads = total_num_heads // tensor_model_parallel_world_size + self.num_heads = get_current_tp_rank_partition_size( + total_num_heads, tp_rank, tp_size) self.head_dim = embed_dim // total_num_heads self.scaling = self.head_dim**-0.5 @@ -87,6 +89,7 @@ def __init__( embed_dim, bias=bias, quant_config=quant_config, + partition_multiple_of=self.head_dim, ) self.attn = Attention(self.num_heads, self.head_dim, From 1cf543b10c8b7387af1d4ea76c32ca18987d559c Mon Sep 17 00:00:00 2001 From: Nadav Shmayovits Date: Sun, 9 Jun 2024 23:28:04 +0300 Subject: [PATCH 05/21] Add unequal tp division support for commandr model --- vllm/model_executor/models/commandr.py | 39 ++++++++++++++++++-------- 1 file changed, 28 insertions(+), 11 deletions(-) diff --git a/vllm/model_executor/models/commandr.py b/vllm/model_executor/models/commandr.py index 84786921ce1b4..26787522a7fad 100644 --- a/vllm/model_executor/models/commandr.py +++ b/vllm/model_executor/models/commandr.py @@ -30,7 +30,9 @@ from vllm.attention import Attention, AttentionMetadata from vllm.config import CacheConfig -from vllm.distributed import (get_tensor_model_parallel_rank, +from vllm.distributed import (get_current_tp_rank_partition_offset, + get_current_tp_rank_partition_size, + get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, @@ -49,7 +51,7 @@ from vllm.sequence import SamplerOutput -@torch.compile +#@torch.compile def layer_norm_func(hidden_states, weight, variance_epsilon): input_dtype = hidden_states.dtype hidden_states = hidden_states.to(torch.float32) @@ -63,11 +65,15 @@ def layer_norm_func(hidden_states, weight, variance_epsilon): class LayerNorm(nn.Module): - def __init__(self, param_shape=None, eps=1e-5): + def __init__(self, + param_shape=None, + eps=1e-5, + partition_multiple_of: int = 1): super().__init__() self.weight = nn.Parameter(torch.ones(param_shape)) self.variance_epsilon = eps set_weight_attrs(self.weight, {"weight_loader": self.weight_loader}) + self.partition_multiple_of = partition_multiple_of def forward(self, hidden_states, residuals=None): hidden_states = layer_norm_func(hidden_states, self.weight, @@ -76,11 +82,14 @@ def forward(self, hidden_states, residuals=None): def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): tp_rank = get_tensor_model_parallel_rank() + tp_size = get_tensor_model_parallel_world_size() shard_dim = 0 if param.dim() != 1 else None param_data = param.data if shard_dim is not None: shard_size = param_data.shape[shard_dim] - start_idx = tp_rank * shard_size + start_idx = get_current_tp_rank_partition_offset( + loaded_weight.shape[shard_dim], tp_rank, tp_size, + self.partition_multiple_of) loaded_weight = loaded_weight.narrow(shard_dim, start_idx, shard_size) assert param_data.shape == loaded_weight.shape @@ -130,22 +139,29 @@ def __init__( ): super().__init__() tp_size = get_tensor_model_parallel_world_size() + tp_rank = get_tensor_model_parallel_rank() self.config = config self.attention_dropout = config.attention_dropout self.hidden_size = config.hidden_size self.total_num_heads = config.num_attention_heads + self.total_num_kv_heads = config.num_key_value_heads self.num_heads = self.total_num_heads // tp_size + self.num_kv_heads = max( + 1, + get_current_tp_rank_partition_size(self.total_num_kv_heads, + tp_rank, tp_size)) + num_heads_per_kv_head = self.total_num_heads // self.total_num_kv_heads + self.num_heads = self.num_kv_heads * num_heads_per_kv_head self.head_dim = self.hidden_size // self.total_num_heads - self.total_num_kv_heads = config.num_key_value_heads if self.total_num_kv_heads >= tp_size: # Number of KV heads is greater than TP size, so we partition # the KV heads across multiple tensor parallel GPUs. - assert self.total_num_kv_heads % tp_size == 0 + pass + #LATER assert self.total_num_kv_heads % tp_size == 0 else: # Number of KV heads is less than TP size, so we replicate # the KV heads across multiple tensor parallel GPUs. assert tp_size % self.total_num_kv_heads == 0 - self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) self.q_size = self.num_heads * self.head_dim self.kv_size = self.num_kv_heads * self.head_dim self.scaling = self.head_dim**-0.5 @@ -168,7 +184,7 @@ def __init__( self.hidden_size, bias=False, quant_config=quant_config, - ) + partition_multiple_of=num_heads_per_kv_head * self.head_dim) self.rotary_emb = get_rope( self.head_dim, rotary_dim=self.head_dim, @@ -184,9 +200,10 @@ def __init__( cache_config=cache_config, quant_config=quant_config) if self.use_qk_norm: - self.q_norm = LayerNorm(param_shape=(self.num_heads, - self.head_dim), - eps=config.layer_norm_eps) + self.q_norm = LayerNorm( + param_shape=(self.num_heads, self.head_dim), + eps=config.layer_norm_eps, + partition_multiple_of=num_heads_per_kv_head) self.k_norm = LayerNorm(param_shape=(self.num_kv_heads, self.head_dim), eps=config.layer_norm_eps) From a6970c0ec645102462e763a083cb63bc535a0e71 Mon Sep 17 00:00:00 2001 From: Nadav Shmayovits Date: Sun, 9 Jun 2024 23:28:20 +0300 Subject: [PATCH 06/21] Add unequal tp division support for llama model --- vllm/model_executor/models/llama.py | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index d83ee9a201c0b..6c173e6a9091d 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -29,7 +29,8 @@ from vllm.attention import Attention, AttentionMetadata from vllm.config import CacheConfig, LoRAConfig -from vllm.distributed import (get_tensor_model_parallel_rank, +from vllm.distributed import (get_current_tp_rank_partition_size, + get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm @@ -99,19 +100,24 @@ def __init__( super().__init__() self.hidden_size = hidden_size tp_size = get_tensor_model_parallel_world_size() + tp_rank = get_tensor_model_parallel_rank() self.total_num_heads = num_heads - assert self.total_num_heads % tp_size == 0 - self.num_heads = self.total_num_heads // tp_size self.total_num_kv_heads = num_kv_heads + self.num_kv_heads = max( + 1, + get_current_tp_rank_partition_size(self.total_num_kv_heads, + tp_rank, tp_size)) + num_heads_per_kv_head = self.total_num_heads // self.total_num_kv_heads + self.num_heads = self.num_kv_heads * num_heads_per_kv_head if self.total_num_kv_heads >= tp_size: # Number of KV heads is greater than TP size, so we partition # the KV heads across multiple tensor parallel GPUs. - assert self.total_num_kv_heads % tp_size == 0 + pass + #LATER assert self.total_num_kv_heads % tp_size == 0 else: # Number of KV heads is less than TP size, so we replicate # the KV heads across multiple tensor parallel GPUs. assert tp_size % self.total_num_kv_heads == 0 - self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) self.head_dim = hidden_size // self.total_num_heads self.q_size = self.num_heads * self.head_dim self.kv_size = self.num_kv_heads * self.head_dim @@ -132,6 +138,7 @@ def __init__( output_size=hidden_size, bias=bias, quant_config=quant_config, + partition_multiple_of=num_heads_per_kv_head * self.head_dim, ) self.rotary_emb = get_rope( From 6a4b70e59cbb220911476e3807247bf2eebd86e9 Mon Sep 17 00:00:00 2001 From: Nadav Shmayovits Date: Sun, 9 Jun 2024 23:31:45 +0300 Subject: [PATCH 07/21] Remove asserts in Llama and CommandR implementation --- vllm/model_executor/models/commandr.py | 9 --------- vllm/model_executor/models/llama.py | 9 --------- 2 files changed, 18 deletions(-) diff --git a/vllm/model_executor/models/commandr.py b/vllm/model_executor/models/commandr.py index 26787522a7fad..4ad6eb7d1914d 100644 --- a/vllm/model_executor/models/commandr.py +++ b/vllm/model_executor/models/commandr.py @@ -153,15 +153,6 @@ def __init__( num_heads_per_kv_head = self.total_num_heads // self.total_num_kv_heads self.num_heads = self.num_kv_heads * num_heads_per_kv_head self.head_dim = self.hidden_size // self.total_num_heads - if self.total_num_kv_heads >= tp_size: - # Number of KV heads is greater than TP size, so we partition - # the KV heads across multiple tensor parallel GPUs. - pass - #LATER assert self.total_num_kv_heads % tp_size == 0 - else: - # Number of KV heads is less than TP size, so we replicate - # the KV heads across multiple tensor parallel GPUs. - assert tp_size % self.total_num_kv_heads == 0 self.q_size = self.num_heads * self.head_dim self.kv_size = self.num_kv_heads * self.head_dim self.scaling = self.head_dim**-0.5 diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 6c173e6a9091d..a0fdc40f68d12 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -109,15 +109,6 @@ def __init__( tp_rank, tp_size)) num_heads_per_kv_head = self.total_num_heads // self.total_num_kv_heads self.num_heads = self.num_kv_heads * num_heads_per_kv_head - if self.total_num_kv_heads >= tp_size: - # Number of KV heads is greater than TP size, so we partition - # the KV heads across multiple tensor parallel GPUs. - pass - #LATER assert self.total_num_kv_heads % tp_size == 0 - else: - # Number of KV heads is less than TP size, so we replicate - # the KV heads across multiple tensor parallel GPUs. - assert tp_size % self.total_num_kv_heads == 0 self.head_dim = hidden_size // self.total_num_heads self.q_size = self.num_heads * self.head_dim self.kv_size = self.num_kv_heads * self.head_dim From 6b33c8774bdade45f4f74ebf33a03c75461a0151 Mon Sep 17 00:00:00 2001 From: Nadav Shmayovits Date: Tue, 11 Jun 2024 12:07:48 +0300 Subject: [PATCH 08/21] Add tp_rank to EmbeddingModelRunner class --- vllm/worker/embedding_model_runner.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/vllm/worker/embedding_model_runner.py b/vllm/worker/embedding_model_runner.py index 465130d10e2f9..ca4e4e3bfc872 100644 --- a/vllm/worker/embedding_model_runner.py +++ b/vllm/worker/embedding_model_runner.py @@ -32,6 +32,7 @@ def __init__( kv_cache_dtype: Optional[str] = "auto", is_driver_worker: bool = False, vision_language_config: Optional[VisionLanguageConfig] = None, + tp_rank: int = 0, ): super().__init__(model_config, parallel_config, @@ -42,7 +43,8 @@ def __init__( lora_config=lora_config, kv_cache_dtype=kv_cache_dtype, is_driver_worker=is_driver_worker, - vision_language_config=vision_language_config) + vision_language_config=vision_language_config, + tp_rank=tp_rank) @torch.inference_mode() def execute_model( From 90d9f6c4b06cffb50fc95cbd8306353b91e704fe Mon Sep 17 00:00:00 2001 From: Nadav Shmayovits Date: Tue, 11 Jun 2024 12:42:30 +0300 Subject: [PATCH 09/21] Fix QKVLinear to work with packed dim --- vllm/model_executor/layers/linear.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index 5165339a432b7..1c08e5f2f909a 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -247,7 +247,8 @@ def __init__(self, # If QKV or MergedColumn, use output size of each partition. if hasattr(self, "output_sizes"): self.output_partition_sizes = [ - output_size // tp_size + (output_size % tp_size > tp_rank) + get_current_tp_rank_partition_size(output_size, tp_rank, + tp_size) for output_size in self.output_sizes ] @@ -630,13 +631,17 @@ def weight_loader(self, if loaded_shard_id == "q": shard_offset = 0 shard_size = self.num_heads * self.head_size + multiple_of = self.head_size * self.num_heads_per_kv_head elif loaded_shard_id == "k": shard_offset = self.num_heads * self.head_size shard_size = self.num_kv_heads * self.head_size + multiple_of = self.head_size elif loaded_shard_id == "v": shard_offset = (self.num_heads + self.num_kv_heads) * self.head_size shard_size = self.num_kv_heads * self.head_size + multiple_of = self.head_size + # Special case for Quantized Weights. # If quantized, we need to adjust the offset and size to account # for the packing. @@ -644,6 +649,7 @@ def weight_loader(self, if packed_dim == output_dim: shard_size = shard_size // param.pack_factor shard_offset = shard_offset // param.pack_factor + multiple_of = multiple_of // param.pack_factor # Special case for Marlin. shard_size, shard_offset = adjust_marlin_shard( @@ -667,10 +673,7 @@ def weight_loader(self, param_data = param_data.narrow(output_dim, shard_offset, shard_size) - if loaded_shard_id == "q": - multiple_of = self.head_size * self.num_heads_per_kv_head - else: - multiple_of = self.head_size + tp_size = get_tensor_model_parallel_world_size() total_size = loaded_weight.shape[output_dim] start_idx = get_current_tp_rank_partition_offset( From 014b682b6c332e20d31ba966fb601d902c5be7ee Mon Sep 17 00:00:00 2001 From: Nadav Shmayovits Date: Tue, 11 Jun 2024 14:49:38 +0300 Subject: [PATCH 10/21] Fix imports formatting in layer/linear.py file --- vllm/model_executor/layers/linear.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index 1c08e5f2f909a..3de5a2c0abdeb 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -5,13 +5,11 @@ import torch.nn.functional as F from torch.nn.parameter import Parameter -from vllm.distributed import (divide, get_current_tp_rank_partition_offset, - get_current_tp_rank_partition_size, - get_tensor_model_parallel_rank, - get_tensor_model_parallel_world_size, - split_tensor_along_last_dim, - tensor_model_parallel_all_gather, - tensor_model_parallel_all_reduce) +from vllm.distributed import ( + divide, get_current_tp_rank_partition_offset, + get_current_tp_rank_partition_size, get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, split_tensor_along_last_dim, + tensor_model_parallel_all_gather, tensor_model_parallel_all_reduce) from vllm.logger import init_logger from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase) From cdb2e27449b25827cab7fca2d927dae2fd85a98c Mon Sep 17 00:00:00 2001 From: Nadav Shmayovits Date: Wed, 3 Jul 2024 10:30:52 +0300 Subject: [PATCH 11/21] Remove unused variable --- vllm/model_executor/layers/linear.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index e3d4a6b92c801..cb83b551c5022 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -748,10 +748,6 @@ def __init__(self, self.register_parameter("bias", None) def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): - # Special case for Fp8 scales. - fp8_scales_shard_indexer = getattr(param, "fp8_scales_shard_indexer", - None) - input_dim = getattr(param, "input_dim", None) param_data = param.data if input_dim is not None: From b9e530988f2ea833c3427392ef1bb655d9a83bfe Mon Sep 17 00:00:00 2001 From: Nadav Shmayovits Date: Wed, 3 Jul 2024 14:17:35 +0300 Subject: [PATCH 12/21] Fix failing tests --- vllm/config.py | 3 +++ vllm/spec_decode/draft_model_runner.py | 2 ++ 2 files changed, 5 insertions(+) diff --git a/vllm/config.py b/vllm/config.py index c5e52b92991fd..711f89bc13f69 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -389,6 +389,9 @@ def get_num_kv_heads(self, def get_num_attention_heads(self, parallel_config: "ParallelConfig", tp_rank: int = 0) -> int: + if getattr(self.hf_text_config, "num_attention_heads", None) is None: + return 0 + num_total_kv_heads = self.get_total_num_kv_heads() num_kv_heads = self.get_num_kv_heads(parallel_config, tp_rank) num_total_attention_heads = self.hf_text_config.num_attention_heads diff --git a/vllm/spec_decode/draft_model_runner.py b/vllm/spec_decode/draft_model_runner.py index 1c7b8c07e89e5..bfbc4de347db8 100644 --- a/vllm/spec_decode/draft_model_runner.py +++ b/vllm/spec_decode/draft_model_runner.py @@ -49,6 +49,7 @@ def __init__( is_driver_worker: bool = False, vision_language_config: Optional[VisionLanguageConfig] = None, return_hidden_states: bool = False, + **kwargs, ): if return_hidden_states: raise ValueError( @@ -67,6 +68,7 @@ def __init__( is_driver_worker=is_driver_worker, vision_language_config=vision_language_config, return_hidden_states=return_hidden_states, + **kwargs, ) # TODO: Remove this cache when we are able to update model_input From a268f20df7071124f061b77d9ea531587a50b916 Mon Sep 17 00:00:00 2001 From: Nadav Shmayovits Date: Wed, 3 Jul 2024 14:33:27 +0300 Subject: [PATCH 13/21] Fix formatting --- vllm/config.py | 2 +- vllm/engine/metrics.py | 2 +- vllm/model_executor/layers/linear.py | 4 ++++ vllm/model_executor/models/llama.py | 5 ++--- 4 files changed, 8 insertions(+), 5 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index 711f89bc13f69..06cf7c4a6fbd1 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -7,8 +7,8 @@ import torch from transformers import PretrainedConfig -from vllm.distributed import get_current_tp_rank_partition_size import vllm.envs as envs +from vllm.distributed import get_current_tp_rank_partition_size from vllm.logger import init_logger from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS from vllm.model_executor.models import ModelRegistry diff --git a/vllm/engine/metrics.py b/vllm/engine/metrics.py index 77de42bc0ed5d..2c1210c90c632 100644 --- a/vllm/engine/metrics.py +++ b/vllm/engine/metrics.py @@ -457,4 +457,4 @@ def log(self, stats: Stats): class RayPrometheusStatLogger(PrometheusStatLogger): """RayPrometheusStatLogger uses Ray metrics instead.""" - _metrics_cls = RayMetrics \ No newline at end of file + _metrics_cls = RayMetrics diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index cb83b551c5022..7f2393ae634dc 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -5,11 +5,15 @@ import torch.nn.functional as F from torch.nn.parameter import Parameter +# fixme Isort and yapf conflict for this, so we disable isort for this block +# isort: off from vllm.distributed import ( divide, get_current_tp_rank_partition_offset, get_current_tp_rank_partition_size, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, split_tensor_along_last_dim, tensor_model_parallel_all_gather, tensor_model_parallel_all_reduce) +# isort: on + from vllm.logger import init_logger from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase) diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index f46930b6ccc00..096a3469925a6 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -29,9 +29,8 @@ from vllm.attention import Attention, AttentionMetadata from vllm.config import CacheConfig, LoRAConfig -from vllm.distributed import (get_current_tp_rank_partition_size, - get_pp_group, get_pp_indices, - get_tensor_model_parallel_rank, +from vllm.distributed import (get_current_tp_rank_partition_size, get_pp_group, + get_pp_indices, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm From b033a439f852bc25222443c0b50bfa00f06aadd4 Mon Sep 17 00:00:00 2001 From: Nadav Shmayovits Date: Wed, 3 Jul 2024 23:12:50 +0300 Subject: [PATCH 14/21] Add uneven tensor parallel test cases --- .../test_basic_distributed_correctness.py | 37 ++++++ .../test_chunked_prefill_distributed.py | 50 +++++++ .../e2e/test_integration_dist_tp3.py | 123 ++++++++++++++++++ 3 files changed, 210 insertions(+) create mode 100644 tests/spec_decode/e2e/test_integration_dist_tp3.py diff --git a/tests/distributed/test_basic_distributed_correctness.py b/tests/distributed/test_basic_distributed_correctness.py index 7a0e5673b2cc4..ccdeb997ed735 100644 --- a/tests/distributed/test_basic_distributed_correctness.py +++ b/tests/distributed/test_basic_distributed_correctness.py @@ -61,3 +61,40 @@ def test_models( name_0="hf", name_1="vllm", ) + + +@pytest.mark.skipif(cuda_device_count_stateless() < 3, + reason="Need at least 3 GPUs to run the test.") +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["half"]) +@pytest.mark.parametrize("max_tokens", [5]) +def test_models_uneven_tensor_parallel( + hf_runner, + vllm_runner, + example_prompts, + model: str, + dtype: str, + max_tokens: int, +) -> None: + distributed_executor_backend = os.getenv(DISTRIBUTED_EXECUTOR_BACKEND) + + # NOTE: take care of the order. run vLLM first, and then run HF. + # vLLM needs a fresh new process without cuda initialization. + # if we run HF first, the cuda initialization will be done and it + # will hurt multiprocessing backend with fork method (the default method). + with vllm_runner(model, + dtype=dtype, + tensor_parallel_size=3, + distributed_executor_backend=distributed_executor_backend + ) as vllm_model: + vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) + + with hf_runner(model, dtype=dtype) as hf_model: + hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens) + + check_outputs_equal( + outputs_0_lst=hf_outputs, + outputs_1_lst=vllm_outputs, + name_0="hf", + name_1="vllm", + ) diff --git a/tests/distributed/test_chunked_prefill_distributed.py b/tests/distributed/test_chunked_prefill_distributed.py index 1ef085b933793..c561b15f2755e 100644 --- a/tests/distributed/test_chunked_prefill_distributed.py +++ b/tests/distributed/test_chunked_prefill_distributed.py @@ -73,3 +73,53 @@ def test_models( name_0="hf", name_1="vllm", ) + + +@pytest.mark.skipif(cuda_device_count_stateless() < 3, + reason="Need at least 3 GPUs to run the test.") +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["half"]) +@pytest.mark.parametrize("max_tokens", [5]) +@pytest.mark.parametrize("chunked_prefill_token_size", [16]) +def test_models_uneven_tensor_parallel( + hf_runner, + vllm_runner, + example_prompts, + model: str, + dtype: str, + max_tokens: int, + chunked_prefill_token_size: int, +) -> None: + distributed_executor_backend = os.getenv(DISTRIBUTED_EXECUTOR_BACKEND) + + # Add a chunked prefill config. + max_num_seqs = min(chunked_prefill_token_size, 256) + assert chunked_prefill_token_size != -1 + enable_chunked_prefill = True + max_num_batched_tokens = chunked_prefill_token_size + + # NOTE: take care of the order. run vLLM first, and then run HF. + # vLLM needs a fresh new process without cuda initialization. + # if we run HF first, the cuda initialization will be done and it + # will hurt multiprocessing backend with fork method (the default method). + + with vllm_runner( + model, + dtype=dtype, + tensor_parallel_size=3, + max_num_seqs=max_num_seqs, + enable_chunked_prefill=enable_chunked_prefill, + max_num_batched_tokens=max_num_batched_tokens, + distributed_executor_backend=distributed_executor_backend, + ) as vllm_model: + vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) + + with hf_runner(model, dtype=dtype) as hf_model: + hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens) + + check_outputs_equal( + outputs_0_lst=hf_outputs, + outputs_1_lst=vllm_outputs, + name_0="hf", + name_1="vllm", + ) diff --git a/tests/spec_decode/e2e/test_integration_dist_tp3.py b/tests/spec_decode/e2e/test_integration_dist_tp3.py new file mode 100644 index 0000000000000..e6178b0930546 --- /dev/null +++ b/tests/spec_decode/e2e/test_integration_dist_tp3.py @@ -0,0 +1,123 @@ +"""Tests which cover integration of the speculative decoding framework with +tensor parallelism. +""" + +import pytest +import torch + +from vllm.utils import is_hip + +from .conftest import run_greedy_equality_correctness_test + + +@pytest.mark.skipif(torch.cuda.device_count() < 3, + reason="Need at least 3 GPUs to run the test.") +@pytest.mark.parametrize( + "common_llm_kwargs", + [{ + "model": "JackFram/llama-68m", + + # Skip cuda graph recording for fast test. + "enforce_eager": True, + + # Required for spec decode. + "use_v2_block_manager": True, + "tensor_parallel_size": 3, + + # Use AsyncLLM engine, so that the engine runs in its own process. + # Otherwise, since vLLM does not follow true SPMD, the test runner + # process will have both the engine and the rank0 worker. NCCL is not + # cleaned up properly, and its server host thread leaks, causing the + # second run of the test to fail with internal NCCL error. + "use_async": True, + }]) +@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) +@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) +@pytest.mark.parametrize("test_llm_kwargs", [ + { + "speculative_model": "JackFram/llama-68m", + "num_speculative_tokens": 3, + }, + { + "speculative_model": "[ngram]", + "num_speculative_tokens": 5, + "ngram_prompt_lookup_max": 3, + }, +]) +@pytest.mark.parametrize("batch_size", [2]) +@pytest.mark.parametrize( + "output_len", + [ + # Use smaller output len for fast test. + 32, + ]) +@pytest.mark.parametrize("seed", [1]) +def test_target_model_tp_gt_1(baseline_llm_generator, test_llm_generator, + batch_size: int, output_len: int): + """Verify greedy equality when tensor parallelism is used. + """ + if is_hip(): + pytest.skip("hip is not well-supported yet") + run_greedy_equality_correctness_test(baseline_llm_generator, + test_llm_generator, + batch_size, + max_output_len=output_len, + force_output_len=True) + + +@pytest.mark.skipif(torch.cuda.device_count() < 3, + reason="Need at least 3 GPUs to run the test.") +@pytest.mark.parametrize( + "common_llm_kwargs", + [{ + # Skip cuda graph recording for fast test. + "enforce_eager": True, + + # Required for spec decode. + "use_v2_block_manager": True, + "tensor_parallel_size": 3, + + # Use AsyncLLM engine, so that the engine runs in its own process. + # Otherwise, since vLLM does not follow true SPMD, the test runner + # process will have both the engine and the rank0 worker. NCCL is not + # cleaned up properly, and its server host thread leaks, causing the + # second run of the test to fail with internal NCCL error. + "use_async": True, + }]) +@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) +@pytest.mark.parametrize( + "per_test_common_llm_kwargs, test_llm_kwargs", + [ + ( + { + # Use a small model for a fast test. + # Note this is repeated in the test body; to initialize a + # tokenizer. + "model": "JackFram/llama-68m", + }, + { + "speculative_model": "JackFram/llama-68m", + "num_speculative_tokens": 5, + "speculative_draft_tensor_parallel_size": 1, + }), + ({ + "model": "ibm-granite/granite-3b-code-instruct", + }, { + "speculative_model": + "ibm-granite/granite-3b-code-instruct-accelerator", + "num_speculative_tokens": 5, + "speculative_draft_tensor_parallel_size": 1, + }) + ]) +@pytest.mark.parametrize("batch_size", [2]) +@pytest.mark.parametrize("seed", [1]) +def test_draft_model_tp_lt_target_model_tp2(test_llm_generator, + baseline_llm_generator, + batch_size: int): + """Verify spec decode works well with smaller tp for draft models. + """ + run_greedy_equality_correctness_test(baseline_llm_generator, + test_llm_generator, + batch_size, + max_output_len=32, + force_output_len=True) From 34f98500177564eec383efda49330e133a32de87 Mon Sep 17 00:00:00 2001 From: Nadav Shmayovits Date: Wed, 3 Jul 2024 23:13:31 +0300 Subject: [PATCH 15/21] Fix review comments --- vllm/model_executor/layers/linear.py | 4 ++-- vllm/model_executor/models/commandr.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index 7f2393ae634dc..e6b0118684446 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -266,8 +266,8 @@ def __init__(self, tp_rank = get_tensor_model_parallel_rank() tp_size = get_tensor_model_parallel_world_size() assert self.quant_method is not None - self.output_size_per_partition = output_size // tp_size + ( - output_size % tp_size > tp_rank) + self.output_size_per_partition = get_current_tp_rank_partition_size( + output_size, tp_rank, tp_size) self.output_partition_sizes = [self.output_size_per_partition] # If QKV or MergedColumn, use output size of each partition. if hasattr(self, "output_sizes"): diff --git a/vllm/model_executor/models/commandr.py b/vllm/model_executor/models/commandr.py index 14b1df4835207..516ea98418aca 100644 --- a/vllm/model_executor/models/commandr.py +++ b/vllm/model_executor/models/commandr.py @@ -51,7 +51,7 @@ from vllm.sequence import IntermediateTensors, SamplerOutput -#@torch.compile +@torch.compile def layer_norm_func(hidden_states, weight, variance_epsilon): input_dtype = hidden_states.dtype hidden_states = hidden_states.to(torch.float32) From a154adefeba6fefb65298de906a4b9490b5ad335 Mon Sep 17 00:00:00 2001 From: Nadav Shmayovits Date: Thu, 4 Jul 2024 10:46:00 +0300 Subject: [PATCH 16/21] Fix uneven TP tests and add to .buildkite --- .buildkite/test-pipeline.yaml | 3 + .../test_basic_distributed_correctness.py | 37 --------- .../test_uneven_distributed_correctness.py | 45 +++++++++++ .../e2e/test_integration_dist_tp3.py | 79 ++----------------- 4 files changed, 56 insertions(+), 108 deletions(-) create mode 100644 tests/distributed/test_uneven_distributed_correctness.py diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index d127278aaae2d..27bd1db22a685 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -72,7 +72,10 @@ steps: # See https://github.com/vllm-project/vllm/pull/5473#issuecomment-2166601837 for context. - TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py - TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_basic_distributed_correctness.py + - TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_uneven_distributed_correctness.py + - TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_uneven_distributed_correctness.py - pytest -v -s spec_decode/e2e/test_integration_dist_tp4.py + - pytest -v -s spec_decode/e2e/test_integration_dist_tp3.py - label: Pipeline Parallelism Test working_dir: "/vllm-workspace/tests" diff --git a/tests/distributed/test_basic_distributed_correctness.py b/tests/distributed/test_basic_distributed_correctness.py index ccdeb997ed735..7a0e5673b2cc4 100644 --- a/tests/distributed/test_basic_distributed_correctness.py +++ b/tests/distributed/test_basic_distributed_correctness.py @@ -61,40 +61,3 @@ def test_models( name_0="hf", name_1="vllm", ) - - -@pytest.mark.skipif(cuda_device_count_stateless() < 3, - reason="Need at least 3 GPUs to run the test.") -@pytest.mark.parametrize("model", MODELS) -@pytest.mark.parametrize("dtype", ["half"]) -@pytest.mark.parametrize("max_tokens", [5]) -def test_models_uneven_tensor_parallel( - hf_runner, - vllm_runner, - example_prompts, - model: str, - dtype: str, - max_tokens: int, -) -> None: - distributed_executor_backend = os.getenv(DISTRIBUTED_EXECUTOR_BACKEND) - - # NOTE: take care of the order. run vLLM first, and then run HF. - # vLLM needs a fresh new process without cuda initialization. - # if we run HF first, the cuda initialization will be done and it - # will hurt multiprocessing backend with fork method (the default method). - with vllm_runner(model, - dtype=dtype, - tensor_parallel_size=3, - distributed_executor_backend=distributed_executor_backend - ) as vllm_model: - vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) - - with hf_runner(model, dtype=dtype) as hf_model: - hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens) - - check_outputs_equal( - outputs_0_lst=hf_outputs, - outputs_1_lst=vllm_outputs, - name_0="hf", - name_1="vllm", - ) diff --git a/tests/distributed/test_uneven_distributed_correctness.py b/tests/distributed/test_uneven_distributed_correctness.py new file mode 100644 index 0000000000000..513c19a434072 --- /dev/null +++ b/tests/distributed/test_uneven_distributed_correctness.py @@ -0,0 +1,45 @@ +import os + +import pytest + +from tests.basic_correctness.test_basic_correctness import MODELS +from tests.distributed.test_basic_distributed_correctness import DISTRIBUTED_EXECUTOR_BACKEND +from tests.models.utils import check_outputs_equal +from vllm.utils import cuda_device_count_stateless + + +@pytest.mark.skipif(cuda_device_count_stateless() < 3, + reason="Need at least 3 GPUs to run the test.") +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["half"]) +@pytest.mark.parametrize("max_tokens", [5]) +def test_models_uneven_tensor_parallel( + hf_runner, + vllm_runner, + example_prompts, + model: str, + dtype: str, + max_tokens: int, +) -> None: + distributed_executor_backend = os.getenv(DISTRIBUTED_EXECUTOR_BACKEND) + + # NOTE: take care of the order. run vLLM first, and then run HF. + # vLLM needs a fresh new process without cuda initialization. + # if we run HF first, the cuda initialization will be done and it + # will hurt multiprocessing backend with fork method (the default method). + with vllm_runner(model, + dtype=dtype, + tensor_parallel_size=3, + distributed_executor_backend=distributed_executor_backend + ) as vllm_model: + vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) + + with hf_runner(model, dtype=dtype) as hf_model: + hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens) + + check_outputs_equal( + outputs_0_lst=hf_outputs, + outputs_1_lst=vllm_outputs, + name_0="hf", + name_1="vllm", + ) \ No newline at end of file diff --git a/tests/spec_decode/e2e/test_integration_dist_tp3.py b/tests/spec_decode/e2e/test_integration_dist_tp3.py index e6178b0930546..c7dac199322c7 100644 --- a/tests/spec_decode/e2e/test_integration_dist_tp3.py +++ b/tests/spec_decode/e2e/test_integration_dist_tp3.py @@ -5,8 +5,6 @@ import pytest import torch -from vllm.utils import is_hip - from .conftest import run_greedy_equality_correctness_test @@ -15,6 +13,8 @@ @pytest.mark.parametrize( "common_llm_kwargs", [{ + # Use a small model for a fast test. + # Note this is repeated in the test body; to initialize a tokenizer. "model": "JackFram/llama-68m", # Skip cuda graph recording for fast test. @@ -31,87 +31,24 @@ # second run of the test to fail with internal NCCL error. "use_async": True, }]) -@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) -@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) -@pytest.mark.parametrize("test_llm_kwargs", [ +@pytest.mark.parametrize("per_test_common_llm_kwargs", [ { "speculative_model": "JackFram/llama-68m", - "num_speculative_tokens": 3, - }, - { - "speculative_model": "[ngram]", "num_speculative_tokens": 5, - "ngram_prompt_lookup_max": 3, }, ]) -@pytest.mark.parametrize("batch_size", [2]) -@pytest.mark.parametrize( - "output_len", - [ - # Use smaller output len for fast test. - 32, - ]) -@pytest.mark.parametrize("seed", [1]) -def test_target_model_tp_gt_1(baseline_llm_generator, test_llm_generator, - batch_size: int, output_len: int): - """Verify greedy equality when tensor parallelism is used. - """ - if is_hip(): - pytest.skip("hip is not well-supported yet") - run_greedy_equality_correctness_test(baseline_llm_generator, - test_llm_generator, - batch_size, - max_output_len=output_len, - force_output_len=True) - - -@pytest.mark.skipif(torch.cuda.device_count() < 3, - reason="Need at least 3 GPUs to run the test.") -@pytest.mark.parametrize( - "common_llm_kwargs", - [{ - # Skip cuda graph recording for fast test. - "enforce_eager": True, - - # Required for spec decode. - "use_v2_block_manager": True, - "tensor_parallel_size": 3, - - # Use AsyncLLM engine, so that the engine runs in its own process. - # Otherwise, since vLLM does not follow true SPMD, the test runner - # process will have both the engine and the rank0 worker. NCCL is not - # cleaned up properly, and its server host thread leaks, causing the - # second run of the test to fail with internal NCCL error. - "use_async": True, - }]) @pytest.mark.parametrize("baseline_llm_kwargs", [{}]) @pytest.mark.parametrize( - "per_test_common_llm_kwargs, test_llm_kwargs", + "test_llm_kwargs", [ - ( - { - # Use a small model for a fast test. - # Note this is repeated in the test body; to initialize a - # tokenizer. - "model": "JackFram/llama-68m", - }, - { - "speculative_model": "JackFram/llama-68m", - "num_speculative_tokens": 5, - "speculative_draft_tensor_parallel_size": 1, - }), - ({ - "model": "ibm-granite/granite-3b-code-instruct", - }, { - "speculative_model": - "ibm-granite/granite-3b-code-instruct-accelerator", - "num_speculative_tokens": 5, + #TODO(wooyeon): add spec_draft_dp=2 case + { "speculative_draft_tensor_parallel_size": 1, - }) + }, ]) @pytest.mark.parametrize("batch_size", [2]) @pytest.mark.parametrize("seed", [1]) -def test_draft_model_tp_lt_target_model_tp2(test_llm_generator, +def test_draft_model_tp_lt_target_model_tp3(test_llm_generator, baseline_llm_generator, batch_size: int): """Verify spec decode works well with smaller tp for draft models. From fe906b51ef70ff68b77b1ec3e968a7f29878e204 Mon Sep 17 00:00:00 2001 From: Nadav Shmayovits Date: Thu, 4 Jul 2024 10:49:48 +0300 Subject: [PATCH 17/21] Fix formatting and imports in new uneven TP tests --- .../test_uneven_distributed_correctness.py | 26 ++++++++++++++++--- 1 file changed, 22 insertions(+), 4 deletions(-) diff --git a/tests/distributed/test_uneven_distributed_correctness.py b/tests/distributed/test_uneven_distributed_correctness.py index 513c19a434072..2e2f3773c54f8 100644 --- a/tests/distributed/test_uneven_distributed_correctness.py +++ b/tests/distributed/test_uneven_distributed_correctness.py @@ -1,12 +1,30 @@ +"""Compare the outputs of HF and distributed vLLM when using greedy sampling. +vLLM will allocate all the available memory, so we need to run the tests one +by one. The solution is to pass arguments (model name) by environment +variables. +Run: +```sh +cd $VLLM_PATH/tests + +TEST_DIST_MODEL=facebook/opt-125m pytest \ + distributed/test_basic_distributed_correctness.py +TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf \ + distributed/test_basic_distributed_correctness.py +``` +""" import os import pytest -from tests.basic_correctness.test_basic_correctness import MODELS -from tests.distributed.test_basic_distributed_correctness import DISTRIBUTED_EXECUTOR_BACKEND -from tests.models.utils import check_outputs_equal from vllm.utils import cuda_device_count_stateless +from ..models.utils import check_outputs_equal + +MODELS = [ + os.environ["TEST_DIST_MODEL"], +] +DISTRIBUTED_EXECUTOR_BACKEND = "DISTRIBUTED_EXECUTOR_BACKEND" + @pytest.mark.skipif(cuda_device_count_stateless() < 3, reason="Need at least 3 GPUs to run the test.") @@ -42,4 +60,4 @@ def test_models_uneven_tensor_parallel( outputs_1_lst=vllm_outputs, name_0="hf", name_1="vllm", - ) \ No newline at end of file + ) From 537e16b00e393defa0c1b351daf07a50e8d6032b Mon Sep 17 00:00:00 2001 From: Nadav Shmayovits Date: Thu, 4 Jul 2024 13:19:18 +0300 Subject: [PATCH 18/21] Fix uneven TP chunked prefill tests and buildkit config --- .buildkite/test-pipeline.yaml | 2 + .../test_chunked_prefill_distributed.py | 50 ------------- ...test_uneven_chunked_prefill_distributed.py | 75 +++++++++++++++++++ 3 files changed, 77 insertions(+), 50 deletions(-) create mode 100644 tests/distributed/test_uneven_chunked_prefill_distributed.py diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 27bd1db22a685..95daa3639f0c4 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -74,6 +74,8 @@ steps: - TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_basic_distributed_correctness.py - TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_uneven_distributed_correctness.py - TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_uneven_distributed_correctness.py + - TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_uneven_chunked_prefill_distributed.py + - TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_uneven_chunked_prefill_distributed.py - pytest -v -s spec_decode/e2e/test_integration_dist_tp4.py - pytest -v -s spec_decode/e2e/test_integration_dist_tp3.py diff --git a/tests/distributed/test_chunked_prefill_distributed.py b/tests/distributed/test_chunked_prefill_distributed.py index c561b15f2755e..1ef085b933793 100644 --- a/tests/distributed/test_chunked_prefill_distributed.py +++ b/tests/distributed/test_chunked_prefill_distributed.py @@ -73,53 +73,3 @@ def test_models( name_0="hf", name_1="vllm", ) - - -@pytest.mark.skipif(cuda_device_count_stateless() < 3, - reason="Need at least 3 GPUs to run the test.") -@pytest.mark.parametrize("model", MODELS) -@pytest.mark.parametrize("dtype", ["half"]) -@pytest.mark.parametrize("max_tokens", [5]) -@pytest.mark.parametrize("chunked_prefill_token_size", [16]) -def test_models_uneven_tensor_parallel( - hf_runner, - vllm_runner, - example_prompts, - model: str, - dtype: str, - max_tokens: int, - chunked_prefill_token_size: int, -) -> None: - distributed_executor_backend = os.getenv(DISTRIBUTED_EXECUTOR_BACKEND) - - # Add a chunked prefill config. - max_num_seqs = min(chunked_prefill_token_size, 256) - assert chunked_prefill_token_size != -1 - enable_chunked_prefill = True - max_num_batched_tokens = chunked_prefill_token_size - - # NOTE: take care of the order. run vLLM first, and then run HF. - # vLLM needs a fresh new process without cuda initialization. - # if we run HF first, the cuda initialization will be done and it - # will hurt multiprocessing backend with fork method (the default method). - - with vllm_runner( - model, - dtype=dtype, - tensor_parallel_size=3, - max_num_seqs=max_num_seqs, - enable_chunked_prefill=enable_chunked_prefill, - max_num_batched_tokens=max_num_batched_tokens, - distributed_executor_backend=distributed_executor_backend, - ) as vllm_model: - vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) - - with hf_runner(model, dtype=dtype) as hf_model: - hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens) - - check_outputs_equal( - outputs_0_lst=hf_outputs, - outputs_1_lst=vllm_outputs, - name_0="hf", - name_1="vllm", - ) diff --git a/tests/distributed/test_uneven_chunked_prefill_distributed.py b/tests/distributed/test_uneven_chunked_prefill_distributed.py new file mode 100644 index 0000000000000..14a76f018580d --- /dev/null +++ b/tests/distributed/test_uneven_chunked_prefill_distributed.py @@ -0,0 +1,75 @@ +"""Compare the outputs of HF and distributed vLLM when using greedy sampling. +vLLM will allocate all the available memory, so we need to run the tests one +by one. The solution is to pass arguments (model name) by environment +variables. + +Run: +```sh +TEST_DIST_MODEL=facebook/opt-125m pytest \ + test_chunked_prefill_distributed.py +TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf \ + test_chunked_prefill_distributed.py +``` +""" +import os + +import pytest + +from vllm.utils import cuda_device_count_stateless + +from ..models.utils import check_outputs_equal + +MODELS = [ + os.environ["TEST_DIST_MODEL"], +] +DISTRIBUTED_EXECUTOR_BACKEND = "DISTRIBUTED_EXECUTOR_BACKEND" + + +@pytest.mark.skipif(cuda_device_count_stateless() < 3, + reason="Need at least 3 GPUs to run the test.") +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["half"]) +@pytest.mark.parametrize("max_tokens", [5]) +@pytest.mark.parametrize("chunked_prefill_token_size", [16]) +def test_models_uneven_tensor_parallel( + hf_runner, + vllm_runner, + example_prompts, + model: str, + dtype: str, + max_tokens: int, + chunked_prefill_token_size: int, +) -> None: + distributed_executor_backend = os.getenv(DISTRIBUTED_EXECUTOR_BACKEND) + + # Add a chunked prefill config. + max_num_seqs = min(chunked_prefill_token_size, 256) + assert chunked_prefill_token_size != -1 + enable_chunked_prefill = True + max_num_batched_tokens = chunked_prefill_token_size + + # NOTE: take care of the order. run vLLM first, and then run HF. + # vLLM needs a fresh new process without cuda initialization. + # if we run HF first, the cuda initialization will be done and it + # will hurt multiprocessing backend with fork method (the default method). + + with vllm_runner( + model, + dtype=dtype, + tensor_parallel_size=3, + max_num_seqs=max_num_seqs, + enable_chunked_prefill=enable_chunked_prefill, + max_num_batched_tokens=max_num_batched_tokens, + distributed_executor_backend=distributed_executor_backend, + ) as vllm_model: + vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) + + with hf_runner(model, dtype=dtype) as hf_model: + hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens) + + check_outputs_equal( + outputs_0_lst=hf_outputs, + outputs_1_lst=vllm_outputs, + name_0="hf", + name_1="vllm", + ) \ No newline at end of file From 5639427367cfde7129626b9d1b1b7b33f0e57717 Mon Sep 17 00:00:00 2001 From: Nadav Shmayovits Date: Thu, 4 Jul 2024 16:06:07 +0300 Subject: [PATCH 19/21] Change default padding size of ParallelLMHead to None --- vllm/model_executor/layers/vocab_parallel_embedding.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/vocab_parallel_embedding.py b/vllm/model_executor/layers/vocab_parallel_embedding.py index b485a1d76862e..7a4fe02251883 100644 --- a/vllm/model_executor/layers/vocab_parallel_embedding.py +++ b/vllm/model_executor/layers/vocab_parallel_embedding.py @@ -382,7 +382,7 @@ def __init__(self, bias: bool = False, params_dtype: Optional[torch.dtype] = None, org_num_embeddings: Optional[int] = None, - padding_size: int = DEFAULT_VOCAB_PADDING_SIZE, + padding_size: Optional[int] = None, quant_config: Optional[QuantizationConfig] = None): super().__init__(num_embeddings, embedding_dim, params_dtype, org_num_embeddings, padding_size, quant_config) From 6f7c0debbddbc466637eaf1fd57e9b1d51e0c50c Mon Sep 17 00:00:00 2001 From: Nadav Shmayovits Date: Mon, 8 Jul 2024 12:10:46 +0300 Subject: [PATCH 20/21] Add validation for LoRA with tensor parallel --- vllm/config.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/vllm/config.py b/vllm/config.py index 06cf7c4a6fbd1..a6d9a305585bf 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1309,6 +1309,11 @@ def verify_with_scheduler_config(self, scheduler_config: SchedulerConfig): if scheduler_config.chunked_prefill_enabled: raise ValueError("LoRA is not supported with chunked prefill yet.") + def verify_with_parallel_config(self, parallel_config: ParallelConfig): + if self.lora_vocab_padding_size % parallel_config.world_size != 0: + raise ValueError("LoRA vocab padding size must be divisible " + "by world size.") + # TODO: To be replaced by MultiModalConfig. @dataclass @@ -1577,6 +1582,7 @@ def __post_init__(self): self.lora_config.verify_with_model_config(self.model_config) self.lora_config.verify_with_scheduler_config( self.scheduler_config) + self.lora_config.verify_with_parallel_config(self.parallel_config) def to_dict(self): """Return the configs as a dictionary, for use in **kwargs. From b8e870a7e8c39d22f44a00b9554f4fa4fafe4f13 Mon Sep 17 00:00:00 2001 From: Nadav Shmayovits Date: Mon, 8 Jul 2024 12:11:12 +0300 Subject: [PATCH 21/21] Fix LLama uneven TP lm head --- tests/distributed/test_uneven_chunked_prefill_distributed.py | 2 +- vllm/model_executor/models/llama.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/distributed/test_uneven_chunked_prefill_distributed.py b/tests/distributed/test_uneven_chunked_prefill_distributed.py index 14a76f018580d..c95b98ca8aa19 100644 --- a/tests/distributed/test_uneven_chunked_prefill_distributed.py +++ b/tests/distributed/test_uneven_chunked_prefill_distributed.py @@ -72,4 +72,4 @@ def test_models_uneven_tensor_parallel( outputs_1_lst=vllm_outputs, name_0="hf", name_1="vllm", - ) \ No newline at end of file + ) diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 096a3469925a6..4e38e91731702 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -43,7 +43,7 @@ from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( - DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) + ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import ( default_weight_loader, kv_cache_scales_loader) from vllm.model_executor.sampling_metadata import SamplingMetadata @@ -373,7 +373,7 @@ def __init__( self.unpadded_vocab_size, config.hidden_size, org_num_embeddings=config.vocab_size, - padding_size=DEFAULT_VOCAB_PADDING_SIZE + padding_size=None # We need bigger padding if using lora for kernel # compatibility if not lora_config else lora_config.lora_vocab_padding_size,