diff --git a/benchmarks/benchmark_blockwise_scaled_linear_triton.py b/benchmarks/benchmark_blockwise_scaled_linear_triton.py new file mode 100644 index 0000000000..9256123e11 --- /dev/null +++ b/benchmarks/benchmark_blockwise_scaled_linear_triton.py @@ -0,0 +1,102 @@ +import pandas as pd +import torch +from tqdm import tqdm +from triton.testing import do_bench + +from torchao.float8.float8_utils import compute_error +from torchao.ops import rowwise_scaled_linear_cutlass_s8s4 +from torchao.prototype.blockwise_fp8.blockwise_fp8_gemm_triton import blockwise_fp8_gemm +from torchao.prototype.blockwise_fp8.blockwise_quantization import ( + fp8_blockwise_act_quant, + fp8_blockwise_weight_quant, +) +from torchao.quantization.quant_api import ( + int8_dynamic_activation_int4_weight, + quantize_, +) + + +def benchmark_microseconds(f, *args): + return do_bench(lambda: f(*args), return_mode="median") * 1e3 + +def get_rowwise_problem(m: int, n: int, k: int): + dev = torch.device("cuda") + A = torch.randint(-128, 127, (m, k), dtype=torch.int8, device=dev) + A_scale = torch.randn((m,), dtype=torch.half, device=dev) + B = torch.randint( + -128, 127, size=(n, 4 * k // 8), dtype=torch.int8, device=dev + ) + B_scale = torch.randn((n,), dtype=torch.half, device=dev) + C = None + + return A, A_scale, B, B_scale, C + +def get_blockwise_problem(m: int, n: int, k: int, block_size: int): + assert n % block_size == 0 and k % block_size == 0, "N and K dims must be divisible by block_size" + dev = torch.device("cuda") + A = (448.0 * (2 * torch.rand(m, k, device=dev) - 1)).to(torch.float8_e4m3fn) + A_scale = torch.randn((m, k // block_size), dtype=torch.half, device=dev) + B = (448.0 * (2 * torch.rand(n, k, device=dev) - 1)).to(torch.float8_e4m3fn) + B_scale = torch.randn((n // block_size, k // block_size), dtype=torch.half, device=dev) + + return A, A_scale, B, B_scale + +def benchmark(m: int, k: int, n: int, block_size: int): + # Speed benchmark + dev = torch.device("cuda") + A_ref = torch.randn((m, k), dtype=torch.half, device=dev) + B_ref = torch.randn((n, k), dtype=torch.half, device=dev) + fp16_time = benchmark_microseconds(torch.nn.functional.linear, A_ref, B_ref) + + A, A_scale, B, B_scale, C = get_rowwise_problem(m, n, k) + rowwise_scaled_linear_cutlass_s8s4_time = benchmark_microseconds( + rowwise_scaled_linear_cutlass_s8s4, A, A_scale, B, B_scale, C + ) + + A, A_scale, B, B_scale = get_blockwise_problem(m, n, k, block_size) + blockwise_fp8_gemm_time = benchmark_microseconds( + blockwise_fp8_gemm, A, A_scale, B, B_scale + ) + + # Precision benchmark + lin = torch.nn.Linear(k, n, False, dev, torch.half) + A = torch.randn((m, k), dtype=torch.half, device=dev) + W = lin.weight + output = A @ W.T + + A_q, A_s = fp8_blockwise_act_quant(A, block_size) + W_q, W_s = fp8_blockwise_weight_quant(W, block_size) + output_blockwise_quant = blockwise_fp8_gemm(A_q, A_s, W_q, W_s) + + quantize_(lin, int8_dynamic_activation_int4_weight()) + output_rowwise_quant = lin(A) + + error_rowwise_quant = compute_error(output, output_rowwise_quant) + error_blockwise_quant = compute_error(output, output_blockwise_quant) + + return { + "m": m, + "k": k, + "n": n, + "fp16_latency (ms)": fp16_time, + "rowwise_scaled_linear_cutlass_s8s4 latency (ms)": rowwise_scaled_linear_cutlass_s8s4_time, + "rowwise s8s4 speedup (d/s)": fp16_time / rowwise_scaled_linear_cutlass_s8s4_time, + "blockwise_fp8_gemm latency (ms)": blockwise_fp8_gemm_time, + "blockwise fp8 speedup (d/s)": fp16_time / blockwise_fp8_gemm_time, + "error_rowwise_quant (dB)": error_rowwise_quant, + "error_blockwise_quant (dB)": error_blockwise_quant + } + +if __name__ == "__main__": + k_vals = (8192, 8192, 8192, 28672) + n_vals = (8192, 10240, 57344, 8192) + block_size_vals = (128, 128, 128, 128) + + results = [] + for m in tqdm([1 << i for i in range(10)]): + for n, k, block_size in zip(n_vals, k_vals, block_size_vals): + results.append(benchmark(m, k, n, block_size)) + + df = pd.DataFrame(results) + df.to_csv("blockwise_scaled_linear_triton_results.csv", index=False) + print(df.to_markdown(index=False)) \ No newline at end of file diff --git a/benchmarks/float8/bench_matmul.py b/benchmarks/float8/bench_matmul.py index 3d48853754..53ce02aacb 100644 --- a/benchmarks/float8/bench_matmul.py +++ b/benchmarks/float8/bench_matmul.py @@ -124,10 +124,13 @@ def run( if scaling_granularity == ScalingGranularity.TENSORWISE: scale_a = torch.tensor([1.0], device=device) scale_b = torch.tensor([1.0], device=device) - else: - assert scaling_granularity == ScalingGranularity.AXISWISE, "unsupported" + elif scaling_granularity == ScalingGranularity.AXISWISE: scale_a = torch.ones(M, 1, device=device) scale_b = torch.ones(1, N, device=device) + else: + assert scaling_granularity == ScalingGranularity.BLOCKWISE, "unsupported" + scale_a = torch.ones(M, N, device=device) + scale_b = torch.ones(M, N, device=device) def do_matmul(A, B): nonlocal scale_a diff --git a/benchmarks/float8/float8_roofline.py b/benchmarks/float8/float8_roofline.py index 2b3f631d8c..0ea90aefcb 100644 --- a/benchmarks/float8/float8_roofline.py +++ b/benchmarks/float8/float8_roofline.py @@ -354,14 +354,30 @@ def run( m_fp8_dyn_axs = torch.compile(m_fp8_dyn_axs) fp8_dyn_axs_time_actual_s = get_gpu_kernel_time(m_fp8_dyn_axs, x) - # get the lw recipe scaling gpu kernel time + # get the float8 dynamic blockwise scaling gpu kernel time + torch._dynamo.reset() + config = recipe_name_to_linear_config(Float8LinearRecipeName.ALL_BLOCKWISE) + m_fp8_dyn_blk = convert_to_float8_training(copy.deepcopy(m_orig), config=config) + m_fp8_dyn_blk = torch.compile(m_fp8_dyn_blk) + fp8_dyn_blk_time_actual_s = get_gpu_kernel_time(m_fp8_dyn_blk, x) + + # get the lw_axs recipe scaling gpu kernel time # TODO(future PR): enable below once basic performance issues # are fixed # torch._dynamo.reset() # config = recipe_name_to_linear_config(Float8LinearRecipeName.LW_AXISWISE_WITH_GW_HP) - # m_fp8_lw = convert_to_float8_training(m_orig, config=config) - # m_fp8_lw = torch.compile(m_fp8_lw) - # fp8_lw_time_actual_s = get_gpu_kernel_time(m_fp8_lw, x) + # m_fp8_lw_axs = convert_to_float8_training(m_orig, config=config) + # m_fp8_lw_axs = torch.compile(m_fp8_lw_axs) + # fp8_lw_axs_time_actual_s = get_gpu_kernel_time(m_fp8_lw_axs, x) + + # get the lw_blk recipe scaling gpu kernel time + # TODO(future PR): enable below once basic performance issues + # are fixed + # torch._dynamo.reset() + # config = recipe_name_to_linear_config(Float8LinearRecipeName.LW_BLOCKWISE_WITH_GW_HP) + # m_fp8_lw_blk = convert_to_float8_training(m_orig, config=config) + # m_fp8_lw_blk = torch.compile(m_fp8_lw_blk) + # fp8_lw_blk_time_actual_s = get_gpu_kernel_time(m_fp8_lw_blk, x) results.append( [ @@ -382,6 +398,7 @@ def run( fp8_dyn_time_actual_s, fp8_del_time_actual_s, fp8_dyn_axs_time_actual_s, + fp8_dyn_blk_time_actual_s, # fp8_lw_time_actual_s, bf16_time_actual_s / fp8_dyn_time_actual_s, bf16_time_actual_s / fp8_del_time_actual_s, diff --git a/test/float8/test_base.py b/test/float8/test_base.py index 3e894c02b9..7fabe98b6e 100644 --- a/test/float8/test_base.py +++ b/test/float8/test_base.py @@ -43,6 +43,7 @@ from torchao.float8.float8_python_api import addmm_float8_unwrapped from torchao.float8.float8_scaling_utils import ( get_maybe_axiswise_dim, + get_maybe_blockwise_size, hp_tensor_to_float8_dynamic, ) from torchao.float8.float8_tensor import ( @@ -178,6 +179,22 @@ def test_axiswise_dynamic_cast(self, shape, axiswise_dim): sqnr = compute_error(a, a_dq) assert sqnr >= 25.0 + @pytest.mark.parametrize("shape", [(8, 16), (4, 8, 16), (2, 4, 8, 16)]) + @pytest.mark.parametrize("blockwise_size", [4]) + def test_blockwise_dynamic_cast(self, shape, blockwise_size): + a = torch.randn(*shape, dtype=torch.bfloat16) + linear_mm_config = LinearMMConfig() + a_fp8 = hp_tensor_to_float8_dynamic( + a, + e4m3_dtype, + linear_mm_config, + scaling_granularity=ScalingGranularity.BLOCKWISE, + blockwise_size=blockwise_size, + ) + a_dq = a_fp8.to_original_precision() + sqnr = compute_error(a, a_dq) + assert sqnr >= 25.0 + def test_axiswise_reshape(self): a = torch.randn(3, 5, 7, dtype=torch.bfloat16) linear_mm_config = LinearMMConfig() @@ -272,6 +289,48 @@ def test_axiswise_gemm(self, a_shape, a_granularity, b_granularity): sqnr = compute_error(c_ref, c_fp8_compute) assert sqnr >= 25.0 + @pytest.mark.parametrize("a_shape", [(16, 32), (2, 16, 32), (1, 2, 16, 32)]) + @pytest.mark.parametrize( + "a_granularity,b_granularity", + [ + (ScalingGranularity.BLOCKWISE, ScalingGranularity.BLOCKWISE), + (ScalingGranularity.BLOCKWISE, ScalingGranularity.TENSORWISE), + (ScalingGranularity.TENSORWISE, ScalingGranularity.BLOCKWISE), + ], + ) + @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") + @unittest.skipIf(not is_sm_at_least_90(), "Requires CUDA capability >= 9.0") + def test_blockwise_gemm(self, a_shape, a_granularity, b_granularity): + a = torch.randn(*a_shape, dtype=torch.bfloat16, device="cuda") + b = torch.randn(64, 32, dtype=torch.bfloat16, device="cuda") + + linear_mm_config = LinearMMConfig() + + a_fp8 = hp_tensor_to_float8_dynamic( + a, + e4m3_dtype, + linear_mm_config, + gemm_input_role=GemmInputRole.INPUT, + scaling_granularity=a_granularity, + blockwise_size=get_maybe_blockwise_size(8, a_granularity), + ) + a_fp8 = a_fp8.reshape(-1, a_shape[-1]) + + b_fp8 = hp_tensor_to_float8_dynamic( + b, + e4m3_dtype, + linear_mm_config, + gemm_input_role=GemmInputRole.WEIGHT, + scaling_granularity=b_granularity, + blockwise_size=get_maybe_blockwise_size(8, b_granularity), + ) + + c_fp8_compute = torch.mm(a_fp8, b_fp8.t()) + a = a.reshape(-1, a_shape[-1]) + c_ref = torch.mm(a, b.t()) + sqnr = compute_error(c_ref, c_fp8_compute) + assert sqnr >= 25.0 + class TestFloat8Linear: def _test_linear_impl( @@ -417,7 +476,9 @@ def test_linear_from_config_params( "recipe_name", [ Float8LinearRecipeName.ALL_AXISWISE, + Float8LinearRecipeName.ALL_BLOCKWISE, Float8LinearRecipeName.LW_AXISWISE_WITH_GW_HP, + Float8LinearRecipeName.LW_BLOCKWISE_WITH_GW_HP, ], ) @pytest.mark.parametrize("x_shape", [(16, 16), (2, 16, 16), (3, 2, 16, 16)]) diff --git a/test/float8/test_compile.py b/test/float8/test_compile.py index c42ab8ee77..c21dd456fe 100644 --- a/test/float8/test_compile.py +++ b/test/float8/test_compile.py @@ -223,7 +223,9 @@ def test_inductor_from_config_params( "recipe_name", [ Float8LinearRecipeName.ALL_AXISWISE, + Float8LinearRecipeName.ALL_BLOCKWISE, Float8LinearRecipeName.LW_AXISWISE_WITH_GW_HP, + Float8LinearRecipeName.LW_BLOCKWISE_WITH_GW_HP, ], ) @unittest.skipIf( diff --git a/test/float8/test_numerics_integration.py b/test/float8/test_numerics_integration.py index 311964d831..01a8822294 100644 --- a/test/float8/test_numerics_integration.py +++ b/test/float8/test_numerics_integration.py @@ -199,7 +199,9 @@ def test_encoder_fw_bw_from_config_params( "recipe_name", [ Float8LinearRecipeName.ALL_AXISWISE, + Float8LinearRecipeName.ALL_BLOCKWISE, Float8LinearRecipeName.LW_AXISWISE_WITH_GW_HP, + Float8LinearRecipeName.LW_BLOCKWISE_WITH_GW_HP, ], ) @pytest.mark.skipif( diff --git a/test/prototype/test_blockwise_triton.py b/test/prototype/test_blockwise_triton.py new file mode 100644 index 0000000000..9aa1c50a18 --- /dev/null +++ b/test/prototype/test_blockwise_triton.py @@ -0,0 +1,51 @@ +import pytest +import torch + +from torchao.prototype.blockwise_fp8.blockwise_fp8_gemm_triton import blockwise_fp8_gemm +from torchao.prototype.blockwise_fp8.blockwise_quantization import ( + fp8_blockwise_act_quant, + fp8_blockwise_weight_dequant, + fp8_blockwise_weight_quant, +) + +ROWWISE_SCALED_LINEAR_CUTLASS_SIZE_MNK = [ + (2, 512, 128), + (3, 2048, 2048), + (4, 3584, 640), + (13, 8704, 8576), + (26, 18944, 1664), + (67, 6656, 1408), +] + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@pytest.mark.parametrize("_, N, K", ROWWISE_SCALED_LINEAR_CUTLASS_SIZE_MNK) +def test_quant_dequant(_, N, K): + x = torch.randn(N, K).cuda() + qx, s = fp8_blockwise_weight_quant(x, block_size=128) + x_reconstructed = fp8_blockwise_weight_dequant(qx, s, block_size=128) + error = torch.norm(x - x_reconstructed) / torch.norm(x) + print(f"Relative Error: {error.item():.6f}") + + assert error < 0.05, "Quant-Dequant error is too high" + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@pytest.mark.parametrize("M, N, K", ROWWISE_SCALED_LINEAR_CUTLASS_SIZE_MNK) +def test_blockwise_fp8_gemm(M, N, K): + A = torch.randn(M, K).cuda() + B = torch.randn(N, K).cuda() + + C = A @ B.T + + A_q, A_s = fp8_blockwise_act_quant(A, block_size=128) + B_q, B_s = fp8_blockwise_weight_quant(B, block_size=128) + + C_q = blockwise_fp8_gemm(A_q, A_s, B_q, B_s) + print(C_q, C) + error = torch.norm(C - C_q) / torch.norm(C) + print(f"Relative Error: {error.item():.6f}") + + assert error < 0.05, "Quantize gemm error is too high" + + +# test_quant_dequant() +# test_blockwise_fp8_gemm() \ No newline at end of file diff --git a/torchao/float8/config.py b/torchao/float8/config.py index c7f32cd3fa..d76a0af88d 100644 --- a/torchao/float8/config.py +++ b/torchao/float8/config.py @@ -44,13 +44,17 @@ class ScalingGranularity(enum.Enum): # Scaling factors computed along one axis of the tensor, reducing it to # size 1. AXISWISE = "axiswise" + # Scaling factors computed along a block of the tensor + BLOCKWISE = "blockwise" def short_str(self): if self is ScalingGranularity.TENSORWISE: return "ten" - else: - assert self is ScalingGranularity.AXISWISE + elif self is ScalingGranularity.AXISWISE: return "axs" + else: + assert self is ScalingGranularity.BLOCKWISE + return "blk" @dataclass @@ -90,6 +94,7 @@ class CastConfig: scaling_type: ScalingType = ScalingType.DYNAMIC scaling_granularity: ScalingGranularity = ScalingGranularity.TENSORWISE + blockwise_size: Optional[int] = None static_scale: Optional[torch.Tensor] = None target_dtype: Optional[torch.dtype] = None @@ -106,6 +111,13 @@ def __post_init__(self): assert ( self.scaling_type is ScalingType.DYNAMIC ), "only dynamic scaling type is supported for axiswise scaling granularity" + if self.scaling_granularity is ScalingGranularity.BLOCKWISE: + assert ( + self.scaling_type is ScalingType.DISABLED + ), "blockwise scaling is not supported for disabled scaling type" + assert ( + self.blockwise_size is not None + ), "blockwise_size must be specified for blockwise scaling" assert self.target_dtype is None or ( self.target_dtype.is_floating_point and self.target_dtype.itemsize == 1 ), "must specify a 8-bit floating-point dtype" @@ -311,14 +323,20 @@ def __post_init__(self): class Float8LinearRecipeName(enum.Enum): ALL_TENSORWISE = "all_tensorwise" ALL_AXISWISE = "all_axiswise" + ALL_BLOCKWISE = "all_blockwise" LW_AXISWISE_WITH_GW_HP = "lw_axiswise_with_gw_hp" + LW_BLOCKWISE_WITH_GW_HP = "lw_blockwise_with_gw_hp" def recipe_name_to_linear_config( recipe_name: Float8LinearRecipeName, + blockwise_size: Optional[int] = None, ) -> Float8LinearConfig: """ - Input: `Float8LinearRecipeName` value + Input: + `Float8LinearRecipeName` value + `blockwise_size`: Optional[int] - if specified, blockwise scaling will be enabled with this size. + Output: a `Float8LinearConfig` configured to implement the recipe """ @@ -338,6 +356,30 @@ def recipe_name_to_linear_config( cast_config_grad_output=cc_go, ) + elif recipe_name is Float8LinearRecipeName.ALL_BLOCKWISE: + # dynamic blockwise scaling with the CUTLASS blockwise kernel + assert ( + blockwise_size is not None + ), "Blockwise scaling must be specified with blockwise_size" + cc_i = CastConfig( + scaling_granularity=ScalingGranularity.BLOCKWISE, + blockwise_size=blockwise_size, + ) + cc_w = CastConfig( + scaling_granularity=ScalingGranularity.BLOCKWISE, + blockwise_size=blockwise_size, + ) + cc_go = CastConfig( + scaling_granularity=ScalingGranularity.BLOCKWISE, + blockwise_size=blockwise_size, + ) + + return Float8LinearConfig( + cast_config_input=cc_i, + cast_config_weight=cc_w, + cast_config_grad_output=cc_go, + ) + elif recipe_name is Float8LinearRecipeName.LW_AXISWISE_WITH_GW_HP: # lw's recipe for a modification on all-axiswise: # @@ -377,5 +419,56 @@ def recipe_name_to_linear_config( cast_config_grad_output_for_grad_weight=cc_go_gw, ) + elif recipe_name is Float8LinearRecipeName.LW_BLOCKWISE_WITH_GW_HP: + # lw's recipe for a modification on all-blockwise: + # + # output_hp = input_fp8_blockwise @ weight_t_blockwise + # grad_input_hp = grad_output_fp8_blockwise @ weight_fp8_tensorwise + # grad_weight_hp = input_t_hp @ grad_output_hp + # + # key characteristics: + # * increased accuracy for grad_weight + # * `input`, `weight` and `grad_output` now only need to be scaled + # blockwise across all dims compared to vanilla all-blockwise, + # which is more amenable to fast kernels + # * the e4m3 dtype is used across the board, including for gradients + # + # output_hp = input_fp8_blockwise @ weight_t_blockwise + assert ( + blockwise_size is not None + ), "Blockwise scaling must be specified with blockwise_size" + + cc_i = CastConfig( + scaling_granularity=ScalingGranularity.BLOCKWISE, + blockwise_size=blockwise_size, + ) + cc_w = CastConfig( + scaling_granularity=ScalingGranularity.BLOCKWISE, + blockwise_size=blockwise_size, + ) + + # grad_input_hp = grad_output_fp8_blockwise @ weight_fp8_tensorwise + cc_go = CastConfig( + scaling_granularity=ScalingGranularity.BLOCKWISE, + blockwise_size=blockwise_size, + target_dtype=e4m3_dtype, + ) + cc_w_gi = CastConfig(scaling_granularity=ScalingGranularity.TENSORWISE) + + # grad_weight_hp = input_t_hp @ grad_output_hp + cc_i_gw = CastConfig(scaling_type=ScalingType.DISABLED) + cc_go_gw = CastConfig( + scaling_type=ScalingType.DISABLED, target_dtype=e4m3_dtype + ) + + return Float8LinearConfig( + cast_config_input=cc_i, + cast_config_weight=cc_w, + cast_config_grad_output=cc_go, + cast_config_input_for_grad_weight=cc_i_gw, + cast_config_weight_for_grad_input=cc_w_gi, + cast_config_grad_output_for_grad_weight=cc_go_gw, + ) + else: raise AssertionError(f"unknown recipe_name {recipe_name}") diff --git a/torchao/float8/float8_linear.py b/torchao/float8/float8_linear.py index 6b3c0f06df..13e2e0b015 100644 --- a/torchao/float8/float8_linear.py +++ b/torchao/float8/float8_linear.py @@ -16,6 +16,7 @@ from torchao.float8.distributed_utils import tensor_already_casted_to_fp8 from torchao.float8.float8_scaling_utils import ( get_maybe_axiswise_dim, + get_maybe_blockwise_size, hp_tensor_to_float8_dynamic, ) from torchao.float8.float8_tensor import ( @@ -96,6 +97,10 @@ def forward( axiswise_dim=get_maybe_axiswise_dim( -1, c.cast_config_input.scaling_granularity ), + blockwise_size=get_maybe_blockwise_size( + c.cast_config_input.blockwise_size, + c.cast_config_input.scaling_granularity, + ), ) if tensor_already_casted_to_fp8(weight_hp_t): @@ -112,6 +117,10 @@ def forward( axiswise_dim=get_maybe_axiswise_dim( 0, c.cast_config_weight.scaling_granularity ), + blockwise_size=get_maybe_blockwise_size( + c.cast_config_weight.blockwise_size, + c.cast_config_weight.scaling_granularity, + ), ) # the reshapes are needed in order to make the shapes compatible with @@ -151,6 +160,10 @@ def backward(ctx, grad_output): axiswise_dim=get_maybe_axiswise_dim( -1, c.cast_config_grad_output.scaling_granularity ), + blockwise_size=get_maybe_blockwise_size( + c.cast_config_grad_output.blockwise_size, + c.cast_config_grad_output.scaling_granularity, + ), ) if tensor_already_casted_to_fp8(weight_hp_t): @@ -159,9 +172,9 @@ def backward(ctx, grad_output): elif c.cast_config_weight_for_grad_input.scaling_type is ScalingType.DISABLED: weight_t_maybe_fp8_dim0 = weight_hp_t else: - if ( - c.cast_config_weight_for_grad_input.scaling_granularity - is ScalingGranularity.AXISWISE + if c.cast_config_weight_for_grad_input.scaling_granularity in ( + ScalingGranularity.AXISWISE, + ScalingGranularity.BLOCKWISE, ): # workaround from https://github.com/pytorch/pytorch/issues/141881 # to avoid saving float8 weight from forward to backward when @@ -181,6 +194,10 @@ def backward(ctx, grad_output): axiswise_dim=get_maybe_axiswise_dim( -1, c.cast_config_weight_for_grad_input.scaling_granularity ), + blockwise_size=get_maybe_blockwise_size( + c.cast_config_weight_for_grad_input.blockwise_size, + c.cast_config_weight_for_grad_input.scaling_granularity, + ), ) grad_input = torch.mm( @@ -216,6 +233,10 @@ def backward(ctx, grad_output): axiswise_dim=get_maybe_axiswise_dim( 0, c.cast_config_grad_output_for_grad_weight.scaling_granularity ), + blockwise_size=get_maybe_blockwise_size( + c.cast_config_grad_output_for_grad_weight.blockwise_size, + c.cast_config_grad_output_for_grad_weight.scaling_granularity, + ), ) if tensor_already_casted_to_fp8(input_hp_reshaped): @@ -233,6 +254,10 @@ def backward(ctx, grad_output): axiswise_dim=get_maybe_axiswise_dim( 0, c.cast_config_input_for_grad_weight.scaling_granularity ), + blockwise_size=get_maybe_blockwise_size( + c.cast_config_input_for_grad_weight.blockwise_size, + c.cast_config_input_for_grad_weight.scaling_granularity, + ), ) grad_weight = torch.mm( @@ -303,8 +328,9 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: autocast_dtype = torch.get_autocast_gpu_dtype() input = input.to(autocast_dtype) - has_any_axiswise_scaling = any( - cc.scaling_granularity is ScalingGranularity.AXISWISE + has_any_axiswise_or_blockwise_scaling = any( + cc.scaling_granularity + in (ScalingGranularity.AXISWISE, ScalingGranularity.BLOCKWISE) for cc in [ self.config.cast_config_input, self.config.cast_config_weight, @@ -319,7 +345,7 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: # TODO(future PR): check for axiswise scaling for input, weight, # grad_output separately instead of together - if not has_any_axiswise_scaling: + if not has_any_axiswise_or_blockwise_scaling: # If force_recompute_fp8_weight_in_bwd, we only recompute the fp8 weight, # weight_scale should be saved. weight_scale = _get_weight_scale( diff --git a/torchao/float8/float8_ops.py b/torchao/float8/float8_ops.py index 2af4160de4..0151c7686e 100644 --- a/torchao/float8/float8_ops.py +++ b/torchao/float8/float8_ops.py @@ -84,6 +84,9 @@ def float8_desugar_data_and_scale_op(aten_op, args, kwargs=None): ] ) def float8_transpose(aten_op, args, kwargs=None): + assert ( + args[0]._blockwise_size is None + ), "Transposition is not yet supported for blockwise fp8 quantized tensors." new_data = aten_op(args[0]._data, *args[1:], **kwargs) if args[0]._scale.ndim > 1: new_scale = aten_op(args[0]._scale, *args[1:], **kwargs) @@ -118,6 +121,9 @@ def float8_view(aten_op, args, kwargs=None): return float8_desugar_op(aten_op, args, kwargs) t, new_shape = args[0], args[1] + assert ( + t._blockwise_size is None + ), "View is not yet supported for blockwise fp8 quantized tensors." # for now, only support reshaping to [-1, dim] or [dim, -1] axiswise_dim = t._axiswise_dim if len(new_shape) == 2: @@ -253,6 +259,10 @@ def preprocess_addmm(a: Float8Tensor, b: Float8Tensor): b_data = b_data.t().contiguous().t() b_scale = b._scale + assert ( + a._blockwise_size == b._blockwise_size + ), "Blockwise sizes must match for tensors a and b." + # Today, torch._scaled_mm only supports both operands using the # same granularity. The code below checks for cases where one # operand is scaled axiswise and one tensorwise. If this case is found, diff --git a/torchao/float8/float8_scaling_utils.py b/torchao/float8/float8_scaling_utils.py index 0c27e4f3fc..1e0540e438 100644 --- a/torchao/float8/float8_scaling_utils.py +++ b/torchao/float8/float8_scaling_utils.py @@ -36,6 +36,7 @@ def hp_tensor_to_float8_dynamic( device_mesh=None, scaling_granularity: ScalingGranularity = ScalingGranularity.TENSORWISE, axiswise_dim: Optional[int] = None, + blockwise_size: Optional[int] = None, ) -> Float8Tensor: """ Given a high precision tensor `hp_tensor`, @@ -51,6 +52,7 @@ def hp_tensor_to_float8_dynamic( the 3 fwd/bwd gemms of linear scaling_granularity: Defines the scaling granularity axiswise_dim: if axiswise granularity is used, defines the dim to scale across + blockwise_size: if blockwise granularity is used, defines the block size """ scale = tensor_to_scale( hp_tensor, @@ -59,6 +61,7 @@ def hp_tensor_to_float8_dynamic( device_mesh, scaling_granularity, axiswise_dim, + blockwise_size, ) return hp_tensor_and_scale_to_float8( hp_tensor, @@ -67,6 +70,7 @@ def hp_tensor_to_float8_dynamic( linear_mm_config, gemm_input_role, axiswise_dim, + blockwise_size, ) @@ -151,6 +155,21 @@ def get_maybe_axiswise_dim( return None +def get_maybe_blockwise_size( + blockwise_size: int, + scaling_granularity: ScalingGranularity, +) -> Optional[int]: + """ + Convenience function which takes in an blockwise size which is only relevant + for blockwise scaing, and a scaling type. The output is pass-through + if scaling type is blockwise, and None otherwise. This is done to keep the + logic from choosing the blockwise size out of the scaling function. + """ + if scaling_granularity is ScalingGranularity.BLOCKWISE: + return blockwise_size + return None + + def _maybe_initialize_amaxes_scales_for_float8_cast( x, cur_amax, diff --git a/torchao/float8/float8_tensor.py b/torchao/float8/float8_tensor.py index fe2498e2b0..47db0a07b2 100644 --- a/torchao/float8/float8_tensor.py +++ b/torchao/float8/float8_tensor.py @@ -10,6 +10,7 @@ from torch.distributed._tensor import DTensor from torchao.float8.float8_utils import ( + blockify_tensor, to_fp8_saturated, ) @@ -136,6 +137,7 @@ def forward( linear_mm_config: Optional[LinearMMConfig] = None, gemm_input_role: Optional[GemmInputRole] = GemmInputRole.INPUT, axiswise_dim: Optional[int] = None, + blockwise_size: Optional[int] = None, ): """ This function will apply the scaling, and then convert to a Float8Tensor @@ -150,7 +152,11 @@ def forward( # Note: when the line below is compiled with `torch.compile`, `tensor` is automatically # upcasted to `float32` to multiply with the scale # In order to match numerics between eager and compile, we upcast manually here. - tensor_scaled = tensor.to(torch.float32) * scale + if blockwise_size: + tensor_scaled = blockify_tensor(tensor, blockwise_size) * scale + tensor_scaled = tensor_scaled.view(tensor.shape) + else: + tensor_scaled = tensor.to(torch.float32) * scale bits_fp8 = to_fp8_saturated(tensor_scaled, float8_dtype) if isinstance(bits_fp8, DTensor): @@ -168,6 +174,7 @@ def forward( linear_mm_config=linear_mm_config, gemm_input_role=gemm_input_role, axiswise_dim=axiswise_dim, + blockwise_size=blockwise_size, ) return DTensor.from_local( inner_float8_tensor, @@ -185,11 +192,12 @@ def forward( linear_mm_config=linear_mm_config, gemm_input_role=gemm_input_role, axiswise_dim=axiswise_dim, + blockwise_size=blockwise_size, ) @staticmethod def backward(ctx, g): - return g, None, None, None, None, None + return g, None, None, None, None, None, None @torch._dynamo.allow_in_graph @@ -202,7 +210,13 @@ class _FromFloat8ConstrFunc(torch.autograd.Function): @staticmethod def forward(ctx, tensor): - return tensor._data.to(tensor._orig_dtype) / tensor._scale + if tensor._blockwise_size: + t = tensor._data.to(tensor._orig_dtype) + return (blockify_tensor(t, tensor._blockwise_size) / tensor._scale).view( + tensor.shape + ) + else: + return tensor._data.to(tensor._orig_dtype) / tensor._scale @staticmethod def backward(ctx, g): @@ -216,6 +230,7 @@ def hp_tensor_and_scale_to_float8( linear_mm_config: Optional[LinearMMConfig] = None, gemm_input_role: Optional[GemmInputRole] = GemmInputRole.INPUT, axiswise_dim: Optional[int] = None, + blockwise_size: Optional[int] = None, ): """ Given a high precision tensor `hp_tensor` and a precalculated scale `s`, @@ -233,9 +248,16 @@ def hp_tensor_and_scale_to_float8( gemm_input_role: Defines the role of this tensor (input, weight or grad_output) in the 3 fwd/bwd gemms of linear axiswise_dim: for rowwise scaling, contains the axis scaled across + blockwise_size: for blockwise scaling, contains the block size """ return _ToFloat8ConstrFunc.apply( - hp_tensor, s, float8_dtype, linear_mm_config, gemm_input_role, axiswise_dim + hp_tensor, + s, + float8_dtype, + linear_mm_config, + gemm_input_role, + axiswise_dim, + blockwise_size, ) @@ -262,6 +284,8 @@ class Float8Tensor(torch.Tensor): tensor. * `_axiswise_dim`: for axiswise scaling only, contains the axis scales across. Only values of 0 or -1 are supported. + * `_blockwise_size`: for blockwise scaling only, contains the block size. + If None, the tensor is not blockwise scaled. Intended usage of this abstraction: 1. to bundle raw data + fp8 metadata together for easy passing through @@ -278,6 +302,7 @@ class Float8Tensor(torch.Tensor): _linear_mm_config: LinearMMConfig _gemm_input_role: GemmInputRole _axiswise_dim: Optional[int] + _blockwise_size: Optional[int] __slots__ = [ "_data", "_scale", @@ -285,6 +310,7 @@ class Float8Tensor(torch.Tensor): "_linear_mm_config", "_gemm_input_role", "_axiswise_dim", + "_blockwise_size", ] def __new__( @@ -295,6 +321,7 @@ def __new__( linear_mm_config: Optional[LinearMMConfig], gemm_input_role: Optional[GemmInputRole] = GemmInputRole.INPUT, axiswise_dim: Optional[int] = None, + blockwise_size: Optional[int] = None, ): self = torch.Tensor._make_wrapper_subclass( cls, @@ -315,11 +342,15 @@ def __new__( self._gemm_input_role = gemm_input_role assert axiswise_dim in (None, 0, -1), f"unsupported axiswise_dim {axiswise_dim}" self._axiswise_dim = axiswise_dim + assert ( + isinstance(blockwise_size, int) or blockwise_size is None + ), f"unsupported blockwise_size {blockwise_size}" + self._blockwise_size = blockwise_size return self def __repr__(self): - return f"Float8Tensor(dtype={self._data.dtype}, scale={self._scale}, linear_mm_config={self._linear_mm_config}, axiswise_dim={self._axiswise_dim}\ngemm_input_role={self._gemm_input_role}\nas_orig_prec={self.to_original_precision()}" + return f"Float8Tensor(dtype={self._data.dtype}, scale={self._scale}, linear_mm_config={self._linear_mm_config}, axiswise_dim={self._axiswise_dim}, blockwise_size={self._blockwise_size}\ngemm_input_role={self._gemm_input_role}\nas_orig_prec={self.to_original_precision()}" def __tensor_flatten__(self): ctx = { @@ -327,6 +358,7 @@ def __tensor_flatten__(self): "_linear_mm_config": self._linear_mm_config, "_gemm_input_role": self._gemm_input_role, "_axiswise_dim": self._axiswise_dim, + "_blockwise_size": self._blockwise_size, } return ["_data", "_scale"], ctx @@ -340,6 +372,7 @@ def __tensor_unflatten__(inner_tensors: Dict, metadata, outer_size, outer_stride metadata["_linear_mm_config"], metadata["_gemm_input_role"], metadata["_axiswise_dim"], + metadata["_blockwise_size"], ) def to_original_precision(self): diff --git a/torchao/float8/float8_utils.py b/torchao/float8/float8_utils.py index 6a93a612fa..2d4e606c0b 100644 --- a/torchao/float8/float8_utils.py +++ b/torchao/float8/float8_utils.py @@ -95,13 +95,23 @@ def tensor_to_amax( device_mesh=None, scaling_granularity: ScalingGranularity = ScalingGranularity.TENSORWISE, axiswise_dim: Optional[int] = None, + block_size: Optional[int] = None, ) -> torch.Tensor: if scaling_granularity is ScalingGranularity.TENSORWISE: amax = torch.max(torch.abs(x)) - else: - assert scaling_granularity is ScalingGranularity.AXISWISE, "unsupported" + elif scaling_granularity is ScalingGranularity.AXISWISE: assert axiswise_dim is not None, "unsupported" amax = torch.amax(torch.abs(x), dim=axiswise_dim, keepdim=True) + else: + assert scaling_granularity is ScalingGranularity.BLOCKWISE, "unsupported" + assert ( + block_size is not None + ), "block_size must be provided for BLOCKWISE scaling" + assert ( + x.shape[-1] % block_size == 0 + ), "x last dimension must be a multiple of block_size" + block_tensor = blockify_tensor(x, block_size) + amax = torch.amax(torch.abs(block_tensor), dim=-1, keepdim=True) # If the user asked for distributed reduction, do it. # If the user did not ask for it, assume that it will @@ -117,6 +127,48 @@ def tensor_to_amax( return amax +@torch.no_grad() +def blockify_tensor( + x: torch.Tensor, + block_size: int | torch.Tensor = 128, +) -> torch.Tensor: + """Blockify a tensor given a block_size for each dimension. + + Args: + x: The tensor to blockify. + block_size: The block size. + + Returns: + torch.Tensor: The blockified tensor. + """ + # This is suppose to give the implementation for multi-dimensional blockification + # but for now, this function only works for last dimension blockification + # TODO: implement blockification for multi-dimensional tensors + # dims = x.shape + # n = len(dims) + # if isinstance(block_size, int): + # ones = torch.ones(n - 1) + # block_size = torch.cat((ones, torch.Tensor([block_size]))) + # assert len(dims) == len( + # block_size + # ), "The tensor and the block sizes must have the same number of dimensions" + # assert all( + # d % b == 0 for d, b in zip(dims, block_size) + # ), "Each dimension of the tensor must be divisible by the corresponding block size" + # new_shape = torch.Tensor( + # [d // b for d, b in zip(dims, block_size)] + list(block_size) + # ).to(dtype=torch.int) + # perm = [ + # 2 * i - i // n * (2 * n - 1) for i in range(2 * n) + # ] # get a sequence of even numbers then odd (ex: [0, 2, 4, 1, 3, 5]) + # x = x.view(new_shape[perm].tolist()) + # x = x.permute(*perm) + # return x + block_shape = list(x.shape[:-1]) + [x.shape[-1] // block_size] + [block_size] + block_tensor = x.view(block_shape) + return block_tensor + + @torch.no_grad() def tensor_to_scale( x: torch.Tensor, @@ -125,6 +177,7 @@ def tensor_to_scale( device_mesh=None, scaling_granularity: ScalingGranularity = ScalingGranularity.TENSORWISE, axiswise_dim: Optional[int] = None, + blockwise_size: Optional[int] = None, ) -> torch.Tensor: amax = tensor_to_amax( x, @@ -132,6 +185,7 @@ def tensor_to_scale( device_mesh, scaling_granularity, axiswise_dim, + blockwise_size, ) return amax_to_scale(amax, float8_dtype) diff --git a/torchao/prototype/blockwise_fp8/blockwise_fp8_gemm_triton.py b/torchao/prototype/blockwise_fp8/blockwise_fp8_gemm_triton.py new file mode 100644 index 0000000000..2f97bdb115 --- /dev/null +++ b/torchao/prototype/blockwise_fp8/blockwise_fp8_gemm_triton.py @@ -0,0 +1,58 @@ +import torch +import triton +import triton.language as tl +from triton import Config + +fp8_gemm_configs = [ + Config({'BLOCK_SIZE_M': block_m, 'BLOCK_SIZE_N': block_n, 'BLOCK_SIZE_K': 128}, num_stages=num_stages, num_warps=8) + for block_m in [16, 32, 64] for block_n in [32, 64, 128] for num_stages in [3, 4, 5, 6] +] + +@triton.autotune(configs=fp8_gemm_configs, key=['N', 'K']) +@triton.jit +def blockwise_fp8_gemm_kernel(a_ptr, b_ptr, c_ptr, + a_s_ptr, b_s_ptr, + M, N: tl.constexpr, K: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr): + pid_m = tl.program_id(axis=0) + pid_n = tl.program_id(axis=1) + k = tl.cdiv(K, BLOCK_SIZE_K) + offs_m = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = a_ptr + offs_m[:, None] * K + offs_k[None, :] + b_ptrs = b_ptr + offs_n[None, :] * K + offs_k[:, None] + a_s_ptrs = a_s_ptr + offs_m * k + b_s_ptrs = b_s_ptr + (offs_n // BLOCK_SIZE_K) * k + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for i in range(k): + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - i * BLOCK_SIZE_K, other=0.0) + b = tl.load(b_ptrs, mask=offs_k[:, None] < K - i * BLOCK_SIZE_K, other=0.0) + a_s = tl.load(a_s_ptrs) + b_s = tl.load(b_s_ptrs) + accumulator += tl.dot(a, b) * a_s[:, None] * b_s[None, :] + a_ptrs += BLOCK_SIZE_K + b_ptrs += BLOCK_SIZE_K + a_s_ptrs += 1 + b_s_ptrs += 1 + c = accumulator.to(c_ptr.dtype.element_ty) + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + offs_m[:, None] * N + offs_n[None, :] + mask = (offs_m[:, None] < M) & (offs_n[None, :] < N) + tl.store(c_ptrs, c, mask=mask) + + +def blockwise_fp8_gemm(a: torch.Tensor, a_s: torch.Tensor, b: torch.Tensor, b_s: torch.Tensor): + assert a.is_contiguous() and b.is_contiguous() + assert a_s.is_contiguous() and b_s.is_contiguous() + K = a.size(-1) + M = a.numel() // K + N = b.size(0) + c = a.new_empty(*a.size()[:-1], N, dtype=torch.get_default_dtype()) + grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']), triton.cdiv(N, META['BLOCK_SIZE_N'])) + blockwise_fp8_gemm_kernel[grid](a, b, c, a_s, b_s, M, N, K) + return c \ No newline at end of file diff --git a/torchao/prototype/blockwise_fp8/blockwise_linear.py b/torchao/prototype/blockwise_fp8/blockwise_linear.py new file mode 100644 index 0000000000..ed909ffa07 --- /dev/null +++ b/torchao/prototype/blockwise_fp8/blockwise_linear.py @@ -0,0 +1,57 @@ +from typing import Optional + +import torch +from torch import nn + +from torchao.prototype.blockwise_fp8.blockwise_fp8_gemm_triton import blockwise_fp8_gemm +from torchao.prototype.blockwise_fp8.blockwise_quantization import ( + fp8_blockwise_act_quant, +) + + +def linear(x: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None, block_size = 128) -> torch.Tensor: + x, scale = fp8_blockwise_act_quant(x, block_size) + y = blockwise_fp8_gemm(x, scale, weight, weight.scale) + if bias is not None: + y += bias + return y + +class Linear(nn.Module): + """ + Custom linear layer with support for quantized weights and optional bias. + + Args: + in_features (int): Number of input features. + out_features (int): Number of output features. + bias (bool): Whether to include a bias term. Defaults to False. + block_size (int): Block size for quantization. Defaults to 128. + """ + dtype = torch.bfloat16 + + def __init__(self, in_features: int, out_features: int, bias: bool = False, block_size = 128, dtype = None): + super().__init__() + self.in_features = in_features + self.out_features = out_features + self.weight = nn.Parameter(torch.empty(out_features, in_features, dtype=dtype or Linear.dtype)) + if self.weight.element_size() == 1: + scale_out_features = (out_features + block_size - 1) // block_size + scale_in_features = (in_features + block_size - 1) // block_size + self.weight.scale = self.scale = nn.Parameter(torch.empty(scale_out_features, scale_in_features, dtype=torch.float32)) + else: + self.register_parameter("scale", None) + if bias: + self.bias = nn.Parameter(torch.empty(out_features)) + else: + self.register_parameter("bias", None) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Forward pass for the custom linear layer. + + Args: + x (torch.Tensor): Input tensor. + + Returns: + torch.Tensor: Transformed tensor after linear computation. + """ + return linear(x, self.weight, self.bias) \ No newline at end of file diff --git a/torchao/prototype/blockwise_fp8/blockwise_quantization.py b/torchao/prototype/blockwise_fp8/blockwise_quantization.py new file mode 100644 index 0000000000..04c21ba04f --- /dev/null +++ b/torchao/prototype/blockwise_fp8/blockwise_quantization.py @@ -0,0 +1,132 @@ +from typing import Tuple + +import torch +import triton +import triton.language as tl + + +@triton.jit +def fp8_blockwise_quant_act_kernel(x_ptr, y_ptr, s_ptr, BLOCK_SIZE: tl.constexpr): + """ + Quantizes the input tensor `x_ptr` and stores the result in `y_ptr` and the scaling factor in `s_ptr`. + + Args: + x_ptr (triton.Pointer): Pointer to the input tensor. + y_ptr (triton.Pointer): Pointer to the output tensor where quantized values will be stored. + s_ptr (triton.Pointer): Pointer to the output tensor where scaling factors will be stored. + BLOCK_SIZE (tl.constexpr): The size of the block to be processed by each program instance. + + Returns: + None + """ + pid = tl.program_id(axis=0) + offs = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + x = tl.load(x_ptr + offs).to(tl.float32) + s = tl.max(tl.abs(x)) / 448. + y = x / s + y = y.to(y_ptr.dtype.element_ty) + tl.store(y_ptr + offs, y) + tl.store(s_ptr + pid, s) + + +def fp8_blockwise_act_quant(x: torch.Tensor, block_size: int = 128) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Quantizes the input tensor `x` using block-wise quantization. + + Args: + x (torch.Tensor): The input tensor to be quantized. Must be contiguous and its last dimension size must be divisible by `block_size`. + block_size (int, optional): The size of the blocks to be used for quantization. Default is 128. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: A tuple containing: + - The quantized tensor with dtype `torch.float8_e4m3fn`. + - A tensor of scaling factors with dtype `torch.float32`. + """ + assert x.is_contiguous(), 'Input tensor must be contiguous' + assert x.size(-1) % block_size == 0, f'Last dimension size must be divisible by block_size (block_size={block_size})' + y = torch.empty_like(x, dtype=torch.float8_e4m3fn) + s = x.new_empty(*x.size()[:-1], x.size(-1) // block_size, dtype=torch.float32) + grid = lambda meta: (triton.cdiv(x.numel(), meta['BLOCK_SIZE']), ) + fp8_blockwise_quant_act_kernel[grid](x, y, s, BLOCK_SIZE=block_size) + return y, s + +@triton.jit +def fp8_blockwise_quant_weight_kernel(x_ptr, y_ptr, s_ptr, M, N, BLOCK_SIZE: tl.constexpr): + pid_m = tl.program_id(axis=0) + pid_n = tl.program_id(axis=1) + n = tl.cdiv(N, BLOCK_SIZE) + offs_m = pid_m * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + offs_n = pid_n * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + offs = offs_m[:, None] * N + offs_n[None, :] + mask = (offs_m[:, None] < M) & (offs_n[None, :] < N) + x = tl.load(x_ptr + offs, mask=mask).to(tl.float32) + s = tl.max(tl.abs(x)) / 448. + y = x / s + y = y.to(y_ptr.dtype.element_ty) + tl.store(y_ptr + offs, y, mask=mask) + tl.store(s_ptr + pid_m * n + pid_n, s) + + +def fp8_blockwise_weight_quant(x: torch.Tensor, block_size: int = 128): + assert x.is_contiguous(), 'Input tensor must be contiguous' + assert x.dim() == 2, 'Input tensor must have 2 dimensions' + assert x.size(0) % block_size == 0 and x.size(1) % block_size == 0, \ + f"Both dimensions of x must be divisible by block_size (block_size={block_size})" + M, N = x.size() + y = torch.empty_like(x, dtype=torch.float8_e4m3fn) + s = x.new_empty(M // block_size, N // block_size, dtype=torch.float32) + grid = lambda meta: (triton.cdiv(M, meta['BLOCK_SIZE']), triton.cdiv(N, meta['BLOCK_SIZE'])) + fp8_blockwise_quant_weight_kernel[grid](x, y, s, M, N, BLOCK_SIZE=block_size) + return y, s + +@triton.jit +def fp8_blockwise_dequant_weight_kernel(x_ptr, s_ptr, y_ptr, M, N, BLOCK_SIZE: tl.constexpr): + """ + Dequantizes weights using the provided scaling factors and stores the result. + + Args: + x_ptr (tl.pointer): Pointer to the quantized weights. + s_ptr (tl.pointer): Pointer to the scaling factors. + y_ptr (tl.pointer): Pointer to the output buffer for dequantized weights. + M (int): Number of rows in the weight matrix. + N (int): Number of columns in the weight matrix. + BLOCK_SIZE (tl.constexpr): Size of the block for tiling. + + Returns: + None + """ + pid_m = tl.program_id(axis=0) + pid_n = tl.program_id(axis=1) + n = tl.cdiv(N, BLOCK_SIZE) + offs_m = pid_m * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + offs_n = pid_n * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + offs = offs_m[:, None] * N + offs_n[None, :] + mask = (offs_m[:, None] < M) & (offs_n[None, :] < N) + x = tl.load(x_ptr + offs, mask=mask).to(tl.float32) + s = tl.load(s_ptr + pid_m * n + pid_n) + y = x * s + tl.store(y_ptr + offs, y, mask=mask) + + +def fp8_blockwise_weight_dequant(x: torch.Tensor, s: torch.Tensor, block_size: int = 128) -> torch.Tensor: + """ + Dequantizes the given weight tensor using the provided scale tensor. + + Args: + x (torch.Tensor): The quantized weight tensor of shape (M, N). + s (torch.Tensor): The scale tensor of shape (M, N). + block_size (int, optional): The block size to use for dequantization. Defaults to 128. + + Returns: + torch.Tensor: The dequantized weight tensor of the same shape as `x`. + + Raises: + AssertionError: If `x` or `s` are not contiguous or if their dimensions are not 2. + """ + assert x.is_contiguous() and s.is_contiguous(), 'Input tensors must be contiguous' + assert x.dim() == 2 and s.dim() == 2, 'Input tensors must have 2 dimensions' + M, N = x.size() + y = torch.empty_like(x, dtype=torch.get_default_dtype()) + grid = lambda meta: (triton.cdiv(M, meta['BLOCK_SIZE']), triton.cdiv(N, meta['BLOCK_SIZE'])) + fp8_blockwise_dequant_weight_kernel[grid](x, s, y, M, N, BLOCK_SIZE=block_size) + return y \ No newline at end of file