Skip to content

Commit

Permalink
Dynamic scaling triton kernel (#28)
Browse files Browse the repository at this point in the history
  • Loading branch information
drisspg authored Apr 12, 2024
1 parent c04d670 commit 0a2d2ae
Show file tree
Hide file tree
Showing 3 changed files with 245 additions and 1 deletion.
146 changes: 146 additions & 0 deletions benchmarks/fp8_dynamic_cast.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
import itertools

from dataclasses import dataclass
from typing import List

import torch

from tabulate import tabulate
from tqdm import tqdm

from transformer_nuggets.fp8.scaled_quant import dynamic_scaled_quant, eager_dynamic_scaled_quant
from transformer_nuggets.utils import benchmark_torch_function_in_microseconds

device = torch.device("cuda")

# Needed since changing args to function causes recompiles
torch._dynamo.config.cache_size_limit = 1000


@dataclass(frozen=True)
class ExperimentConfig:
numel: int
high_precision_dtype: torch.dtype
low_precision_dtype: torch.dtype


@dataclass(frozen=True)
class ExperimentResult:
triton_time: float
pytorch_time: float
compiled_pytorch_time: float


@dataclass(frozen=True)
class Experiment:
config: ExperimentConfig
result: ExperimentResult


def get_configs() -> List[ExperimentConfig]:
sizes = [2**21, 2**22, 2**23, 2**24]
high_precision_dtypes = [torch.float32]
low_precision_dtypes = [torch.float8_e4m3fn, torch.float8_e5m2]
configs = []
for size, high_precision_dtype, low_precision_dtype in itertools.product(
sizes, high_precision_dtypes, low_precision_dtypes
):
configs.append(
ExperimentConfig(
numel=size,
high_precision_dtype=high_precision_dtype,
low_precision_dtype=low_precision_dtype,
)
)
return configs


def correctness_check(hp_tensor, triton_tensor, config):
# Correctness check:
nuggets_out = dynamic_scaled_quant(
triton_tensor,
config.low_precision_dtype,
).to(config.high_precision_dtype)

eager_out = eager_dynamic_scaled_quant(
hp_tensor,
config.low_precision_dtype,
).to(config.high_precision_dtype)

print(f"Deviation between Triton and Nuggets: {torch.abs(nuggets_out - eager_out).max()}")
max_dev_index = torch.abs(nuggets_out - eager_out).argmax().item()
print(f"nuggets_out tensor value: {nuggets_out.flatten()[max_dev_index]:.4f}")
print(f"eager_out tensor value: {eager_out.flatten()[max_dev_index]:.4f}")


def run_experiment(config: ExperimentConfig) -> ExperimentResult:
high_precision_tensor = torch.randn(
config.numel, dtype=config.high_precision_dtype, device=device
)
triton_hp_tensor = high_precision_tensor.clone()

# Triton does different rounding as far as I can tell
if False:
correctness_check(high_precision_tensor, triton_hp_tensor, config)

triton_time = benchmark_torch_function_in_microseconds(
dynamic_scaled_quant,
triton_hp_tensor,
config.low_precision_dtype,
)
pytorch_time = benchmark_torch_function_in_microseconds(
eager_dynamic_scaled_quant,
high_precision_tensor,
config.low_precision_dtype,
)
compiled_pytorch_fn = torch.compile(eager_dynamic_scaled_quant, fullgraph=True)
compiled_pytorch_time = benchmark_torch_function_in_microseconds(
compiled_pytorch_fn,
high_precision_tensor,
config.low_precision_dtype,
)
return ExperimentResult(
triton_time=triton_time,
pytorch_time=pytorch_time,
compiled_pytorch_time=compiled_pytorch_time,
)


def print_results(experiments: List[Experiment]):
headers = [
"numel",
"high_precision_dtype",
"low_precision_dtype",
"triton_time",
"pytorch_time",
"compiled_pytorch_time",
]
rows = []
for experiment in experiments:
rows.append(
[
experiment.config.numel,
experiment.config.high_precision_dtype,
experiment.config.low_precision_dtype,
experiment.result.triton_time,
experiment.result.pytorch_time,
experiment.result.compiled_pytorch_time,
]
)
print(tabulate(rows, headers=headers))


def main():
torch.random.manual_seed(123)
configs = get_configs()
results = []
for config in tqdm(configs):
result = run_experiment(config)
results.append(Experiment(config=config, result=result))

# Use Tabulate to print results
print_results(results)


if __name__ == "__main__":
main()
18 changes: 17 additions & 1 deletion test/test_fp8.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
import pytest
import torch

from transformer_nuggets.fp8.scaled_quant import eager_scaled_quant, scaled_quant
from transformer_nuggets.fp8.scaled_quant import (
dynamic_scaled_quant,
eager_dynamic_scaled_quant,
eager_scaled_quant,
scaled_quant,
)


@pytest.mark.parametrize("fp8_dtype", [torch.float8_e4m3fn, torch.float8_e5m2])
Expand Down Expand Up @@ -36,5 +41,16 @@ def test_saturated(fp8_dtype):
)


