From 0a2d2aed61fb8f2f6e27c906ae8847d57f59cd09 Mon Sep 17 00:00:00 2001 From: Driss Guessous <32754868+drisspg@users.noreply.github.com> Date: Fri, 12 Apr 2024 16:01:32 -0700 Subject: [PATCH] Dynamic scaling triton kernel (#28) --- benchmarks/fp8_dynamic_cast.py | 146 ++++++++++++++++++++++++ test/test_fp8.py | 18 ++- transformer_nuggets/fp8/scaled_quant.py | 82 +++++++++++++ 3 files changed, 245 insertions(+), 1 deletion(-) create mode 100644 benchmarks/fp8_dynamic_cast.py diff --git a/benchmarks/fp8_dynamic_cast.py b/benchmarks/fp8_dynamic_cast.py new file mode 100644 index 0000000..1f258c0 --- /dev/null +++ b/benchmarks/fp8_dynamic_cast.py @@ -0,0 +1,146 @@ +import itertools + +from dataclasses import dataclass +from typing import List + +import torch + +from tabulate import tabulate +from tqdm import tqdm + +from transformer_nuggets.fp8.scaled_quant import dynamic_scaled_quant, eager_dynamic_scaled_quant +from transformer_nuggets.utils import benchmark_torch_function_in_microseconds + +device = torch.device("cuda") + +# Needed since changing args to function causes recompiles +torch._dynamo.config.cache_size_limit = 1000 + + +@dataclass(frozen=True) +class ExperimentConfig: + numel: int + high_precision_dtype: torch.dtype + low_precision_dtype: torch.dtype + + +@dataclass(frozen=True) +class ExperimentResult: + triton_time: float + pytorch_time: float + compiled_pytorch_time: float + + +@dataclass(frozen=True) +class Experiment: + config: ExperimentConfig + result: ExperimentResult + + +def get_configs() -> List[ExperimentConfig]: + sizes = [2**21, 2**22, 2**23, 2**24] + high_precision_dtypes = [torch.float32] + low_precision_dtypes = [torch.float8_e4m3fn, torch.float8_e5m2] + configs = [] + for size, high_precision_dtype, low_precision_dtype in itertools.product( + sizes, high_precision_dtypes, low_precision_dtypes + ): + configs.append( + ExperimentConfig( + numel=size, + high_precision_dtype=high_precision_dtype, + low_precision_dtype=low_precision_dtype, + ) + ) + return configs + + +def correctness_check(hp_tensor, triton_tensor, config): + # Correctness check: + nuggets_out = dynamic_scaled_quant( + triton_tensor, + config.low_precision_dtype, + ).to(config.high_precision_dtype) + + eager_out = eager_dynamic_scaled_quant( + hp_tensor, + config.low_precision_dtype, + ).to(config.high_precision_dtype) + + print(f"Deviation between Triton and Nuggets: {torch.abs(nuggets_out - eager_out).max()}") + max_dev_index = torch.abs(nuggets_out - eager_out).argmax().item() + print(f"nuggets_out tensor value: {nuggets_out.flatten()[max_dev_index]:.4f}") + print(f"eager_out tensor value: {eager_out.flatten()[max_dev_index]:.4f}") + + +def run_experiment(config: ExperimentConfig) -> ExperimentResult: + high_precision_tensor = torch.randn( + config.numel, dtype=config.high_precision_dtype, device=device + ) + triton_hp_tensor = high_precision_tensor.clone() + + # Triton does different rounding as far as I can tell + if False: + correctness_check(high_precision_tensor, triton_hp_tensor, config) + + triton_time = benchmark_torch_function_in_microseconds( + dynamic_scaled_quant, + triton_hp_tensor, + config.low_precision_dtype, + ) + pytorch_time = benchmark_torch_function_in_microseconds( + eager_dynamic_scaled_quant, + high_precision_tensor, + config.low_precision_dtype, + ) + compiled_pytorch_fn = torch.compile(eager_dynamic_scaled_quant, fullgraph=True) + compiled_pytorch_time = benchmark_torch_function_in_microseconds( + compiled_pytorch_fn, + high_precision_tensor, + config.low_precision_dtype, + ) + return ExperimentResult( + triton_time=triton_time, + pytorch_time=pytorch_time, + compiled_pytorch_time=compiled_pytorch_time, + ) + + +def print_results(experiments: List[Experiment]): + headers = [ + "numel", + "high_precision_dtype", + "low_precision_dtype", + "triton_time", + "pytorch_time", + "compiled_pytorch_time", + ] + rows = [] + for experiment in experiments: + rows.append( + [ + experiment.config.numel, + experiment.config.high_precision_dtype, + experiment.config.low_precision_dtype, + experiment.result.triton_time, + experiment.result.pytorch_time, + experiment.result.compiled_pytorch_time, + ] + ) + print(tabulate(rows, headers=headers)) + + +def main(): + torch.random.manual_seed(123) + configs = get_configs() + results = [] + for config in tqdm(configs): + result = run_experiment(config) + results.append(Experiment(config=config, result=result)) + + # Use Tabulate to print results + print_results(results) + + +if __name__ == "__main__": + main() diff --git a/test/test_fp8.py b/test/test_fp8.py index 73ae2b9..ff6850e 100644 --- a/test/test_fp8.py +++ b/test/test_fp8.py @@ -1,7 +1,12 @@ import pytest import torch -from transformer_nuggets.fp8.scaled_quant import eager_scaled_quant, scaled_quant +from transformer_nuggets.fp8.scaled_quant import ( + dynamic_scaled_quant, + eager_dynamic_scaled_quant, + eager_scaled_quant, + scaled_quant, +) @pytest.mark.parametrize("fp8_dtype", [torch.float8_e4m3fn, torch.float8_e5m2]) @@ -36,5 +41,16 @@ def test_saturated(fp8_dtype): ) +@pytest.mark.parametrize("fp8_dtype", [torch.float8_e4m3fn, torch.float8_e5m2]) +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is not available") +def test_dynamic_quant(fp8_dtype): + torch.manual_seed(0) + torch.cuda.manual_seed_all(0) + a = torch.randn(2**12, 2**12, dtype=torch.float32, device="cuda") * 9.6 + output = eager_dynamic_scaled_quant(a, fp8_dtype) + output_triton = dynamic_scaled_quant(a, fp8_dtype) + torch.testing.assert_close(output.to(torch.float32), output_triton.to(torch.float32)) + + if __name__ == "__main__": pytest.main([__file__]) diff --git a/transformer_nuggets/fp8/scaled_quant.py b/transformer_nuggets/fp8/scaled_quant.py index 9ff6e96..4537600 100644 --- a/transformer_nuggets/fp8/scaled_quant.py +++ b/transformer_nuggets/fp8/scaled_quant.py @@ -92,3 +92,85 @@ def eager_scaled_quant( ) _ = torch.max(torch.abs(out)) return out.to(fp8_dtype) + + +# ----------- Dynamic Scaled Quantization ------------ + + +@triton.jit +def dynamic_scaled_cast( + inpt_ptr: torch.Tensor, + output_ptr: torch.Tensor, + abs_max_ptr: torch.Tensor, + spin_lock: torch.Tensor, + numel: int, + XBLOCK: tl.constexpr, + float8_dtype: tl.constexpr, + max_val: tl.constexpr, +): + """Quantize tensor to fp8 using current global absmax""" + n_blocks = tl.num_programs(0) + offset = tl.program_id(0) * XBLOCK + index = offset + tl.arange(0, XBLOCK)[:] + index = tl.max_contiguous(tl.multiple_of(index, XBLOCK), XBLOCK) + mask = index < numel + inpt = tl.load(inpt_ptr + (index), mask=mask) + block_max = tl.max(tl.abs(inpt)) + tl.atomic_max(abs_max_ptr, block_max) + # Spinlock global barrier + tl.atomic_add(spin_lock, 1) + while tl.load(spin_lock) < n_blocks: + pass + scale = max_val / (tl.clamp(tl.load(abs_max_ptr), -1e12, float("inf"))) + scaled_inpt = inpt * scale + # Saturated casting + scaled_inpt = tl.clamp(scaled_inpt, -1 * max_val, max_val) + tl.store(output_ptr + (index), scaled_inpt.to(float8_dtype), mask=mask) + + +def dynamic_scaled_quant( + inpt_tensor: torch.Tensor, fp8_dtype: torch.dtype = torch.float8_e4m3fn +) -> torch.Tensor: + """Quantize tensor to fp8 using dynamic scale calculated from abs_max + It will do saturated casting + + Args: + inpt_tensor: Input tensor to quantize + fp8_dtype: FP8 datatype to quantize to + """ + assert inpt_tensor.is_contiguous(), "Input tensor must be contiguous" + + out_tensor = torch.empty_like(inpt_tensor, dtype=fp8_dtype, device="cuda") + numel = inpt_tensor.numel() + grid = lambda meta: (triton.cdiv(numel, meta["XBLOCK"]),) + tl_dtype = {torch.float8_e4m3fn: tl.float8e4nv, torch.float8_e5m2: tl.float8e5}[fp8_dtype] + max_val = torch.finfo(fp8_dtype).max + abs_max_scratch = torch.empty((), dtype=inpt_tensor.dtype, device="cuda") + spin_lock = torch.zeros((), dtype=torch.int32, device="cuda") + dynamic_scaled_cast[grid]( + inpt_tensor, + out_tensor, + abs_max_scratch, + spin_lock, + numel, + 4096, + tl_dtype, + max_val, + ) + return out_tensor + + +def eager_dynamic_scaled_quant( + a: torch.Tensor, + fp8_dtype: torch.dtype, +) -> torch.Tensor: + """Quantize tensor to fp8 using the current amax value to generate scale + Args: + a: Input tensor to quantize + fp8_dtype: FP8 datatype to quantize to + """ + from float8_experimental.float8_utils import tensor_to_scale, to_fp8_saturated + + scale = tensor_to_scale(a, fp8_dtype) + a = a * scale + return to_fp8_saturated(a, fp8_dtype)