From 4f44925fb441d05c698702fb75ec4963052ff139 Mon Sep 17 00:00:00 2001 From: Prashant Gupta Date: Thu, 25 Jul 2024 12:30:53 -0700 Subject: [PATCH 1/3] =?UTF-8?q?=F0=9F=90=9B=20skip=20loading=20lm=5Fhead.w?= =?UTF-8?q?eight?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Since lm_head is tied to embeddings we should be able to skip loading lm_head Signed-off-by: Prashant Gupta --- vllm/model_executor/models/llama.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 306d22e42ed1d..4211ffd43d044 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -496,6 +496,10 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): break else: + # lm_head is not used in vllm as it is tied with embed_token. + # To prevent errors, skip loading lm_head. + if "lm_head" in name: + continue # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue From 0f5867262856c15a6087635c9c79fa7e5d261985 Mon Sep 17 00:00:00 2001 From: Prashant Gupta Date: Thu, 25 Jul 2024 12:36:24 -0700 Subject: [PATCH 2/3] =?UTF-8?q?=F0=9F=90=9B=20make=20sure=20to=20check=20t?= =?UTF-8?q?ie=5Fword=5Fembeddings=20config=20first?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Prashant Gupta --- vllm/model_executor/models/llama.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 4211ffd43d044..09184d5003b0a 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -496,10 +496,12 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): break else: - # lm_head is not used in vllm as it is tied with embed_token. + # if word embeddings are tied, + # lm_head will not be used. # To prevent errors, skip loading lm_head. - if "lm_head" in name: - continue + if self.config.tie_word_embeddings: + if "lm_head" in name: + continue # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue From ff663f6a95daf8b3da63541f3913b1c3fa3ed9cc Mon Sep 17 00:00:00 2001 From: Prashant Gupta Date: Thu, 25 Jul 2024 12:39:19 -0700 Subject: [PATCH 3/3] =?UTF-8?q?=F0=9F=8E=A8=20use=20single=20if?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Prashant Gupta --- vllm/model_executor/models/llama.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 09184d5003b0a..21c0c3f2dd3b5 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -496,12 +496,11 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): break else: - # if word embeddings are tied, + # if word embeddings are tied, # lm_head will not be used. # To prevent errors, skip loading lm_head. - if self.config.tie_word_embeddings: - if "lm_head" in name: - continue + if self.config.tie_word_embeddings and "lm_head" in name: + continue # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue