From 8b62d1ad2b1e2d473e3c5b2350468463e05cab23 Mon Sep 17 00:00:00 2001 From: drisspg Date: Fri, 12 Apr 2024 14:05:47 -0700 Subject: [PATCH] this shouldn't work --- test/test_fp8.py | 18 ++++++- transformer_nuggets/fp8/scaled_quant.py | 68 +++++++++++++++++++++++++ 2 files changed, 85 insertions(+), 1 deletion(-) diff --git a/test/test_fp8.py b/test/test_fp8.py index 73ae2b9..ff6850e 100644 --- a/test/test_fp8.py +++ b/test/test_fp8.py @@ -1,7 +1,12 @@ import pytest import torch -from transformer_nuggets.fp8.scaled_quant import eager_scaled_quant, scaled_quant +from transformer_nuggets.fp8.scaled_quant import ( + dynamic_scaled_quant, + eager_dynamic_scaled_quant, + eager_scaled_quant, + scaled_quant, +) @pytest.mark.parametrize("fp8_dtype", [torch.float8_e4m3fn, torch.float8_e5m2]) @@ -36,5 +41,16 @@ def test_saturated(fp8_dtype): ) +@pytest.mark.parametrize("fp8_dtype", [torch.float8_e4m3fn, torch.float8_e5m2]) +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is not available") +def test_dynamic_quant(fp8_dtype): + torch.manual_seed(0) + torch.cuda.manual_seed_all(0) + a = torch.randn(2**12, 2**12, dtype=torch.float32, device="cuda") * 9.6 + output = eager_dynamic_scaled_quant(a, fp8_dtype) + output_triton = dynamic_scaled_quant(a, fp8_dtype) + torch.testing.assert_close(output.to(torch.float32), output_triton.to(torch.float32)) + + if __name__ == "__main__": pytest.main([__file__]) diff --git a/transformer_nuggets/fp8/scaled_quant.py b/transformer_nuggets/fp8/scaled_quant.py index 9ff6e96..b88c8a9 100644 --- a/transformer_nuggets/fp8/scaled_quant.py +++ b/transformer_nuggets/fp8/scaled_quant.py @@ -32,6 +32,58 @@ def scaled_cast( tl.store(output_ptr + (index), scaled_inpt.to(float8_dtype), mask=mask) +@triton.jit +def dynamic_scaled_cast( + inpt_ptr: torch.Tensor, + output_ptr: torch.Tensor, + abs_max_ptr: torch.Tensor, + numel: int, + XBLOCK: tl.constexpr, + float8_dtype: tl.constexpr, + max_val: tl.constexpr, +): + """Quantize tensor to fp8 using current global absmax""" + offset = tl.program_id(0) * XBLOCK + index = offset + tl.arange(0, XBLOCK)[:] + index = tl.max_contiguous(tl.multiple_of(index, XBLOCK), XBLOCK) + mask = index < numel + inpt = tl.load(inpt_ptr + (index), mask=mask) + block_max = tl.max(tl.abs(inpt)) + tl.atomic_max(abs_max_ptr, block_max) + # TODO Need a global barrier to ensure all blocks have updated abs_max + # Yet the test passes...? + scale = max_val / (tl.load(abs_max_ptr) + 1e-12) + scaled_inpt = inpt * scale + # Saturated casting + tl.where(scaled_inpt > max_val, max_val, scaled_inpt) + tl.where(scaled_inpt < -1 * max_val, -1 * max_val, scaled_inpt) + tl.store(output_ptr + (index), scaled_inpt.to(float8_dtype), mask=mask) + + +def dynamic_scaled_quant( + inpt_tensor: torch.Tensor, fp8_dtype: torch.dtype = torch.float8_e4m3fn +) -> torch.Tensor: + """Quantize tensor to fp8 using dynamic scale calculated from abs_max + It will do saturated casting + + Args: + inpt_tensor: Input tensor to quantize + fp8_dtype: FP8 datatype to quantize to + """ + assert inpt_tensor.is_contiguous(), "Input tensor must be contiguous" + + out_tensor = torch.empty_like(inpt_tensor, dtype=fp8_dtype, device="cuda") + numel = inpt_tensor.numel() + grid = lambda meta: (triton.cdiv(numel, meta["XBLOCK"]),) + tl_dtype = {torch.float8_e4m3fn: tl.float8e4nv, torch.float8_e5m2: tl.float8e5}[fp8_dtype] + max_val = torch.finfo(fp8_dtype).max + abs_max_scratch = torch.empty((), dtype=inpt_tensor.dtype, device="cuda") + dynamic_scaled_cast[grid]( + inpt_tensor, out_tensor, abs_max_scratch, numel, 4096, tl_dtype, max_val, num_warps=8 + ) + return out_tensor + + def scaled_quant( inpt_tensor: torch.Tensor, scale: torch.Tensor, @@ -92,3 +144,19 @@ def eager_scaled_quant( ) _ = torch.max(torch.abs(out)) return out.to(fp8_dtype) + + +def eager_dynamic_scaled_quant( + a: torch.Tensor, + fp8_dtype: torch.dtype, +) -> torch.Tensor: + """Quantize tensor to fp8 using the current amax value to generate scale + Args: + a: Input tensor to quantize + fp8_dtype: FP8 datatype to quantize to + """ + from float8_experimental.float8_utils import tensor_to_scale, to_fp8_saturated + + scale = tensor_to_scale(a, fp8_dtype) + a = a * scale + return to_fp8_saturated(a, fp8_dtype)