Skip to content

Commit

Permalink
[model] Reduce medusa weight (#10454)
Browse files Browse the repository at this point in the history
Signed-off-by: skylee-01 <[email protected]>
  • Loading branch information
skylee-01 authored Nov 20, 2024
1 parent ed701ca commit 343041c
Showing 1 changed file with 18 additions and 4 deletions.
22 changes: 18 additions & 4 deletions vllm/model_executor/models/medusa.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,14 +61,25 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
self.truncated_vocab_size = config.truncated_vocab_size
self.unpadded_vocab_size = self.truncated_vocab_size

self.lm_heads = nn.ModuleList([
ParallelLMHead(
if getattr(config, "original_lm_head", False):
self.lm_head = ParallelLMHead(
self.unpadded_vocab_size,
config.hidden_size,
org_num_embeddings=self.truncated_vocab_size,
padding_size=DEFAULT_VOCAB_PADDING_SIZE,
) for _ in range(self.config.num_heads)
])
)
self.lm_heads = [
self.lm_head for _ in range(self.config.num_heads)
]
else:
self.lm_heads = nn.ModuleList([
ParallelLMHead(
self.unpadded_vocab_size,
config.hidden_size,
org_num_embeddings=self.truncated_vocab_size,
padding_size=DEFAULT_VOCAB_PADDING_SIZE,
) for _ in range(self.config.num_heads)
])

logit_scale = getattr(config, "logit_scale", 1.0)
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
Expand Down Expand Up @@ -172,6 +183,9 @@ def load_weights(self, weights: Iterable[Tuple[str,
requires_grad=False)
elif name in params_dict:
weights_map[name] = loaded_weight
elif (getattr(self.config, "original_lm_head", False)
and name == "lm_heads.0.weight"):
weights_map["lm_head.weight"] = loaded_weight

for name, loaded_weight in weights_map.items():
if "lm_head" in name and self.token_map is not None and\
Expand Down

0 comments on commit 343041c

Please sign in to comment.