Skip to content
This repository has been archived by the owner on Oct 11, 2024. It is now read-only.

Commit

Permalink
[BugFix] Fix Falcon tied embeddings (vllm-project#3590)
Browse files Browse the repository at this point in the history
Co-authored-by: 44670 <[email protected]>
  • Loading branch information
WoosukKwon and 44670 authored Mar 24, 2024
1 parent f8a12ec commit af9e534
Showing 1 changed file with 7 additions and 7 deletions.
14 changes: 7 additions & 7 deletions vllm/model_executor/models/falcon.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding, ParallelLMHead)
VocabParallelEmbedding)
from vllm.model_executor.parallel_utils.communication_op import (
tensor_model_parallel_all_reduce)
from vllm.model_executor.parallel_utils.parallel_state import (
Expand Down Expand Up @@ -370,10 +370,7 @@ def __init__(
self.config = config
self.linear_method = linear_method
self.transformer = FalconModel(config, linear_method)
self.lm_head = ParallelLMHead(
config.vocab_size,
config.hidden_size,
)
self.lm_head_weight = self.transformer.word_embeddings.weight
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()

Expand All @@ -394,7 +391,7 @@ def forward(

def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(self.lm_head.weight, hidden_states,
logits = self.logits_processor(self.lm_head_weight, hidden_states,
sampling_metadata)
return logits

Expand All @@ -419,9 +416,12 @@ def load_weights(self,
else:
total_num_kv_heads = total_num_heads
num_query_heads_per_kv_head = total_num_heads // total_num_kv_heads
params_dict = dict(self.named_parameters())
params_dict = dict(self.named_parameters(remove_duplicate=False))
for name, loaded_weight in hf_model_weights_iterator(
model_name_or_path, cache_dir, load_format, revision):
if name == "lm_head.weight":
# Falcon uses tied embeddings.
continue
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
Expand Down

0 comments on commit af9e534

Please sign in to comment.