diff --git a/vllm/config.py b/vllm/config.py index 4d05b4ea36d5c..51a7a4c4d22d5 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -292,7 +292,11 @@ def get_total_num_kv_heads(self) -> int: return 1 # For DBRX and MPT - if self.hf_config.model_type in ["dbrx", "mpt"]: + if self.hf_config.model_type == "mpt": + if "kv_n_heads" in self.hf_config.attn_config: + return self.hf_config.attn_config["kv_n_heads"] + return self.hf_config.num_attention_heads + if self.hf_config.model_type == "dbrx": return getattr(self.hf_config.attn_config, "kv_n_heads", self.hf_config.num_attention_heads)