Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BugFix] Fix GGUF tp>1 models when vocab_size is not divisible by 64 #12230

Merged
merged 2 commits into from
Jan 21, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions tests/models/decoder_only/language/test_gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,12 +66,20 @@ def gguf_model(self):
gguf_filename="starcoder2-3b.Q6_K.gguf",
)

DOLPHIN_CONFIG = GGUFTestConfig(
# Test VocabParallelEmbedding sharding issue.
original_model="cognitivecomputations/TinyDolphin-2.8-1.1b",
gguf_repo="tsunemoto/TinyDolphin-2.8-1.1b-GGUF",
gguf_filename="tinydolphin-2.8-1.1b.Q6_K.gguf",
)

MODELS = [
LLAMA_CONFIG,
QWEN2_CONFIG,
PHI3_CONFIG,
GPT2_CONFIG,
STABLELM_CONFIG,
DOLPHIN_CONFIG
# STARCODER_CONFIG, # broken
]

Expand Down Expand Up @@ -107,6 +115,7 @@ def test_models(

# Run unquantized model.
with vllm_runner(model_name=model.original_model,
enforce_eager=True, # faster tests
dtype=dtype,
max_model_len=MAX_MODEL_LEN,
tensor_parallel_size=tp_size) as original_model:
Expand All @@ -115,6 +124,7 @@ def test_models(

# Run gguf model.
with vllm_runner(model_name=model.gguf_model,
enforce_eager=True,
tokenizer_name=model.original_model,
dtype=dtype,
max_model_len=MAX_MODEL_LEN,
Expand Down
4 changes: 2 additions & 2 deletions vllm/model_executor/layers/vocab_parallel_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,7 +355,7 @@ def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
elif isinstance(param, UninitializedParameter):
shape = list(loaded_weight.shape)
if output_dim is not None:
shape[output_dim] = shape[output_dim] // self.tp_size
shape[output_dim] = self.num_embeddings_per_partition
param.materialize(tuple(shape), dtype=loaded_weight.dtype)

# If parameter does not have output dim, then it should
Expand All @@ -381,7 +381,7 @@ def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
else:
assert loaded_weight.shape[output_dim] == self.org_vocab_size

# Copy the data.
# Copy the data. Select chunk corresponding to current shard.
loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)

if current_platform.is_hpu():
Expand Down
Loading