diff --git a/vllm/model_executor/models/mlp_speculator.py b/vllm/model_executor/models/mlp_speculator.py index 97f7ec74292bb..d3aec06a92fdb 100644 --- a/vllm/model_executor/models/mlp_speculator.py +++ b/vllm/model_executor/models/mlp_speculator.py @@ -110,7 +110,7 @@ def __init__(self, config: MLPSpeculatorConfig, **kwargs) -> None: ]) self.head = nn.ModuleList([ - nn.Linear(self.inner_dim, self.vocab_size, bias=False) + ParallelLMHead(self.vocab_size, self.inner_dim, bias=False) for _ in range(self.max_speculative_tokens) ]) self.ln = nn.ModuleList([