Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bugfix][Quantization]Fix support for non quantized visual layers in otherwise quantized mllama model, including missing scaling factors #9800

Closed
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions vllm/model_executor/layers/quantization/fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,9 +206,13 @@ def process_weights_after_loading(self, layer: Module) -> None:
# If checkpoint is fp8, handle that there are N scales for N
# shards in a fused module
else:
layer.weight_scale.data[layer.weight_scale.data == torch.finfo(
torch.float32).min] = 1
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you pull out torch.finfo(torch.float32).min into a constant in this file so we can use a single reference when also using it in create_weights()? i.e. UNINITIALIZED_SCALE = torch.finfo(torch.float32).min

What is the case that this will be happening? This seems like it may cover up failed weight loading, where we might want to raise an exception.

Copy link
Contributor Author

@gshtras gshtras Oct 29, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you pull out torch.finfo(torch.float32).min into a constant in this file

Will do
This is for the case when a otherwise quantize model has unquantized layers. These layers will get converted to fp8 with the scale of 1.0, but this scale will not get loaded, as it is not in the model

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, in that case then I think this is an incorrect fix. We have infrastructure for referencing the module name against the ignore list in the quantization config. We should fix the issue if they name is not matching or detected from the ignore list here:

if isinstance(layer, LinearBase):
if is_layer_skipped(prefix, self.ignored_layers):
return UnquantizedLinearMethod()
return Fp8LinearMethod(self)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's good to know, thanks. In this case we only need the kv remap and 3d tensor changes. I will remove this one

layer.weight_scale = torch.nn.Parameter(layer.weight_scale.data,
requires_grad=False)
if self.quant_config.activation_scheme == "static":
layer.input_scale.data[layer.input_scale.data == torch.finfo(
torch.float32).min] = 1
layer.input_scale = torch.nn.Parameter(layer.input_scale.data,
requires_grad=False)
# If using marlin (w8a16), kernel uses channelwise weights,
Expand Down
10 changes: 7 additions & 3 deletions vllm/model_executor/layers/quantization/utils/w8a8_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,8 +118,10 @@ def apply_fp8_linear(
# Note: we pad the input because torch._scaled_mm is more performant
# for matrices with batch dimension > 16.
# This could change in the future.
batched = input.dim() > 2
inp_view = input.view(-1, input.shape[-1]) if batched else input
qinput, x_scale = ops.scaled_fp8_quant(
input,
inp_view,
input_scale,
num_token_padding=17,
use_per_token_if_dynamic=use_per_token_if_dynamic)
Expand All @@ -138,8 +140,10 @@ def apply_fp8_linear(
# A fix for discrepancy in scaled_mm which returns tuple
# for torch < 2.5 and a single value in torch >= 2.5
if type(output) is tuple and len(output) == 2:
return torch.narrow(output[0], 0, 0, input.shape[0])
return torch.narrow(output, 0, 0, input.shape[0])
output = output[0]
return (torch.narrow(
output, 0, 0, input.shape[0]) if not batched else output.view(
input.shape[0], input.shape[1], weight.shape[1]))

else:
# Fallback for channelwise case, where we use unfused DQ
Expand Down
9 changes: 9 additions & 0 deletions vllm/model_executor/models/mllama.py
Original file line number Diff line number Diff line change
Expand Up @@ -1396,6 +1396,15 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
weight_loader(param, loaded_weight, shard_id)
break
else:
from vllm.model_executor.model_loader.weight_utils import (
maybe_remap_kv_scale_name)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not add this to the top import list?

orig_name = name
name = maybe_remap_kv_scale_name(name, params_dict)
if name is None:
logger.debug("Missing name %s, orig name %s", name,
orig_name)
continue

param = params_dict.pop(name)
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
Expand Down