diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 1bb6bc753d37c..3ea6217d7c0ef 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -701,8 +701,14 @@ def fused_experts_impl(hidden_states: torch.Tensor, device=hidden_states.device, dtype=hidden_states.dtype) - compute_type = (tl.bfloat16 - if hidden_states.dtype == torch.bfloat16 else tl.float16) + if hidden_states.dtype == torch.bfloat16: + compute_type = tl.bfloat16 + elif hidden_states.dtype == torch.float16: + compute_type = tl.float16 + elif hidden_states.dtype == torch.float32: + compute_type = tl.float32 + else: + raise ValueError(f"Unsupported compute_type: {hidden_states.dtype}") if inplace: out_hidden_states = hidden_states