From 919dc5d9de4cac4bda252fc99b99e4c650d8c55d Mon Sep 17 00:00:00 2001 From: Robert Shaw <114415538+robertgshaw2-neuralmagic@users.noreply.github.com> Date: Fri, 28 Jun 2024 14:43:49 -0400 Subject: [PATCH] [ Bugfix ] Enabling Loading Models With Fused QKV/MLP on Disk with FP8 (#5921) Co-authored-by: Robert Shaw --- vllm/model_executor/layers/linear.py | 14 ++++++- .../model_executor/layers/quantization/fp8.py | 41 +++++++++---------- 2 files changed, 32 insertions(+), 23 deletions(-) diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index fe7c2a295b70c..d221fecd66ff1 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -383,8 +383,13 @@ def weight_loader(self, None) if loaded_shard_id is None: - # Loaded weight is already packed. + # Loaded weight is already fused on disk (qkv/mlp). if output_dim is None: + # If fp8 + scale, need to send to each shard. + if fp8_scales_shard_indexer is not None: + param_data, loaded_weight = fp8_scales_shard_indexer( + param_data, loaded_weight, loaded_shard_id) + assert param_data.shape == loaded_weight.shape param_data.copy_(loaded_weight) return @@ -567,8 +572,13 @@ def weight_loader(self, None) if loaded_shard_id is None: - # Loaded weight is already packed. + # Loaded weight is already fused on disk (qkv/mlp). if output_dim is None: + # If fp8 + scale, need to send to each shard. + if fp8_scales_shard_indexer is not None: + param_data, loaded_weight = fp8_scales_shard_indexer( + param_data, loaded_weight, loaded_shard_id) + assert param_data.shape == loaded_weight.shape param_data.copy_(loaded_weight) return diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index bbf3cde54782d..1c760566c28d7 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -98,6 +98,7 @@ class Fp8LinearMethod(LinearMethodBase): """ def __init__(self, quant_config: Fp8Config): + self.fused_module_in_checkpoint = False self.quant_config = quant_config self.cutlass_fp8_supported = cutlass_fp8_supported() @@ -111,6 +112,7 @@ def _create_scale_param( scale = Parameter(torch.empty(len(output_partition_sizes), dtype=torch.float32), requires_grad=False) + scale[:] = torch.finfo(torch.float8_e4m3fn).min layer.register_parameter(scale_name, scale) set_weight_attrs( scale, { @@ -169,11 +171,15 @@ def create_weights( **extra_weight_attrs) def scales_shard_indexer( - self, param: torch.Tensor, loaded_weight: torch.Tensor, - shard_id: Union[str, int]) -> Tuple[torch.Tensor, torch.Tensor]: + self, param: torch.Tensor, loaded_weight: torch.Tensor, + shard_id: Optional[Union[str, + int]]) -> Tuple[torch.Tensor, torch.Tensor]: qkv_idxs = {"q": 0, "k": 1, "v": 2} - if isinstance(shard_id, int): + if shard_id is None: + shard_id = 0 + self.fused_module_in_checkpoint = True + elif isinstance(shard_id, int): pass elif isinstance(shard_id, str): if shard_id not in qkv_idxs: @@ -205,15 +211,17 @@ def process_weights_after_loading(self, layer: Module) -> None: # WEIGHT_SCALE / WEIGHT # Loop over logical weights, requantizing with single scale. max_w_scale = layer.weight_scale.max() - start = 0 - for idx, logical_width in enumerate(layer.logical_widths): - end = start + logical_width - weight_dq = per_tensor_dequantize(layer.weight[start:end, :], - layer.weight_scale[idx]) - - layer.weight[start:end, :] = per_tensor_quantize( - weight_dq, layer.weight_scale.max()) - start = end + + if not self.fused_module_in_checkpoint: + start = 0 + for idx, logical_width in enumerate(layer.logical_widths): + end = start + logical_width + weight_dq = per_tensor_dequantize( + layer.weight[start:end, :], layer.weight_scale[idx]) + + layer.weight[start:end, :] = per_tensor_quantize( + weight_dq, layer.weight_scale.max()) + start = end layer.weight_scale = Parameter(max_w_scale, requires_grad=False) # WEIGHT @@ -227,10 +235,6 @@ def process_weights_after_loading(self, layer: Module) -> None: if self.quant_config.activation_scheme == "dynamic": layer.input_scale = None elif self.quant_config.activation_scheme == "static": - if not all_close_1d(layer.input_scale): - raise ValueError( - "All the input_scales for the logical weights of a " - f"layer must be equal. But got {layer.input_scale}") layer.input_scale = Parameter(layer.input_scale.max(), requires_grad=False) else: @@ -317,11 +321,6 @@ def process_weights_after_loading(self, layer: Module) -> None: del layer.kv_scale -def all_close_1d(x: torch.Tensor) -> bool: - assert len(x.shape) == 1 - return all(torch.allclose(x[0], x[i]) for i in range(x.shape[0])) - - def per_tensor_quantize(tensor: torch.Tensor, inv_scale: Union[float, torch.Tensor]) -> torch.Tensor: finfo = torch.finfo(torch.float8_e4m3fn)