From 959ceb9d1e8845a5bbe7717c1d679d2b00d392dc Mon Sep 17 00:00:00 2001 From: Taemin Lee Date: Thu, 21 Mar 2024 19:23:13 +0900 Subject: [PATCH 1/2] fix gemma loading after quantization or LoRA. lm_head is not used in vllm as it is tied weight with embed_token. Sometimes duplicate lm_head layers are added when the structure of the model is newly created by quantization, LoRA, etc. To avoid the error that occurs, skip loading lm_head.weight. --- vllm/model_executor/models/gemma.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/vllm/model_executor/models/gemma.py b/vllm/model_executor/models/gemma.py index fd3dbe798cd8e..4e0c7a3b8ad19 100644 --- a/vllm/model_executor/models/gemma.py +++ b/vllm/model_executor/models/gemma.py @@ -340,6 +340,11 @@ def load_weights(self, weight_loader(param, loaded_weight, shard_id) break else: + # lm_head is not used in vllm as it is tied weight with embed_token. + # Sometimes duplicate lm_head layers are added when the structure of the model is newly created by quantization, LORA, etc. + # To avoid the error that occurs, skip loading lm_head.weight. + if "lm_head.weight" in name: + continue # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue From 8c2b7b4ccb319c588f9118a48cb55560335c0287 Mon Sep 17 00:00:00 2001 From: Taemin Lee Date: Thu, 21 Mar 2024 19:40:39 +0900 Subject: [PATCH 2/2] refine comments for lm_head gemma error fix --- vllm/model_executor/models/gemma.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/vllm/model_executor/models/gemma.py b/vllm/model_executor/models/gemma.py index 4e0c7a3b8ad19..fa8ce60e74056 100644 --- a/vllm/model_executor/models/gemma.py +++ b/vllm/model_executor/models/gemma.py @@ -340,9 +340,8 @@ def load_weights(self, weight_loader(param, loaded_weight, shard_id) break else: - # lm_head is not used in vllm as it is tied weight with embed_token. - # Sometimes duplicate lm_head layers are added when the structure of the model is newly created by quantization, LORA, etc. - # To avoid the error that occurs, skip loading lm_head.weight. + # lm_head is not used in vllm as it is tied with embed_token. + # To prevent errors, skip loading lm_head.weight. if "lm_head.weight" in name: continue # Skip loading extra bias for GPTQ models.