Skip to content

Commit

Permalink
Fix _quantize_affine_no_dtype_cast for FP8 types
Browse files Browse the repository at this point in the history
  • Loading branch information
sanchitintel committed Jan 14, 2025
1 parent c53a9d5 commit 837ee22
Showing 1 changed file with 4 additions and 1 deletion.
5 changes: 4 additions & 1 deletion torchao/quantization/quant_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,6 +334,7 @@ def _quantize_affine(
zero_point,
quant_min,
quant_max,
output_dtype,
zero_point_domain,
).to(output_dtype)

Expand All @@ -345,6 +346,7 @@ def _quantize_affine_no_dtype_cast(
zero_point: Optional[torch.Tensor],
quant_min: Union[int, float],
quant_max: Union[int, float],
quant_dtype: Optional[torch.dtype],
zero_point_domain: Optional[str] = ZeroPointDomain.INT.name,
) -> torch.Tensor:
"""
Expand Down Expand Up @@ -389,7 +391,7 @@ def _quantize_affine_no_dtype_cast(
assert (
zero_point is None
), "zero_point should be None when zero_point_domain is NONE"
if _is_float8_type(input.dtype):
if _is_float8_type(quant_dtype):
quant = torch.clamp(input * scale.reciprocal(), quant_min, quant_max)
else:
quant = torch.clamp(torch.round(input * (1.0 / scale)), quant_min, quant_max)
Expand Down Expand Up @@ -661,6 +663,7 @@ def _do_fake_quantize_affine(
zero_point,
quant_min,
quant_max,
quant_dtype,
zero_point_domain.name,
)
dq = _dequantize_affine_no_dtype_check(
Expand Down

0 comments on commit 837ee22

Please sign in to comment.