diff --git a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py index 1879d2855d93d..ea15270ee1910 100644 --- a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py @@ -115,11 +115,13 @@ def apply_fp8_linear( # torch.scaled_mm supports per tensor weights + activations only # so fallback to naive if per channel or per token else: + batched = input.dim() == 3 # Note: we pad the input because torch._scaled_mm is more performant # for matrices with batch dimension > 16. # This could change in the future. + inp_view = input.view(-1, input.shape[-1]) if batched else input qinput, x_scale = ops.scaled_fp8_quant( - input, + inp_view, input_scale, num_token_padding=17, use_per_token_if_dynamic=use_per_token_if_dynamic) @@ -138,8 +140,10 @@ def apply_fp8_linear( # A fix for discrepancy in scaled_mm which returns tuple # for torch < 2.5 and a single value in torch >= 2.5 if type(output) is tuple and len(output) == 2: - return torch.narrow(output[0], 0, 0, input.shape[0]) - return torch.narrow(output, 0, 0, input.shape[0]) + output = output[0] + return (torch.narrow( + output, 0, 0, input.shape[0]) if not batched else output.view( + input.shape[0], input.shape[1], weight.shape[1])) else: # Fallback for channelwise case, where we use unfused DQ diff --git a/vllm/model_executor/models/mllama.py b/vllm/model_executor/models/mllama.py index 5cf5272cae878..3ba83a92993b5 100644 --- a/vllm/model_executor/models/mllama.py +++ b/vllm/model_executor/models/mllama.py @@ -48,7 +48,8 @@ from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) -from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.model_loader.weight_utils import ( + default_weight_loader, maybe_remap_kv_scale_name) from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.sequence import SequenceData @@ -1396,6 +1397,13 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): weight_loader(param, loaded_weight, shard_id) break else: + orig_name = name + name = maybe_remap_kv_scale_name(name, params_dict) + if name is None: + logger.debug("Missing name %s, orig name %s", name, + orig_name) + continue + param = params_dict.pop(name) weight_loader = getattr(param, "weight_loader", default_weight_loader)