From 2454f4a5e02d96f80502124d03ef7c6dca502e9c Mon Sep 17 00:00:00 2001 From: Gregory Shtrasberg Date: Tue, 29 Oct 2024 14:28:04 +0000 Subject: [PATCH 1/3] Fix support for non quantized visual layers in otherwise quantized mllama model, including missing scaling factors Signed-off-by: Gregory Shtrasberg --- vllm/model_executor/layers/quantization/fp8.py | 4 ++++ .../layers/quantization/utils/w8a8_utils.py | 10 +++++++--- vllm/model_executor/models/mllama.py | 9 +++++++++ 3 files changed, 20 insertions(+), 3 deletions(-) diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index d34579b7099bb..64b43798e8bdf 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -206,9 +206,13 @@ def process_weights_after_loading(self, layer: Module) -> None: # If checkpoint is fp8, handle that there are N scales for N # shards in a fused module else: + layer.weight_scale.data[layer.weight_scale.data == torch.finfo( + torch.float32).min] = 1 layer.weight_scale = torch.nn.Parameter(layer.weight_scale.data, requires_grad=False) if self.quant_config.activation_scheme == "static": + layer.input_scale.data[layer.input_scale.data == torch.finfo( + torch.float32).min] = 1 layer.input_scale = torch.nn.Parameter(layer.input_scale.data, requires_grad=False) # If using marlin (w8a16), kernel uses channelwise weights, diff --git a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py index 1879d2855d93d..38cda41fe7232 100644 --- a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py @@ -118,8 +118,10 @@ def apply_fp8_linear( # Note: we pad the input because torch._scaled_mm is more performant # for matrices with batch dimension > 16. # This could change in the future. + batched = input.dim() > 2 + inp_view = input.view(-1, input.shape[-1]) if batched else input qinput, x_scale = ops.scaled_fp8_quant( - input, + inp_view, input_scale, num_token_padding=17, use_per_token_if_dynamic=use_per_token_if_dynamic) @@ -138,8 +140,10 @@ def apply_fp8_linear( # A fix for discrepancy in scaled_mm which returns tuple # for torch < 2.5 and a single value in torch >= 2.5 if type(output) is tuple and len(output) == 2: - return torch.narrow(output[0], 0, 0, input.shape[0]) - return torch.narrow(output, 0, 0, input.shape[0]) + output = output[0] + return (torch.narrow( + output, 0, 0, input.shape[0]) if not batched else output.view( + input.shape[0], input.shape[1], weight.shape[1])) else: # Fallback for channelwise case, where we use unfused DQ diff --git a/vllm/model_executor/models/mllama.py b/vllm/model_executor/models/mllama.py index 5cf5272cae878..b107f91af5472 100644 --- a/vllm/model_executor/models/mllama.py +++ b/vllm/model_executor/models/mllama.py @@ -1396,6 +1396,15 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): weight_loader(param, loaded_weight, shard_id) break else: + from vllm.model_executor.model_loader.weight_utils import ( + maybe_remap_kv_scale_name) + orig_name = name + name = maybe_remap_kv_scale_name(name, params_dict) + if name is None: + logger.debug("Missing name %s, orig name %s", name, + orig_name) + continue + param = params_dict.pop(name) weight_loader = getattr(param, "weight_loader", default_weight_loader) From a23a23c126d04469a3e9499f7d15ca403fc0e792 Mon Sep 17 00:00:00 2001 From: Gregory Shtrasberg Date: Tue, 29 Oct 2024 16:31:19 +0000 Subject: [PATCH 2/3] Reorganize imports; Restrict additional supported tensors in _scaled_mm to 3D; Use constant for default fp8 scale Signed-off-by: Gregory Shtrasberg --- vllm/model_executor/layers/quantization/fp8.py | 14 ++++++++------ .../layers/quantization/utils/w8a8_utils.py | 2 +- vllm/model_executor/models/mllama.py | 5 ++--- 3 files changed, 11 insertions(+), 10 deletions(-) diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 64b43798e8bdf..fb56897879dd6 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -126,6 +126,8 @@ def __init__(self, quant_config: Fp8Config): if current_platform.is_rocm(): self.use_marlin = False + self.default_scale = torch.finfo(torch.float32).min + def create_weights( self, layer: torch.nn.Module, @@ -168,7 +170,7 @@ def create_weights( len(output_partition_sizes), dtype=torch.float32), weight_loader=weight_loader) - scale[:] = torch.finfo(torch.float32).min + scale[:] = self.default_scale layer.register_parameter("weight_scale", scale) # INPUT ACTIVATION SCALE @@ -177,7 +179,7 @@ def create_weights( len(output_partition_sizes), dtype=torch.float32), weight_loader=weight_loader) - scale[:] = torch.finfo(torch.float32).min + scale[:] = self.default_scale layer.register_parameter("input_scale", scale) else: layer.register_parameter("input_scale", None) @@ -206,13 +208,13 @@ def process_weights_after_loading(self, layer: Module) -> None: # If checkpoint is fp8, handle that there are N scales for N # shards in a fused module else: - layer.weight_scale.data[layer.weight_scale.data == torch.finfo( - torch.float32).min] = 1 + layer.weight_scale.data[layer.weight_scale.data == + self.default_scale] = 1 layer.weight_scale = torch.nn.Parameter(layer.weight_scale.data, requires_grad=False) if self.quant_config.activation_scheme == "static": - layer.input_scale.data[layer.input_scale.data == torch.finfo( - torch.float32).min] = 1 + layer.input_scale.data[layer.input_scale.data == + self.default_scale] = 1 layer.input_scale = torch.nn.Parameter(layer.input_scale.data, requires_grad=False) # If using marlin (w8a16), kernel uses channelwise weights, diff --git a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py index 38cda41fe7232..ea15270ee1910 100644 --- a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py @@ -115,10 +115,10 @@ def apply_fp8_linear( # torch.scaled_mm supports per tensor weights + activations only # so fallback to naive if per channel or per token else: + batched = input.dim() == 3 # Note: we pad the input because torch._scaled_mm is more performant # for matrices with batch dimension > 16. # This could change in the future. - batched = input.dim() > 2 inp_view = input.view(-1, input.shape[-1]) if batched else input qinput, x_scale = ops.scaled_fp8_quant( inp_view, diff --git a/vllm/model_executor/models/mllama.py b/vllm/model_executor/models/mllama.py index b107f91af5472..3ba83a92993b5 100644 --- a/vllm/model_executor/models/mllama.py +++ b/vllm/model_executor/models/mllama.py @@ -48,7 +48,8 @@ from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) -from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.model_loader.weight_utils import ( + default_weight_loader, maybe_remap_kv_scale_name) from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.sequence import SequenceData @@ -1396,8 +1397,6 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): weight_loader(param, loaded_weight, shard_id) break else: - from vllm.model_executor.model_loader.weight_utils import ( - maybe_remap_kv_scale_name) orig_name = name name = maybe_remap_kv_scale_name(name, params_dict) if name is None: From 49c0a3d70bcc38ee9fadebdb84fadb63e440dfa8 Mon Sep 17 00:00:00 2001 From: Gregory Shtrasberg Date: Tue, 29 Oct 2024 13:48:12 -0500 Subject: [PATCH 3/3] Revert weight scaling fix; this is meant to be handled through skipped layers in the config Signed-off-by: Gregory Shtrasberg --- vllm/model_executor/layers/quantization/fp8.py | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index fb56897879dd6..d34579b7099bb 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -126,8 +126,6 @@ def __init__(self, quant_config: Fp8Config): if current_platform.is_rocm(): self.use_marlin = False - self.default_scale = torch.finfo(torch.float32).min - def create_weights( self, layer: torch.nn.Module, @@ -170,7 +168,7 @@ def create_weights( len(output_partition_sizes), dtype=torch.float32), weight_loader=weight_loader) - scale[:] = self.default_scale + scale[:] = torch.finfo(torch.float32).min layer.register_parameter("weight_scale", scale) # INPUT ACTIVATION SCALE @@ -179,7 +177,7 @@ def create_weights( len(output_partition_sizes), dtype=torch.float32), weight_loader=weight_loader) - scale[:] = self.default_scale + scale[:] = torch.finfo(torch.float32).min layer.register_parameter("input_scale", scale) else: layer.register_parameter("input_scale", None) @@ -208,13 +206,9 @@ def process_weights_after_loading(self, layer: Module) -> None: # If checkpoint is fp8, handle that there are N scales for N # shards in a fused module else: - layer.weight_scale.data[layer.weight_scale.data == - self.default_scale] = 1 layer.weight_scale = torch.nn.Parameter(layer.weight_scale.data, requires_grad=False) if self.quant_config.activation_scheme == "static": - layer.input_scale.data[layer.input_scale.data == - self.default_scale] = 1 layer.input_scale = torch.nn.Parameter(layer.input_scale.data, requires_grad=False) # If using marlin (w8a16), kernel uses channelwise weights,