Skip to content

Commit

Permalink
this shouldn't work
Browse files Browse the repository at this point in the history
  • Loading branch information
drisspg committed Apr 12, 2024
1 parent c04d670 commit 8b62d1a
Show file tree
Hide file tree
Showing 2 changed files with 85 additions and 1 deletion.
18 changes: 17 additions & 1 deletion test/test_fp8.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
import pytest
import torch

from transformer_nuggets.fp8.scaled_quant import eager_scaled_quant, scaled_quant
from transformer_nuggets.fp8.scaled_quant import (
dynamic_scaled_quant,
eager_dynamic_scaled_quant,
eager_scaled_quant,
scaled_quant,
)


@pytest.mark.parametrize("fp8_dtype", [torch.float8_e4m3fn, torch.float8_e5m2])
Expand Down Expand Up @@ -36,5 +41,16 @@ def test_saturated(fp8_dtype):
)


@pytest.mark.parametrize("fp8_dtype", [torch.float8_e4m3fn, torch.float8_e5m2])
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is not available")
def test_dynamic_quant(fp8_dtype):
torch.manual_seed(0)
torch.cuda.manual_seed_all(0)
a = torch.randn(2**12, 2**12, dtype=torch.float32, device="cuda") * 9.6
output = eager_dynamic_scaled_quant(a, fp8_dtype)
output_triton = dynamic_scaled_quant(a, fp8_dtype)
torch.testing.assert_close(output.to(torch.float32), output_triton.to(torch.float32))


if __name__ == "__main__":
pytest.main([__file__])
68 changes: 68 additions & 0 deletions transformer_nuggets/fp8/scaled_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,58 @@ def scaled_cast(
tl.store(output_ptr + (index), scaled_inpt.to(float8_dtype), mask=mask)


@triton.jit
def dynamic_scaled_cast(
inpt_ptr: torch.Tensor,
output_ptr: torch.Tensor,
abs_max_ptr: torch.Tensor,
numel: int,
XBLOCK: tl.constexpr,
float8_dtype: tl.constexpr,
max_val: tl.constexpr,
):
"""Quantize tensor to fp8 using current global absmax"""
offset = tl.program_id(0) * XBLOCK
index = offset + tl.arange(0, XBLOCK)[:]
index = tl.max_contiguous(tl.multiple_of(index, XBLOCK), XBLOCK)
mask = index < numel
inpt = tl.load(inpt_ptr + (index), mask=mask)
block_max = tl.max(tl.abs(inpt))
tl.atomic_max(abs_max_ptr, block_max)
# TODO Need a global barrier to ensure all blocks have updated abs_max
# Yet the test passes...?
scale = max_val / (tl.load(abs_max_ptr) + 1e-12)
scaled_inpt = inpt * scale
# Saturated casting
tl.where(scaled_inpt > max_val, max_val, scaled_inpt)
tl.where(scaled_inpt < -1 * max_val, -1 * max_val, scaled_inpt)
tl.store(output_ptr + (index), scaled_inpt.to(float8_dtype), mask=mask)


def dynamic_scaled_quant(
inpt_tensor: torch.Tensor, fp8_dtype: torch.dtype = torch.float8_e4m3fn
) -> torch.Tensor:
"""Quantize tensor to fp8 using dynamic scale calculated from abs_max
It will do saturated casting
Args:
inpt_tensor: Input tensor to quantize
fp8_dtype: FP8 datatype to quantize to
"""
assert inpt_tensor.is_contiguous(), "Input tensor must be contiguous"

out_tensor = torch.empty_like(inpt_tensor, dtype=fp8_dtype, device="cuda")
numel = inpt_tensor.numel()
grid = lambda meta: (triton.cdiv(numel, meta["XBLOCK"]),)
tl_dtype = {torch.float8_e4m3fn: tl.float8e4nv, torch.float8_e5m2: tl.float8e5}[fp8_dtype]
max_val = torch.finfo(fp8_dtype).max
abs_max_scratch = torch.empty((), dtype=inpt_tensor.dtype, device="cuda")
dynamic_scaled_cast[grid](
inpt_tensor, out_tensor, abs_max_scratch, numel, 4096, tl_dtype, max_val, num_warps=8
)
return out_tensor


def scaled_quant(
inpt_tensor: torch.Tensor,
scale: torch.Tensor,
Expand Down Expand Up @@ -92,3 +144,19 @@ def eager_scaled_quant(
)
_ = torch.max(torch.abs(out))
return out.to(fp8_dtype)


def eager_dynamic_scaled_quant(
a: torch.Tensor,
fp8_dtype: torch.dtype,
) -> torch.Tensor:
"""Quantize tensor to fp8 using the current amax value to generate scale
Args:
a: Input tensor to quantize
fp8_dtype: FP8 datatype to quantize to
"""
from float8_experimental.float8_utils import tensor_to_scale, to_fp8_saturated

scale = tensor_to_scale(a, fp8_dtype)
a = a * scale
return to_fp8_saturated(a, fp8_dtype)

0 comments on commit 8b62d1a

Please sign in to comment.