@pytest.mark.parametrize("fp8_dtype", [torch.float8_e4m3fn, torch.float8_e5m2])
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is not available")
def test_dynamic_quant(fp8_dtype):
torch.manual_seed(0)
torch.cuda.manual_seed_all(0)
a = torch.randn(2**12, 2**12, dtype=torch.float32, device="cuda") * 9.6
output = eager_dynamic_scaled_quant(a, fp8_dtype)
output_triton = dynamic_scaled_quant(a, fp8_dtype)
torch.testing.assert_close(output.to(torch.float32), output_triton.to(torch.float32))


if __name__ == "__main__":
pytest.main([__file__])
82 changes: 82 additions & 0 deletions transformer_nuggets/fp8/scaled_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,3 +92,85 @@ def eager_scaled_quant(
)
_ = torch.max(torch.abs(out))
return out.to(fp8_dtype)


# ----------- Dynamic Scaled Quantization ------------


@triton.jit
def dynamic_scaled_cast(
inpt_ptr: torch.Tensor,
output_ptr: torch.Tensor,
abs_max_ptr: torch.Tensor,
spin_lock: torch.Tensor,
numel: int,
XBLOCK: tl.constexpr,
float8_dtype: tl.constexpr,
max_val: tl.constexpr,
):
"""Quantize tensor to fp8 using current global absmax"""
n_blocks = tl.num_programs(0)
offset = tl.program_id(0) * XBLOCK
index = offset + tl.arange(0, XBLOCK)[:]
index = tl.max_contiguous(tl.multiple_of(index, XBLOCK), XBLOCK)
mask = index < numel
inpt = tl.load(inpt_ptr + (index), mask=mask)
block_max = tl.max(tl.abs(inpt))
tl.atomic_max(abs_max_ptr, block_max)
# Spinlock global barrier
tl.atomic_add(spin_lock, 1)
while tl.load(spin_lock) < n_blocks:
pass
scale = max_val / (tl.clamp(tl.load(abs_max_ptr), -1e12, float("inf")))
scaled_inpt = inpt * scale
# Saturated casting
scaled_inpt = tl.clamp(scaled_inpt, -1 * max_val, max_val)
tl.store(output_ptr + (index), scaled_inpt.to(float8_dtype), mask=mask)


def dynamic_scaled_quant(
inpt_tensor: torch.Tensor, fp8_dtype: torch.dtype = torch.float8_e4m3fn
) -> torch.Tensor:
"""Quantize tensor to fp8 using dynamic scale calculated from abs_max
It will do saturated casting
Args:
inpt_tensor: Input tensor to quantize
fp8_dtype: FP8 datatype to quantize to
"""
assert inpt_tensor.is_contiguous(), "Input tensor must be contiguous"

out_tensor = torch.empty_like(inpt_tensor, dtype=fp8_dtype, device="cuda")
numel = inpt_tensor.numel()
grid = lambda meta: (triton.cdiv(numel, meta["XBLOCK"]),)
tl_dtype = {torch.float8_e4m3fn: tl.float8e4nv, torch.float8_e5m2: tl.float8e5}[fp8_dtype]
max_val = torch.finfo(fp8_dtype).max
abs_max_scratch = torch.empty((), dtype=inpt_tensor.dtype, device="cuda")
spin_lock = torch.zeros((), dtype=torch.int32, device="cuda")
dynamic_scaled_cast[grid](
inpt_tensor,
out_tensor,
abs_max_scratch,
spin_lock,
numel,
4096,
tl_dtype,
max_val,
)
return out_tensor


def eager_dynamic_scaled_quant(
a: torch.Tensor,
fp8_dtype: torch.dtype,
) -> torch.Tensor:
"""Quantize tensor to fp8 using the current amax value to generate scale
Args:
a: Input tensor to quantize
fp8_dtype: FP8 datatype to quantize to
"""
from float8_experimental.float8_utils import tensor_to_scale, to_fp8_saturated

scale = tensor_to_scale(a, fp8_dtype)
a = a * scale
return to_fp8_saturated(a, fp8_dtype)

0 comments on commit 0a2d2ae

Please sign in to comment.