Skip to content

Commit

Permalink
Update w8a8_utils.py
Browse files Browse the repository at this point in the history
  • Loading branch information
mgoin authored Oct 29, 2024
1 parent 1032cf3 commit ebd38af
Showing 1 changed file with 7 additions and 6 deletions.
13 changes: 7 additions & 6 deletions vllm/model_executor/layers/quantization/utils/w8a8_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,11 +110,11 @@ def apply_fp8_linear(

# Fused GEMM_DQ
output = ops.cutlass_scaled_mm(qinput,
weight,
out_dtype=input.dtype,
scale_a=x_scale,
scale_b=weight_scale,
bias=bias)
weight,
out_dtype=input.dtype,
scale_a=x_scale,
scale_b=weight_scale,
bias=bias)
return output.view(*output_shape)

# torch.scaled_mm supports per tensor weights + activations only
Expand Down Expand Up @@ -145,7 +145,8 @@ def apply_fp8_linear(
if type(output) is tuple and len(output) == 2:
output = output[0]

return torch.narrow(output, 0, 0, input_2d.shape[0]).view(*output_shape)
return torch.narrow(output, 0, 0,
input_2d.shape[0]).view(*output_shape)

else:
# Fallback for channelwise case, where we use unfused DQ
Expand Down

0 comments on commit ebd38af

Please sign in to comment.