From 343041c4c4db93b4693ba437df7ae8bea485d18e Mon Sep 17 00:00:00 2001 From: Sky Lee <46676799+skylee-01@users.noreply.github.com> Date: Wed, 20 Nov 2024 14:05:55 +0800 Subject: [PATCH] [model] Reduce medusa weight (#10454) Signed-off-by: skylee-01 <497627264@qq.com> --- vllm/model_executor/models/medusa.py | 22 ++++++++++++++++++---- 1 file changed, 18 insertions(+), 4 deletions(-) diff --git a/vllm/model_executor/models/medusa.py b/vllm/model_executor/models/medusa.py index b4ed6538bddac..66bdcb89a0213 100644 --- a/vllm/model_executor/models/medusa.py +++ b/vllm/model_executor/models/medusa.py @@ -61,14 +61,25 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: self.truncated_vocab_size = config.truncated_vocab_size self.unpadded_vocab_size = self.truncated_vocab_size - self.lm_heads = nn.ModuleList([ - ParallelLMHead( + if getattr(config, "original_lm_head", False): + self.lm_head = ParallelLMHead( self.unpadded_vocab_size, config.hidden_size, org_num_embeddings=self.truncated_vocab_size, padding_size=DEFAULT_VOCAB_PADDING_SIZE, - ) for _ in range(self.config.num_heads) - ]) + ) + self.lm_heads = [ + self.lm_head for _ in range(self.config.num_heads) + ] + else: + self.lm_heads = nn.ModuleList([ + ParallelLMHead( + self.unpadded_vocab_size, + config.hidden_size, + org_num_embeddings=self.truncated_vocab_size, + padding_size=DEFAULT_VOCAB_PADDING_SIZE, + ) for _ in range(self.config.num_heads) + ]) logit_scale = getattr(config, "logit_scale", 1.0) self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, @@ -172,6 +183,9 @@ def load_weights(self, weights: Iterable[Tuple[str, requires_grad=False) elif name in params_dict: weights_map[name] = loaded_weight + elif (getattr(self.config, "original_lm_head", False) + and name == "lm_heads.0.weight"): + weights_map["lm_head.weight"] = loaded_weight for name, loaded_weight in weights_map.items(): if "lm_head" in name and self.token_map is not None and\