From 124c64624fe103c54db36659492c49a447b86a2a Mon Sep 17 00:00:00 2001 From: Bruce Fontaine Date: Mon, 17 Jun 2024 15:26:41 -0700 Subject: [PATCH] [Bugfix] Fix KV head calculation for MPT models when using GQA (#5142) --- vllm/config.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/vllm/config.py b/vllm/config.py index b1a3a82f5a6c0..d95faf52db1a0 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -302,7 +302,11 @@ def get_total_num_kv_heads(self) -> int: return 1 # For DBRX and MPT - if self.hf_config.model_type in ["dbrx", "mpt"]: + if self.hf_config.model_type == "mpt": + if "kv_n_heads" in self.hf_config.attn_config: + return self.hf_config.attn_config["kv_n_heads"] + return self.hf_config.num_attention_heads + if self.hf_config.model_type == "dbrx": return getattr(self.hf_config.attn_config, "kv_n_heads", self.hf_config.num_attention_heads)