From 55dab5fd93e424d6edabecdec7387d72a73e34d1 Mon Sep 17 00:00:00 2001 From: "DESKTOP-7VDQ2GB\\Chris" Date: Sat, 1 Feb 2025 14:25:28 +0100 Subject: [PATCH 1/8] Feat: blockwise fp8 quantizer - first implementation of the DeepSeek blockwise quantizer (not fully fonctionnal) - amax has been unpdated - 2 more quantisation recipes has been added - a couple of things here and there to make it consistent --- torchao/float8/config.py | 65 +++++++++++++++++++++++++- torchao/float8/float8_scaling_utils.py | 18 +++++++ torchao/float8/float8_tensor.py | 20 +++++++- torchao/float8/float8_utils.py | 57 +++++++++++++++++++++- 4 files changed, 154 insertions(+), 6 deletions(-) diff --git a/torchao/float8/config.py b/torchao/float8/config.py index c7f32cd3fa..9de48353d9 100644 --- a/torchao/float8/config.py +++ b/torchao/float8/config.py @@ -44,13 +44,17 @@ class ScalingGranularity(enum.Enum): # Scaling factors computed along one axis of the tensor, reducing it to # size 1. AXISWISE = "axiswise" + # Scaling factors computed along a block of the tensor + BLOCKWISE = "blockwise" def short_str(self): if self is ScalingGranularity.TENSORWISE: return "ten" - else: - assert self is ScalingGranularity.AXISWISE + elif self is ScalingGranularity.AXISWISE: return "axs" + else: + assert self is ScalingGranularity.BLOCKWISE + return "blk" @dataclass @@ -106,6 +110,10 @@ def __post_init__(self): assert ( self.scaling_type is ScalingType.DYNAMIC ), "only dynamic scaling type is supported for axiswise scaling granularity" + if self.scaling_granularity is ScalingGranularity.BLOCKWISE: + assert ( + self.scaling_type is ScalingType.DISABLED + ), "blockwise scaling is not supported for disabled scaling type" assert self.target_dtype is None or ( self.target_dtype.is_floating_point and self.target_dtype.itemsize == 1 ), "must specify a 8-bit floating-point dtype" @@ -311,7 +319,9 @@ def __post_init__(self): class Float8LinearRecipeName(enum.Enum): ALL_TENSORWISE = "all_tensorwise" ALL_AXISWISE = "all_axiswise" + ALL_BLOCKWISE = "all_blockwise" LW_AXISWISE_WITH_GW_HP = "lw_axiswise_with_gw_hp" + LW_BLOCKWISE_WITH_GW_HP = "lw_blockwise_with_gw_hp" def recipe_name_to_linear_config( @@ -337,6 +347,18 @@ def recipe_name_to_linear_config( cast_config_weight=cc_w, cast_config_grad_output=cc_go, ) + + elif recipe_name is Float8LinearRecipeName.ALL_BLOCKWISE: + # dynamic blockwise scaling with the CUTLASS blockwise kernel + cc_i = CastConfig(scaling_granularity=ScalingGranularity.BLOCKWISE) + cc_w = CastConfig(scaling_granularity=ScalingGranularity.BLOCKWISE) + cc_go = CastConfig(scaling_granularity=ScalingGranularity.BLOCKWISE) + + return Float8LinearConfig( + cast_config_input=cc_i, + cast_config_weight=cc_w, + cast_config_grad_output=cc_go, + ) elif recipe_name is Float8LinearRecipeName.LW_AXISWISE_WITH_GW_HP: # lw's recipe for a modification on all-axiswise: @@ -376,6 +398,45 @@ def recipe_name_to_linear_config( cast_config_weight_for_grad_input=cc_w_gi, cast_config_grad_output_for_grad_weight=cc_go_gw, ) + + elif recipe_name is Float8LinearRecipeName.LW_BLOCKWISE_WITH_GW_HP: + # lw's recipe for a modification on all-blockwise: + # + # output_hp = input_fp8_blockwise @ weight_t_blockwise + # grad_input_hp = grad_output_fp8_blockwise @ weight_fp8_tensorwise + # grad_weight_hp = input_t_hp @ grad_output_hp + # + # key characteristics: + # * increased accuracy for grad_weight + # * `input`, `weight` and `grad_output` now only need to be scaled + # blockwise across all dims compared to vanilla all-blockwise, + # which is more amenable to fast kernels + # * the e4m3 dtype is used across the board, including for gradients + # + # output_hp = input_fp8_blockwise @ weight_t_blockwise + cc_i = CastConfig(scaling_granularity=ScalingGranularity.BLOCKWISE) + cc_w = CastConfig(scaling_granularity=ScalingGranularity.BLOCKWISE) + + # grad_input_hp = grad_output_fp8_blockwise @ weight_fp8_tensorwise + cc_go = CastConfig( + scaling_granularity=ScalingGranularity.BLOCKWISE, target_dtype=e4m3_dtype + ) + cc_w_gi = CastConfig(scaling_granularity=ScalingGranularity.TENSORWISE) + + # grad_weight_hp = input_t_hp @ grad_output_hp + cc_i_gw = CastConfig(scaling_type=ScalingType.DISABLED) + cc_go_gw = CastConfig( + scaling_type=ScalingType.DISABLED, target_dtype=e4m3_dtype + ) + + return Float8LinearConfig( + cast_config_input=cc_i, + cast_config_weight=cc_w, + cast_config_grad_output=cc_go, + cast_config_input_for_grad_weight=cc_i_gw, + cast_config_weight_for_grad_input=cc_w_gi, + cast_config_grad_output_for_grad_weight=cc_go_gw, + ) else: raise AssertionError(f"unknown recipe_name {recipe_name}") diff --git a/torchao/float8/float8_scaling_utils.py b/torchao/float8/float8_scaling_utils.py index 0c27e4f3fc..35a54400e7 100644 --- a/torchao/float8/float8_scaling_utils.py +++ b/torchao/float8/float8_scaling_utils.py @@ -36,6 +36,7 @@ def hp_tensor_to_float8_dynamic( device_mesh=None, scaling_granularity: ScalingGranularity = ScalingGranularity.TENSORWISE, axiswise_dim: Optional[int] = None, + blockwise_size: Optional[int] = None, ) -> Float8Tensor: """ Given a high precision tensor `hp_tensor`, @@ -51,6 +52,7 @@ def hp_tensor_to_float8_dynamic( the 3 fwd/bwd gemms of linear scaling_granularity: Defines the scaling granularity axiswise_dim: if axiswise granularity is used, defines the dim to scale across + blockwise_size: if blockwise granularity is used, defines the block size """ scale = tensor_to_scale( hp_tensor, @@ -59,6 +61,7 @@ def hp_tensor_to_float8_dynamic( device_mesh, scaling_granularity, axiswise_dim, + blockwise_size, ) return hp_tensor_and_scale_to_float8( hp_tensor, @@ -67,6 +70,7 @@ def hp_tensor_to_float8_dynamic( linear_mm_config, gemm_input_role, axiswise_dim, + blockwise_size, ) @@ -150,6 +154,20 @@ def get_maybe_axiswise_dim( return axiswise_dim return None +def get_maybe_blockwise_size( + blockwise_size: int, + scaling_granularity: ScalingGranularity, +) -> Optional[int]: + """ + Convenience function which takes in an blockwise size which is only relevant + for blockwise scaing, and a scaling type. The output is pass-through + if scaling type is blockwise, and None otherwise. This is done to keep the + logic from choosing the blockwise size out of the scaling function. + """ + if scaling_granularity is ScalingGranularity.BLOCKWISE: + return blockwise_size + return None + def _maybe_initialize_amaxes_scales_for_float8_cast( x, diff --git a/torchao/float8/float8_tensor.py b/torchao/float8/float8_tensor.py index fe2498e2b0..39d7753e85 100644 --- a/torchao/float8/float8_tensor.py +++ b/torchao/float8/float8_tensor.py @@ -10,6 +10,8 @@ from torch.distributed._tensor import DTensor from torchao.float8.float8_utils import ( + deblockify_tensor, + blockify_tensor, to_fp8_saturated, ) @@ -136,6 +138,7 @@ def forward( linear_mm_config: Optional[LinearMMConfig] = None, gemm_input_role: Optional[GemmInputRole] = GemmInputRole.INPUT, axiswise_dim: Optional[int] = None, + blockwise_size: Optional[int] = None, ): """ This function will apply the scaling, and then convert to a Float8Tensor @@ -168,6 +171,7 @@ def forward( linear_mm_config=linear_mm_config, gemm_input_role=gemm_input_role, axiswise_dim=axiswise_dim, + blockwise_size=blockwise_size, ) return DTensor.from_local( inner_float8_tensor, @@ -185,6 +189,7 @@ def forward( linear_mm_config=linear_mm_config, gemm_input_role=gemm_input_role, axiswise_dim=axiswise_dim, + blockwise_size=blockwise_size, ) @staticmethod @@ -216,6 +221,7 @@ def hp_tensor_and_scale_to_float8( linear_mm_config: Optional[LinearMMConfig] = None, gemm_input_role: Optional[GemmInputRole] = GemmInputRole.INPUT, axiswise_dim: Optional[int] = None, + blockwise_size: Optional[int] = None, ): """ Given a high precision tensor `hp_tensor` and a precalculated scale `s`, @@ -233,9 +239,10 @@ def hp_tensor_and_scale_to_float8( gemm_input_role: Defines the role of this tensor (input, weight or grad_output) in the 3 fwd/bwd gemms of linear axiswise_dim: for rowwise scaling, contains the axis scaled across + blockwise_size: for blockwise scaling, contains the block size """ return _ToFloat8ConstrFunc.apply( - hp_tensor, s, float8_dtype, linear_mm_config, gemm_input_role, axiswise_dim + hp_tensor, s, float8_dtype, linear_mm_config, gemm_input_role, axiswise_dim, blockwise_size ) @@ -262,6 +269,8 @@ class Float8Tensor(torch.Tensor): tensor. * `_axiswise_dim`: for axiswise scaling only, contains the axis scales across. Only values of 0 or -1 are supported. + * `_blockwise_size`: for blockwise scaling only, contains the block size. + If None, the tensor is not blockwise scaled. Intended usage of this abstraction: 1. to bundle raw data + fp8 metadata together for easy passing through @@ -278,6 +287,7 @@ class Float8Tensor(torch.Tensor): _linear_mm_config: LinearMMConfig _gemm_input_role: GemmInputRole _axiswise_dim: Optional[int] + _blockwise_size: Optional[int] __slots__ = [ "_data", "_scale", @@ -285,6 +295,7 @@ class Float8Tensor(torch.Tensor): "_linear_mm_config", "_gemm_input_role", "_axiswise_dim", + "_blockwise_size", ] def __new__( @@ -295,6 +306,7 @@ def __new__( linear_mm_config: Optional[LinearMMConfig], gemm_input_role: Optional[GemmInputRole] = GemmInputRole.INPUT, axiswise_dim: Optional[int] = None, + blockwise_size: Optional[int] = None ): self = torch.Tensor._make_wrapper_subclass( cls, @@ -315,11 +327,13 @@ def __new__( self._gemm_input_role = gemm_input_role assert axiswise_dim in (None, 0, -1), f"unsupported axiswise_dim {axiswise_dim}" self._axiswise_dim = axiswise_dim + assert isinstance(blockwise_size, int), f"unsupported blockwise_size {blockwise_size}" + self._blockwise_size = blockwise_size return self def __repr__(self): - return f"Float8Tensor(dtype={self._data.dtype}, scale={self._scale}, linear_mm_config={self._linear_mm_config}, axiswise_dim={self._axiswise_dim}\ngemm_input_role={self._gemm_input_role}\nas_orig_prec={self.to_original_precision()}" + return f"Float8Tensor(dtype={self._data.dtype}, scale={self._scale}, linear_mm_config={self._linear_mm_config}, axiswise_dim={self._axiswise_dim}, blockwise_size={self._blockwise_size}\ngemm_input_role={self._gemm_input_role}\nas_orig_prec={self.to_original_precision()}" def __tensor_flatten__(self): ctx = { @@ -327,6 +341,7 @@ def __tensor_flatten__(self): "_linear_mm_config": self._linear_mm_config, "_gemm_input_role": self._gemm_input_role, "_axiswise_dim": self._axiswise_dim, + "_blockwise_size": self._blockwise_size, } return ["_data", "_scale"], ctx @@ -340,6 +355,7 @@ def __tensor_unflatten__(inner_tensors: Dict, metadata, outer_size, outer_stride metadata["_linear_mm_config"], metadata["_gemm_input_role"], metadata["_axiswise_dim"], + metadata["_blockwise_size"], ) def to_original_precision(self): diff --git a/torchao/float8/float8_utils.py b/torchao/float8/float8_utils.py index 6a93a612fa..c880b0f96a 100644 --- a/torchao/float8/float8_utils.py +++ b/torchao/float8/float8_utils.py @@ -95,13 +95,21 @@ def tensor_to_amax( device_mesh=None, scaling_granularity: ScalingGranularity = ScalingGranularity.TENSORWISE, axiswise_dim: Optional[int] = None, + block_size: Optional[int] = None, ) -> torch.Tensor: if scaling_granularity is ScalingGranularity.TENSORWISE: amax = torch.max(torch.abs(x)) - else: - assert scaling_granularity is ScalingGranularity.AXISWISE, "unsupported" + elif scaling_granularity is ScalingGranularity.AXISWISE: assert axiswise_dim is not None, "unsupported" amax = torch.amax(torch.abs(x), dim=axiswise_dim, keepdim=True) + else: + assert scaling_granularity is ScalingGranularity.BLOCKWISE, "unsupported" + assert block_size is not None, "block_size must be provided for BLOCKWISE scaling" + assert x.shape[-1] % block_size == 0, "x last dimension must be a multiple of block_size" + block_shape = list(x.shape[:-1]) + [x.shape[-1]//block_size] + [block_size] + block_tensor = x.view(block_shape) + amax = torch.amax(torch.abs(block_tensor), dim=-1) + amax.repeat_interleave(block_size, dim=-1) # If the user asked for distributed reduction, do it. # If the user did not ask for it, assume that it will @@ -116,6 +124,49 @@ def tensor_to_amax( return amax +@torch.no_grad() +def blockify_tensor( + x: torch.Tensor, + block_size: int | torch.Tensor = 128, +) -> torch.Tensor: + """Blockify a tensor given a block_size for each dimension. + + Args: + x: The tensor to blockify. + block_size: The block size. + + Returns: + torch.Tensor: The blockified tensor. + """ + dims = x.shape + n = len(dims) + if isinstance(block_size, int): + ones = torch.ones(n-1) + block_size = torch.cat((ones, torch.Tensor([block_size]))) + assert len(dims) == len(block_size), "The tensor and the block sizes must have the same number of dimensions" + assert all(d % b == 0 for d, b in zip(dims, block_size)), "Each dimension of the tensor must be divisible by the corresponding block size" + new_shape = torch.Tensor([d // b for d, b in zip(dims, block_size)] + list(block_size)).to(dtype=torch.int) + perm = [2*i - i//n*(2*n-1) for i in range(2*n)] # get a sequence of even numbers then odd (ex: [0, 2, 4, 1, 3, 5]) + x = x.view(new_shape[perm].tolist()) + x = x.permute(*perm) + return x + +@torch.no_grad() +def deblockify_tensor( + x: torch.Tensor, + block_size: int | torch.Tensor = 128, +) -> torch.Tensor: + """Unblockify a tensor given a block_size for each dimension. + + Args: + x: The tensor to unblockify. + block_size: The block size. + + Returns: + torch.Tensor: The unblockified tensor. + """ + pass + @torch.no_grad() def tensor_to_scale( @@ -125,6 +176,7 @@ def tensor_to_scale( device_mesh=None, scaling_granularity: ScalingGranularity = ScalingGranularity.TENSORWISE, axiswise_dim: Optional[int] = None, + blockwise_size: Optional[int] = None, ) -> torch.Tensor: amax = tensor_to_amax( x, @@ -132,6 +184,7 @@ def tensor_to_scale( device_mesh, scaling_granularity, axiswise_dim, + blockwise_size, ) return amax_to_scale(amax, float8_dtype) From 5ab1eb2c83226cdc133c7e67ef1fdc6554d538b3 Mon Sep 17 00:00:00 2001 From: "DESKTOP-7VDQ2GB\\Chris" Date: Sat, 1 Feb 2025 17:40:57 +0100 Subject: [PATCH 2/8] Feat: fp8 linear layer with blockwise quantization --- torchao/float8/config.py | 25 ++++++++++++++++++------- torchao/float8/float8_linear.py | 30 ++++++++++++++++++++++++++++-- 2 files changed, 46 insertions(+), 9 deletions(-) diff --git a/torchao/float8/config.py b/torchao/float8/config.py index 9de48353d9..d081556c39 100644 --- a/torchao/float8/config.py +++ b/torchao/float8/config.py @@ -94,6 +94,7 @@ class CastConfig: scaling_type: ScalingType = ScalingType.DYNAMIC scaling_granularity: ScalingGranularity = ScalingGranularity.TENSORWISE + blockwise_size: Optional[int] = None static_scale: Optional[torch.Tensor] = None target_dtype: Optional[torch.dtype] = None @@ -114,6 +115,9 @@ def __post_init__(self): assert ( self.scaling_type is ScalingType.DISABLED ), "blockwise scaling is not supported for disabled scaling type" + assert ( + self.blockwise_size is not None + ), "blockwise_size must be specified for blockwise scaling" assert self.target_dtype is None or ( self.target_dtype.is_floating_point and self.target_dtype.itemsize == 1 ), "must specify a 8-bit floating-point dtype" @@ -326,9 +330,13 @@ class Float8LinearRecipeName(enum.Enum): def recipe_name_to_linear_config( recipe_name: Float8LinearRecipeName, + blockwise_size: Optional[int] = None, ) -> Float8LinearConfig: """ - Input: `Float8LinearRecipeName` value + Input: + `Float8LinearRecipeName` value + `blockwise_size`: Optional[int] - if specified, blockwise scaling will be enabled with this size. + Output: a `Float8LinearConfig` configured to implement the recipe """ @@ -350,9 +358,10 @@ def recipe_name_to_linear_config( elif recipe_name is Float8LinearRecipeName.ALL_BLOCKWISE: # dynamic blockwise scaling with the CUTLASS blockwise kernel - cc_i = CastConfig(scaling_granularity=ScalingGranularity.BLOCKWISE) - cc_w = CastConfig(scaling_granularity=ScalingGranularity.BLOCKWISE) - cc_go = CastConfig(scaling_granularity=ScalingGranularity.BLOCKWISE) + assert blockwise_size is not None, "Blockwise scaling must be specified with blockwise_size" + cc_i = CastConfig(scaling_granularity=ScalingGranularity.BLOCKWISE, blockwise_size=blockwise_size) + cc_w = CastConfig(scaling_granularity=ScalingGranularity.BLOCKWISE, blockwise_size=blockwise_size) + cc_go = CastConfig(scaling_granularity=ScalingGranularity.BLOCKWISE, blockwise_size=blockwise_size) return Float8LinearConfig( cast_config_input=cc_i, @@ -414,12 +423,14 @@ def recipe_name_to_linear_config( # * the e4m3 dtype is used across the board, including for gradients # # output_hp = input_fp8_blockwise @ weight_t_blockwise - cc_i = CastConfig(scaling_granularity=ScalingGranularity.BLOCKWISE) - cc_w = CastConfig(scaling_granularity=ScalingGranularity.BLOCKWISE) + assert blockwise_size is not None, "Blockwise scaling must be specified with blockwise_size" + + cc_i = CastConfig(scaling_granularity=ScalingGranularity.BLOCKWISE, blockwise_size=blockwise_size) + cc_w = CastConfig(scaling_granularity=ScalingGranularity.BLOCKWISE, blockwise_size=blockwise_size) # grad_input_hp = grad_output_fp8_blockwise @ weight_fp8_tensorwise cc_go = CastConfig( - scaling_granularity=ScalingGranularity.BLOCKWISE, target_dtype=e4m3_dtype + scaling_granularity=ScalingGranularity.BLOCKWISE, blockwise_size=blockwise_size, target_dtype=e4m3_dtype ) cc_w_gi = CastConfig(scaling_granularity=ScalingGranularity.TENSORWISE) diff --git a/torchao/float8/float8_linear.py b/torchao/float8/float8_linear.py index 6b3c0f06df..24fff37e52 100644 --- a/torchao/float8/float8_linear.py +++ b/torchao/float8/float8_linear.py @@ -16,6 +16,7 @@ from torchao.float8.distributed_utils import tensor_already_casted_to_fp8 from torchao.float8.float8_scaling_utils import ( get_maybe_axiswise_dim, + get_maybe_blockwise_size, hp_tensor_to_float8_dynamic, ) from torchao.float8.float8_tensor import ( @@ -96,6 +97,9 @@ def forward( axiswise_dim=get_maybe_axiswise_dim( -1, c.cast_config_input.scaling_granularity ), + blockwise_size=get_maybe_blockwise_size( + c.cast_config_input.blockwise_size, c.cast_config_input.scaling_granularity + ), ) if tensor_already_casted_to_fp8(weight_hp_t): @@ -112,6 +116,9 @@ def forward( axiswise_dim=get_maybe_axiswise_dim( 0, c.cast_config_weight.scaling_granularity ), + blockwise_size=get_maybe_blockwise_size( + c.cast_config_weight.blockwise_size, c.cast_config_weight.scaling_granularity + ), ) # the reshapes are needed in order to make the shapes compatible with @@ -151,6 +158,10 @@ def backward(ctx, grad_output): axiswise_dim=get_maybe_axiswise_dim( -1, c.cast_config_grad_output.scaling_granularity ), + blockwise_size=get_maybe_blockwise_size( + c.cast_config_grad_output.blockwise_size, + c.cast_config_grad_output.scaling_granularity, + ), ) if tensor_already_casted_to_fp8(weight_hp_t): @@ -161,7 +172,10 @@ def backward(ctx, grad_output): else: if ( c.cast_config_weight_for_grad_input.scaling_granularity - is ScalingGranularity.AXISWISE + in ( + ScalingGranularity.AXISWISE, + ScalingGranularity.BLOCKWISE, + ) ): # workaround from https://github.com/pytorch/pytorch/issues/141881 # to avoid saving float8 weight from forward to backward when @@ -181,6 +195,10 @@ def backward(ctx, grad_output): axiswise_dim=get_maybe_axiswise_dim( -1, c.cast_config_weight_for_grad_input.scaling_granularity ), + blockwise_size=get_maybe_blockwise_size( + c.cast_config_weight_for_grad_input.blockwise_size, + c.cast_config_weight_for_grad_input.scaling_granularity, + ), ) grad_input = torch.mm( @@ -216,6 +234,10 @@ def backward(ctx, grad_output): axiswise_dim=get_maybe_axiswise_dim( 0, c.cast_config_grad_output_for_grad_weight.scaling_granularity ), + blockwise_size=get_maybe_blockwise_size( + c.cast_config_grad_output_for_grad_weight.blockwise_size, + c.cast_config_grad_output_for_grad_weight.scaling_granularity, + ), ) if tensor_already_casted_to_fp8(input_hp_reshaped): @@ -233,6 +255,10 @@ def backward(ctx, grad_output): axiswise_dim=get_maybe_axiswise_dim( 0, c.cast_config_input_for_grad_weight.scaling_granularity ), + blockwise_size=get_maybe_blockwise_size( + c.cast_config_input_for_grad_weight.blockwise_size, + c.cast_config_input_for_grad_weight.scaling_granularity, + ), ) grad_weight = torch.mm( @@ -304,7 +330,7 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: input = input.to(autocast_dtype) has_any_axiswise_scaling = any( - cc.scaling_granularity is ScalingGranularity.AXISWISE + cc.scaling_granularity in (ScalingGranularity.AXISWISE, ScalingGranularity.BLOCKWISE) for cc in [ self.config.cast_config_input, self.config.cast_config_weight, From aa7dc878a0efa1ba0b082ba1046acb72ed6e16cd Mon Sep 17 00:00:00 2001 From: "DESKTOP-7VDQ2GB\\Chris" Date: Tue, 4 Feb 2025 19:49:22 +0100 Subject: [PATCH 3/8] Feat: adding assertions in the ops file --- torchao/float8/float8_linear.py | 4 ++-- torchao/float8/float8_ops.py | 4 ++++ 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/torchao/float8/float8_linear.py b/torchao/float8/float8_linear.py index 24fff37e52..0c642103d2 100644 --- a/torchao/float8/float8_linear.py +++ b/torchao/float8/float8_linear.py @@ -329,7 +329,7 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: autocast_dtype = torch.get_autocast_gpu_dtype() input = input.to(autocast_dtype) - has_any_axiswise_scaling = any( + has_any_axiswise_or_blockwise_scaling = any( cc.scaling_granularity in (ScalingGranularity.AXISWISE, ScalingGranularity.BLOCKWISE) for cc in [ self.config.cast_config_input, @@ -345,7 +345,7 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: # TODO(future PR): check for axiswise scaling for input, weight, # grad_output separately instead of together - if not has_any_axiswise_scaling: + if not has_any_axiswise_or_blockwise_scaling: # If force_recompute_fp8_weight_in_bwd, we only recompute the fp8 weight, # weight_scale should be saved. weight_scale = _get_weight_scale( diff --git a/torchao/float8/float8_ops.py b/torchao/float8/float8_ops.py index 2af4160de4..7b12bd934d 100644 --- a/torchao/float8/float8_ops.py +++ b/torchao/float8/float8_ops.py @@ -84,6 +84,7 @@ def float8_desugar_data_and_scale_op(aten_op, args, kwargs=None): ] ) def float8_transpose(aten_op, args, kwargs=None): + assert args[0]._blockwise_size is None, "Transposition is not yet supported for blockwise fp8 quantized tensors." new_data = aten_op(args[0]._data, *args[1:], **kwargs) if args[0]._scale.ndim > 1: new_scale = aten_op(args[0]._scale, *args[1:], **kwargs) @@ -118,6 +119,7 @@ def float8_view(aten_op, args, kwargs=None): return float8_desugar_op(aten_op, args, kwargs) t, new_shape = args[0], args[1] + assert t._blockwise_size is None, "View is not yet supported for blockwise fp8 quantized tensors." # for now, only support reshaping to [-1, dim] or [dim, -1] axiswise_dim = t._axiswise_dim if len(new_shape) == 2: @@ -253,6 +255,8 @@ def preprocess_addmm(a: Float8Tensor, b: Float8Tensor): b_data = b_data.t().contiguous().t() b_scale = b._scale + assert a._blockwise_size == b._blockwise_size, "Blockwise sizes must match for tensors a and b." + # Today, torch._scaled_mm only supports both operands using the # same granularity. The code below checks for cases where one # operand is scaled axiswise and one tensorwise. If this case is found, From 167fdce888cb7d73f7852d062781002bf1bc7240 Mon Sep 17 00:00:00 2001 From: "DESKTOP-7VDQ2GB\\Chris" Date: Wed, 5 Feb 2025 09:42:06 +0100 Subject: [PATCH 4/8] Feat: adding some tests for blockwise fp8 quant --- benchmarks/float8/bench_matmul.py | 7 ++- benchmarks/float8/float8_roofline.py | 25 ++++++++-- test/float8/test_base.py | 60 ++++++++++++++++++++++++ test/float8/test_compile.py | 2 + test/float8/test_numerics_integration.py | 2 + 5 files changed, 90 insertions(+), 6 deletions(-) diff --git a/benchmarks/float8/bench_matmul.py b/benchmarks/float8/bench_matmul.py index 3d48853754..53ce02aacb 100644 --- a/benchmarks/float8/bench_matmul.py +++ b/benchmarks/float8/bench_matmul.py @@ -124,10 +124,13 @@ def run( if scaling_granularity == ScalingGranularity.TENSORWISE: scale_a = torch.tensor([1.0], device=device) scale_b = torch.tensor([1.0], device=device) - else: - assert scaling_granularity == ScalingGranularity.AXISWISE, "unsupported" + elif scaling_granularity == ScalingGranularity.AXISWISE: scale_a = torch.ones(M, 1, device=device) scale_b = torch.ones(1, N, device=device) + else: + assert scaling_granularity == ScalingGranularity.BLOCKWISE, "unsupported" + scale_a = torch.ones(M, N, device=device) + scale_b = torch.ones(M, N, device=device) def do_matmul(A, B): nonlocal scale_a diff --git a/benchmarks/float8/float8_roofline.py b/benchmarks/float8/float8_roofline.py index 2b3f631d8c..0ea90aefcb 100644 --- a/benchmarks/float8/float8_roofline.py +++ b/benchmarks/float8/float8_roofline.py @@ -354,14 +354,30 @@ def run( m_fp8_dyn_axs = torch.compile(m_fp8_dyn_axs) fp8_dyn_axs_time_actual_s = get_gpu_kernel_time(m_fp8_dyn_axs, x) - # get the lw recipe scaling gpu kernel time + # get the float8 dynamic blockwise scaling gpu kernel time + torch._dynamo.reset() + config = recipe_name_to_linear_config(Float8LinearRecipeName.ALL_BLOCKWISE) + m_fp8_dyn_blk = convert_to_float8_training(copy.deepcopy(m_orig), config=config) + m_fp8_dyn_blk = torch.compile(m_fp8_dyn_blk) + fp8_dyn_blk_time_actual_s = get_gpu_kernel_time(m_fp8_dyn_blk, x) + + # get the lw_axs recipe scaling gpu kernel time # TODO(future PR): enable below once basic performance issues # are fixed # torch._dynamo.reset() # config = recipe_name_to_linear_config(Float8LinearRecipeName.LW_AXISWISE_WITH_GW_HP) - # m_fp8_lw = convert_to_float8_training(m_orig, config=config) - # m_fp8_lw = torch.compile(m_fp8_lw) - # fp8_lw_time_actual_s = get_gpu_kernel_time(m_fp8_lw, x) + # m_fp8_lw_axs = convert_to_float8_training(m_orig, config=config) + # m_fp8_lw_axs = torch.compile(m_fp8_lw_axs) + # fp8_lw_axs_time_actual_s = get_gpu_kernel_time(m_fp8_lw_axs, x) + + # get the lw_blk recipe scaling gpu kernel time + # TODO(future PR): enable below once basic performance issues + # are fixed + # torch._dynamo.reset() + # config = recipe_name_to_linear_config(Float8LinearRecipeName.LW_BLOCKWISE_WITH_GW_HP) + # m_fp8_lw_blk = convert_to_float8_training(m_orig, config=config) + # m_fp8_lw_blk = torch.compile(m_fp8_lw_blk) + # fp8_lw_blk_time_actual_s = get_gpu_kernel_time(m_fp8_lw_blk, x) results.append( [ @@ -382,6 +398,7 @@ def run( fp8_dyn_time_actual_s, fp8_del_time_actual_s, fp8_dyn_axs_time_actual_s, + fp8_dyn_blk_time_actual_s, # fp8_lw_time_actual_s, bf16_time_actual_s / fp8_dyn_time_actual_s, bf16_time_actual_s / fp8_del_time_actual_s, diff --git a/test/float8/test_base.py b/test/float8/test_base.py index 3e894c02b9..a5ed8d16d9 100644 --- a/test/float8/test_base.py +++ b/test/float8/test_base.py @@ -43,6 +43,7 @@ from torchao.float8.float8_python_api import addmm_float8_unwrapped from torchao.float8.float8_scaling_utils import ( get_maybe_axiswise_dim, + get_maybe_blockwise_size, hp_tensor_to_float8_dynamic, ) from torchao.float8.float8_tensor import ( @@ -178,6 +179,22 @@ def test_axiswise_dynamic_cast(self, shape, axiswise_dim): sqnr = compute_error(a, a_dq) assert sqnr >= 25.0 + @pytest.mark.parametrize("shape", [(8, 16), (4, 8, 16), (2, 4, 8, 16)]) + @pytest.mark.parametrize("blockwise_size", [4]) + def test_blockwise_dynamic_cast(self, shape, blockwise_size): + a = torch.randn(*shape, dtype=torch.bfloat16) + linear_mm_config = LinearMMConfig() + a_fp8 = hp_tensor_to_float8_dynamic( + a, + e4m3_dtype, + linear_mm_config, + scaling_granularity=ScalingGranularity.BLOCKWISE, + blockwise_size=blockwise_size, + ) + a_dq = a_fp8.to_original_precision() + sqnr = compute_error(a, a_dq) + assert sqnr >= 25.0 + def test_axiswise_reshape(self): a = torch.randn(3, 5, 7, dtype=torch.bfloat16) linear_mm_config = LinearMMConfig() @@ -272,6 +289,47 @@ def test_axiswise_gemm(self, a_shape, a_granularity, b_granularity): sqnr = compute_error(c_ref, c_fp8_compute) assert sqnr >= 25.0 + @pytest.mark.parametrize("a_shape", [(16, 32), (2, 16, 32), (1, 2, 16, 32)]) + @pytest.mark.parametrize( + "a_granularity,b_granularity", + [ + (ScalingGranularity.BLOCKWISE, ScalingGranularity.BLOCKWISE), + (ScalingGranularity.BLOCKWISE, ScalingGranularity.TENSORWISE), + (ScalingGranularity.TENSORWISE, ScalingGranularity.BLOCKWISE), + ], + ) + @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") + @unittest.skipIf(not is_sm_at_least_90(), "Requires CUDA capability >= 9.0") + def test_blockwise_gemm(self, a_shape, a_granularity, b_granularity): + a = torch.randn(*a_shape, dtype=torch.bfloat16, device="cuda") + b = torch.randn(64, 32, dtype=torch.bfloat16, device="cuda") + + linear_mm_config = LinearMMConfig() + + a_fp8 = hp_tensor_to_float8_dynamic( + a, + e4m3_dtype, + linear_mm_config, + gemm_input_role=GemmInputRole.INPUT, + scaling_granularity=a_granularity, + blockwise_size=get_maybe_blockwise_size(8, a_granularity), + ) + a_fp8 = a_fp8.reshape(-1, a_shape[-1]) + + b_fp8 = hp_tensor_to_float8_dynamic( + b, + e4m3_dtype, + linear_mm_config, + gemm_input_role=GemmInputRole.WEIGHT, + scaling_granularity=b_granularity, + blockwise_size=get_maybe_blockwise_size(8, b_granularity), + ) + + c_fp8_compute = torch.mm(a_fp8, b_fp8.t()) + a = a.reshape(-1, a_shape[-1]) + c_ref = torch.mm(a, b.t()) + sqnr = compute_error(c_ref, c_fp8_compute) + assert sqnr >= 25.0 class TestFloat8Linear: def _test_linear_impl( @@ -417,7 +475,9 @@ def test_linear_from_config_params( "recipe_name", [ Float8LinearRecipeName.ALL_AXISWISE, + Float8LinearRecipeName.ALL_BLOCKWISE, Float8LinearRecipeName.LW_AXISWISE_WITH_GW_HP, + Float8LinearRecipeName.LW_BLOCKWISE_WITH_GW_HP, ], ) @pytest.mark.parametrize("x_shape", [(16, 16), (2, 16, 16), (3, 2, 16, 16)]) diff --git a/test/float8/test_compile.py b/test/float8/test_compile.py index c42ab8ee77..c21dd456fe 100644 --- a/test/float8/test_compile.py +++ b/test/float8/test_compile.py @@ -223,7 +223,9 @@ def test_inductor_from_config_params( "recipe_name", [ Float8LinearRecipeName.ALL_AXISWISE, + Float8LinearRecipeName.ALL_BLOCKWISE, Float8LinearRecipeName.LW_AXISWISE_WITH_GW_HP, + Float8LinearRecipeName.LW_BLOCKWISE_WITH_GW_HP, ], ) @unittest.skipIf( diff --git a/test/float8/test_numerics_integration.py b/test/float8/test_numerics_integration.py index 311964d831..01a8822294 100644 --- a/test/float8/test_numerics_integration.py +++ b/test/float8/test_numerics_integration.py @@ -199,7 +199,9 @@ def test_encoder_fw_bw_from_config_params( "recipe_name", [ Float8LinearRecipeName.ALL_AXISWISE, + Float8LinearRecipeName.ALL_BLOCKWISE, Float8LinearRecipeName.LW_AXISWISE_WITH_GW_HP, + Float8LinearRecipeName.LW_BLOCKWISE_WITH_GW_HP, ], ) @pytest.mark.skipif( From 9e9d16ea3186e1205465e6a5c7600f515aa6d646 Mon Sep 17 00:00:00 2001 From: "DESKTOP-7VDQ2GB\\Chris" Date: Wed, 5 Feb 2025 14:16:00 +0100 Subject: [PATCH 5/8] Fix: fixes for the blockwise_fp8_quantization --- test/float8/test_base.py | 1 + torchao/float8/config.py | 49 ++++++++++++------ torchao/float8/float8_linear.py | 18 +++---- torchao/float8/float8_ops.py | 12 +++-- torchao/float8/float8_scaling_utils.py | 1 + torchao/float8/float8_tensor.py | 29 ++++++++--- torchao/float8/float8_utils.py | 69 +++++++++++++------------- torchao/quantization/quant_api.py | 4 +- 8 files changed, 115 insertions(+), 68 deletions(-) diff --git a/test/float8/test_base.py b/test/float8/test_base.py index a5ed8d16d9..7fabe98b6e 100644 --- a/test/float8/test_base.py +++ b/test/float8/test_base.py @@ -331,6 +331,7 @@ def test_blockwise_gemm(self, a_shape, a_granularity, b_granularity): sqnr = compute_error(c_ref, c_fp8_compute) assert sqnr >= 25.0 + class TestFloat8Linear: def _test_linear_impl( self, diff --git a/torchao/float8/config.py b/torchao/float8/config.py index d081556c39..d76a0af88d 100644 --- a/torchao/float8/config.py +++ b/torchao/float8/config.py @@ -333,7 +333,7 @@ def recipe_name_to_linear_config( blockwise_size: Optional[int] = None, ) -> Float8LinearConfig: """ - Input: + Input: `Float8LinearRecipeName` value `blockwise_size`: Optional[int] - if specified, blockwise scaling will be enabled with this size. @@ -355,13 +355,24 @@ def recipe_name_to_linear_config( cast_config_weight=cc_w, cast_config_grad_output=cc_go, ) - + elif recipe_name is Float8LinearRecipeName.ALL_BLOCKWISE: # dynamic blockwise scaling with the CUTLASS blockwise kernel - assert blockwise_size is not None, "Blockwise scaling must be specified with blockwise_size" - cc_i = CastConfig(scaling_granularity=ScalingGranularity.BLOCKWISE, blockwise_size=blockwise_size) - cc_w = CastConfig(scaling_granularity=ScalingGranularity.BLOCKWISE, blockwise_size=blockwise_size) - cc_go = CastConfig(scaling_granularity=ScalingGranularity.BLOCKWISE, blockwise_size=blockwise_size) + assert ( + blockwise_size is not None + ), "Blockwise scaling must be specified with blockwise_size" + cc_i = CastConfig( + scaling_granularity=ScalingGranularity.BLOCKWISE, + blockwise_size=blockwise_size, + ) + cc_w = CastConfig( + scaling_granularity=ScalingGranularity.BLOCKWISE, + blockwise_size=blockwise_size, + ) + cc_go = CastConfig( + scaling_granularity=ScalingGranularity.BLOCKWISE, + blockwise_size=blockwise_size, + ) return Float8LinearConfig( cast_config_input=cc_i, @@ -407,7 +418,7 @@ def recipe_name_to_linear_config( cast_config_weight_for_grad_input=cc_w_gi, cast_config_grad_output_for_grad_weight=cc_go_gw, ) - + elif recipe_name is Float8LinearRecipeName.LW_BLOCKWISE_WITH_GW_HP: # lw's recipe for a modification on all-blockwise: # @@ -423,23 +434,33 @@ def recipe_name_to_linear_config( # * the e4m3 dtype is used across the board, including for gradients # # output_hp = input_fp8_blockwise @ weight_t_blockwise - assert blockwise_size is not None, "Blockwise scaling must be specified with blockwise_size" + assert ( + blockwise_size is not None + ), "Blockwise scaling must be specified with blockwise_size" + + cc_i = CastConfig( + scaling_granularity=ScalingGranularity.BLOCKWISE, + blockwise_size=blockwise_size, + ) + cc_w = CastConfig( + scaling_granularity=ScalingGranularity.BLOCKWISE, + blockwise_size=blockwise_size, + ) - cc_i = CastConfig(scaling_granularity=ScalingGranularity.BLOCKWISE, blockwise_size=blockwise_size) - cc_w = CastConfig(scaling_granularity=ScalingGranularity.BLOCKWISE, blockwise_size=blockwise_size) - # grad_input_hp = grad_output_fp8_blockwise @ weight_fp8_tensorwise cc_go = CastConfig( - scaling_granularity=ScalingGranularity.BLOCKWISE, blockwise_size=blockwise_size, target_dtype=e4m3_dtype + scaling_granularity=ScalingGranularity.BLOCKWISE, + blockwise_size=blockwise_size, + target_dtype=e4m3_dtype, ) cc_w_gi = CastConfig(scaling_granularity=ScalingGranularity.TENSORWISE) - + # grad_weight_hp = input_t_hp @ grad_output_hp cc_i_gw = CastConfig(scaling_type=ScalingType.DISABLED) cc_go_gw = CastConfig( scaling_type=ScalingType.DISABLED, target_dtype=e4m3_dtype ) - + return Float8LinearConfig( cast_config_input=cc_i, cast_config_weight=cc_w, diff --git a/torchao/float8/float8_linear.py b/torchao/float8/float8_linear.py index 0c642103d2..13e2e0b015 100644 --- a/torchao/float8/float8_linear.py +++ b/torchao/float8/float8_linear.py @@ -98,7 +98,8 @@ def forward( -1, c.cast_config_input.scaling_granularity ), blockwise_size=get_maybe_blockwise_size( - c.cast_config_input.blockwise_size, c.cast_config_input.scaling_granularity + c.cast_config_input.blockwise_size, + c.cast_config_input.scaling_granularity, ), ) @@ -117,7 +118,8 @@ def forward( 0, c.cast_config_weight.scaling_granularity ), blockwise_size=get_maybe_blockwise_size( - c.cast_config_weight.blockwise_size, c.cast_config_weight.scaling_granularity + c.cast_config_weight.blockwise_size, + c.cast_config_weight.scaling_granularity, ), ) @@ -170,12 +172,9 @@ def backward(ctx, grad_output): elif c.cast_config_weight_for_grad_input.scaling_type is ScalingType.DISABLED: weight_t_maybe_fp8_dim0 = weight_hp_t else: - if ( - c.cast_config_weight_for_grad_input.scaling_granularity - in ( - ScalingGranularity.AXISWISE, - ScalingGranularity.BLOCKWISE, - ) + if c.cast_config_weight_for_grad_input.scaling_granularity in ( + ScalingGranularity.AXISWISE, + ScalingGranularity.BLOCKWISE, ): # workaround from https://github.com/pytorch/pytorch/issues/141881 # to avoid saving float8 weight from forward to backward when @@ -330,7 +329,8 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: input = input.to(autocast_dtype) has_any_axiswise_or_blockwise_scaling = any( - cc.scaling_granularity in (ScalingGranularity.AXISWISE, ScalingGranularity.BLOCKWISE) + cc.scaling_granularity + in (ScalingGranularity.AXISWISE, ScalingGranularity.BLOCKWISE) for cc in [ self.config.cast_config_input, self.config.cast_config_weight, diff --git a/torchao/float8/float8_ops.py b/torchao/float8/float8_ops.py index 7b12bd934d..0151c7686e 100644 --- a/torchao/float8/float8_ops.py +++ b/torchao/float8/float8_ops.py @@ -84,7 +84,9 @@ def float8_desugar_data_and_scale_op(aten_op, args, kwargs=None): ] ) def float8_transpose(aten_op, args, kwargs=None): - assert args[0]._blockwise_size is None, "Transposition is not yet supported for blockwise fp8 quantized tensors." + assert ( + args[0]._blockwise_size is None + ), "Transposition is not yet supported for blockwise fp8 quantized tensors." new_data = aten_op(args[0]._data, *args[1:], **kwargs) if args[0]._scale.ndim > 1: new_scale = aten_op(args[0]._scale, *args[1:], **kwargs) @@ -119,7 +121,9 @@ def float8_view(aten_op, args, kwargs=None): return float8_desugar_op(aten_op, args, kwargs) t, new_shape = args[0], args[1] - assert t._blockwise_size is None, "View is not yet supported for blockwise fp8 quantized tensors." + assert ( + t._blockwise_size is None + ), "View is not yet supported for blockwise fp8 quantized tensors." # for now, only support reshaping to [-1, dim] or [dim, -1] axiswise_dim = t._axiswise_dim if len(new_shape) == 2: @@ -255,7 +259,9 @@ def preprocess_addmm(a: Float8Tensor, b: Float8Tensor): b_data = b_data.t().contiguous().t() b_scale = b._scale - assert a._blockwise_size == b._blockwise_size, "Blockwise sizes must match for tensors a and b." + assert ( + a._blockwise_size == b._blockwise_size + ), "Blockwise sizes must match for tensors a and b." # Today, torch._scaled_mm only supports both operands using the # same granularity. The code below checks for cases where one diff --git a/torchao/float8/float8_scaling_utils.py b/torchao/float8/float8_scaling_utils.py index 35a54400e7..1e0540e438 100644 --- a/torchao/float8/float8_scaling_utils.py +++ b/torchao/float8/float8_scaling_utils.py @@ -154,6 +154,7 @@ def get_maybe_axiswise_dim( return axiswise_dim return None + def get_maybe_blockwise_size( blockwise_size: int, scaling_granularity: ScalingGranularity, diff --git a/torchao/float8/float8_tensor.py b/torchao/float8/float8_tensor.py index 39d7753e85..452adb61d2 100644 --- a/torchao/float8/float8_tensor.py +++ b/torchao/float8/float8_tensor.py @@ -10,7 +10,6 @@ from torch.distributed._tensor import DTensor from torchao.float8.float8_utils import ( - deblockify_tensor, blockify_tensor, to_fp8_saturated, ) @@ -153,7 +152,11 @@ def forward( # Note: when the line below is compiled with `torch.compile`, `tensor` is automatically # upcasted to `float32` to multiply with the scale # In order to match numerics between eager and compile, we upcast manually here. - tensor_scaled = tensor.to(torch.float32) * scale + if blockwise_size: + tensor_scaled = blockify_tensor(tensor, blockwise_size) * scale + tensor_scaled = tensor_scaled.view(tensor.shape) + else: + tensor_scaled = tensor.to(torch.float32) * scale bits_fp8 = to_fp8_saturated(tensor_scaled, float8_dtype) if isinstance(bits_fp8, DTensor): @@ -194,7 +197,7 @@ def forward( @staticmethod def backward(ctx, g): - return g, None, None, None, None, None + return g, None, None, None, None, None, None @torch._dynamo.allow_in_graph @@ -207,7 +210,11 @@ class _FromFloat8ConstrFunc(torch.autograd.Function): @staticmethod def forward(ctx, tensor): - return tensor._data.to(tensor._orig_dtype) / tensor._scale + if tensor._blockwise_size: + t = tensor._data.to(tensor._orig_dtype) + return (blockify_tensor(t, tensor._blockwise_size) / tensor._scale).view(tensor.shape) + else: + return tensor._data.to(tensor._orig_dtype) / tensor._scale @staticmethod def backward(ctx, g): @@ -242,7 +249,13 @@ def hp_tensor_and_scale_to_float8( blockwise_size: for blockwise scaling, contains the block size """ return _ToFloat8ConstrFunc.apply( - hp_tensor, s, float8_dtype, linear_mm_config, gemm_input_role, axiswise_dim, blockwise_size + hp_tensor, + s, + float8_dtype, + linear_mm_config, + gemm_input_role, + axiswise_dim, + blockwise_size, ) @@ -306,7 +319,7 @@ def __new__( linear_mm_config: Optional[LinearMMConfig], gemm_input_role: Optional[GemmInputRole] = GemmInputRole.INPUT, axiswise_dim: Optional[int] = None, - blockwise_size: Optional[int] = None + blockwise_size: Optional[int] = None, ): self = torch.Tensor._make_wrapper_subclass( cls, @@ -327,7 +340,9 @@ def __new__( self._gemm_input_role = gemm_input_role assert axiswise_dim in (None, 0, -1), f"unsupported axiswise_dim {axiswise_dim}" self._axiswise_dim = axiswise_dim - assert isinstance(blockwise_size, int), f"unsupported blockwise_size {blockwise_size}" + assert isinstance( + blockwise_size, int + ) or blockwise_size is None, f"unsupported blockwise_size {blockwise_size}" self._blockwise_size = blockwise_size return self diff --git a/torchao/float8/float8_utils.py b/torchao/float8/float8_utils.py index c880b0f96a..2d4e606c0b 100644 --- a/torchao/float8/float8_utils.py +++ b/torchao/float8/float8_utils.py @@ -104,12 +104,14 @@ def tensor_to_amax( amax = torch.amax(torch.abs(x), dim=axiswise_dim, keepdim=True) else: assert scaling_granularity is ScalingGranularity.BLOCKWISE, "unsupported" - assert block_size is not None, "block_size must be provided for BLOCKWISE scaling" - assert x.shape[-1] % block_size == 0, "x last dimension must be a multiple of block_size" - block_shape = list(x.shape[:-1]) + [x.shape[-1]//block_size] + [block_size] - block_tensor = x.view(block_shape) - amax = torch.amax(torch.abs(block_tensor), dim=-1) - amax.repeat_interleave(block_size, dim=-1) + assert ( + block_size is not None + ), "block_size must be provided for BLOCKWISE scaling" + assert ( + x.shape[-1] % block_size == 0 + ), "x last dimension must be a multiple of block_size" + block_tensor = blockify_tensor(x, block_size) + amax = torch.amax(torch.abs(block_tensor), dim=-1, keepdim=True) # If the user asked for distributed reduction, do it. # If the user did not ask for it, assume that it will @@ -124,6 +126,7 @@ def tensor_to_amax( return amax + @torch.no_grad() def blockify_tensor( x: torch.Tensor, @@ -138,34 +141,32 @@ def blockify_tensor( Returns: torch.Tensor: The blockified tensor. """ - dims = x.shape - n = len(dims) - if isinstance(block_size, int): - ones = torch.ones(n-1) - block_size = torch.cat((ones, torch.Tensor([block_size]))) - assert len(dims) == len(block_size), "The tensor and the block sizes must have the same number of dimensions" - assert all(d % b == 0 for d, b in zip(dims, block_size)), "Each dimension of the tensor must be divisible by the corresponding block size" - new_shape = torch.Tensor([d // b for d, b in zip(dims, block_size)] + list(block_size)).to(dtype=torch.int) - perm = [2*i - i//n*(2*n-1) for i in range(2*n)] # get a sequence of even numbers then odd (ex: [0, 2, 4, 1, 3, 5]) - x = x.view(new_shape[perm].tolist()) - x = x.permute(*perm) - return x - -@torch.no_grad() -def deblockify_tensor( - x: torch.Tensor, - block_size: int | torch.Tensor = 128, -) -> torch.Tensor: - """Unblockify a tensor given a block_size for each dimension. - - Args: - x: The tensor to unblockify. - block_size: The block size. - - Returns: - torch.Tensor: The unblockified tensor. - """ - pass + # This is suppose to give the implementation for multi-dimensional blockification + # but for now, this function only works for last dimension blockification + # TODO: implement blockification for multi-dimensional tensors + # dims = x.shape + # n = len(dims) + # if isinstance(block_size, int): + # ones = torch.ones(n - 1) + # block_size = torch.cat((ones, torch.Tensor([block_size]))) + # assert len(dims) == len( + # block_size + # ), "The tensor and the block sizes must have the same number of dimensions" + # assert all( + # d % b == 0 for d, b in zip(dims, block_size) + # ), "Each dimension of the tensor must be divisible by the corresponding block size" + # new_shape = torch.Tensor( + # [d // b for d, b in zip(dims, block_size)] + list(block_size) + # ).to(dtype=torch.int) + # perm = [ + # 2 * i - i // n * (2 * n - 1) for i in range(2 * n) + # ] # get a sequence of even numbers then odd (ex: [0, 2, 4, 1, 3, 5]) + # x = x.view(new_shape[perm].tolist()) + # x = x.permute(*perm) + # return x + block_shape = list(x.shape[:-1]) + [x.shape[-1] // block_size] + [block_size] + block_tensor = x.view(block_shape) + return block_tensor @torch.no_grad() diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index bbe9b1cb6b..7154957a21 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -450,7 +450,9 @@ def _linear_extra_repr(self): return f"in_features={self.weight.shape[1]}, out_features={self.weight.shape[0]}, weight={_quantization_type(self.weight)}" -def _get_linear_subclass_inserter(constructor, *, allow_requires_grad=False, propagate_bias=False, **kwargs): +def _get_linear_subclass_inserter( + constructor, *, allow_requires_grad=False, propagate_bias=False, **kwargs +): """Helper function to apply the constructor that quantizes the weight Tensor (with additional kwargs) to the weight of linear module """ From 6c9246a969e79ad70b23c476a281a4fb126493c1 Mon Sep 17 00:00:00 2001 From: "DESKTOP-7VDQ2GB\\Chris" Date: Wed, 5 Feb 2025 16:01:17 +0100 Subject: [PATCH 6/8] linting --- torchao/float8/float8_tensor.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/torchao/float8/float8_tensor.py b/torchao/float8/float8_tensor.py index 452adb61d2..47db0a07b2 100644 --- a/torchao/float8/float8_tensor.py +++ b/torchao/float8/float8_tensor.py @@ -212,7 +212,9 @@ class _FromFloat8ConstrFunc(torch.autograd.Function): def forward(ctx, tensor): if tensor._blockwise_size: t = tensor._data.to(tensor._orig_dtype) - return (blockify_tensor(t, tensor._blockwise_size) / tensor._scale).view(tensor.shape) + return (blockify_tensor(t, tensor._blockwise_size) / tensor._scale).view( + tensor.shape + ) else: return tensor._data.to(tensor._orig_dtype) / tensor._scale @@ -340,9 +342,9 @@ def __new__( self._gemm_input_role = gemm_input_role assert axiswise_dim in (None, 0, -1), f"unsupported axiswise_dim {axiswise_dim}" self._axiswise_dim = axiswise_dim - assert isinstance( - blockwise_size, int - ) or blockwise_size is None, f"unsupported blockwise_size {blockwise_size}" + assert ( + isinstance(blockwise_size, int) or blockwise_size is None + ), f"unsupported blockwise_size {blockwise_size}" self._blockwise_size = blockwise_size return self From 89c6ed05a61c78dcd07b1e14deeed5edfe1fd7ba Mon Sep 17 00:00:00 2001 From: Degnel Date: Fri, 21 Feb 2025 09:07:41 +0100 Subject: [PATCH 7/8] Feat/test: quant/dequant weight/act + test --- ...enchmark_blockwise_scaled_linear_triton.py | 121 +++++++++++++++ benchmarks/test.ipynb | 144 ++++++++++++++++++ .../blockwise_fp8_gemm_triton.py | 58 +++++++ .../blockwise_fp8/blockwise_linear.py | 52 +++++++ .../blockwise_fp8/blockwise_quantization.py | 131 ++++++++++++++++ 5 files changed, 506 insertions(+) create mode 100644 benchmarks/benchmark_blockwise_scaled_linear_triton.py create mode 100644 benchmarks/test.ipynb create mode 100644 torchao/prototype/blockwise_fp8/blockwise_fp8_gemm_triton.py create mode 100644 torchao/prototype/blockwise_fp8/blockwise_linear.py create mode 100644 torchao/prototype/blockwise_fp8/blockwise_quantization.py diff --git a/benchmarks/benchmark_blockwise_scaled_linear_triton.py b/benchmarks/benchmark_blockwise_scaled_linear_triton.py new file mode 100644 index 0000000000..07786e3c70 --- /dev/null +++ b/benchmarks/benchmark_blockwise_scaled_linear_triton.py @@ -0,0 +1,121 @@ +import pandas as pd +import torch +from tqdm import tqdm +from triton.testing import do_bench + +from torchao.prototype.blockwise_fp8.blockwise_fp8_gemm_triton import blockwise_fp8_gemm +from torchao.ops import rowwise_scaled_linear_cutlass_s8s4 + + +def benchmark_microseconds(f, *args): + return do_bench(lambda: f(*args), return_mode="median") * 1e3 + +def get_rowwise_problem(m: int, n: int, k: int, A_nbits: int, B_nbits: int): + assert A_nbits in (4, 8) and B_nbits in (4, 8) + + dev = torch.device("cuda") + A = torch.randint(-128, 127, (m, k * A_nbits // 8), dtype=torch.int8, device=dev) + A_scale = torch.randn((m,), dtype=torch.half, device=dev) + B = torch.randint( + -128, 127, size=(n, k * B_nbits // 8), dtype=torch.int8, device=dev + ) + B_scale = torch.randn((n,), dtype=torch.half, device=dev) + C = None + + return A, A_scale, B, B_scale, C + +def get_blockwise_problem(m: int, n: int, k: int, block_size: int): + assert n % block_size == 0 and k % block_size == 0, "N and K dims must be divisible by block_size" + dev = torch.device("cuda") + A = (448.0 * (2 * torch.rand(m, k, device=dev) - 1)).to(torch.float8_e4m3fn) + A_scale = torch.randn((m, k // block_size), dtype=torch.half, device=dev) + B = (448.0 * (2 * torch.rand(n, k, device=dev) - 1)).to(torch.float8_e4m3fn) + B_scale = torch.randn((n // block_size, k // block_size), dtype=torch.half, device=dev) + + return A, A_scale, B, B_scale + +def benchmark(m: int, k: int, n: int, block_size: int): + dev = torch.device("cuda") + A_ref = torch.randn((m, k), dtype=torch.half, device=dev) + B_ref = torch.randn((n, k), dtype=torch.half, device=dev) + fp16_time = benchmark_microseconds(torch.nn.functional.linear, A_ref, B_ref) + + A, A_scale, B, B_scale, C = get_rowwise_problem(m, n, k, 8, 8) + rowwise_scaled_linear_cutlass_s8s4_time = benchmark_microseconds( + rowwise_scaled_linear_cutlass_s8s4, A, A_scale, B, B_scale, C + ) + + A, A_scale, B, B_scale = get_blockwise_problem(m, n, k, block_size) + blockwise_fp8_gemm_time = benchmark_microseconds( + blockwise_fp8_gemm, A, A_scale, B, B_scale + ) + + # Add precision tests + # On prend 2 sets de matrices aléatoires + # On les quantise en int8/int4 rowwise + # On les quantise en en float8 blockwise + # + + + return { + "m": m, + "k": k, + "n": n, + "fp16_latency (ms)": fp16_time, + "rowwise_scaled_linear_cutlass_s8s4 latency (ms)": rowwise_scaled_linear_cutlass_s8s4_time, + "rowwise s8s4 speedup (d/s)": fp16_time / rowwise_scaled_linear_cutlass_s8s4_time, + "blockwise_fp8_gemm latency (ms)": blockwise_fp8_gemm_time, + "blockwise fp8 speedup (d/s)": fp16_time / blockwise_fp8_gemm_time, + } + + +from torchao.prototype.blockwise_fp8.blockwise_quantization import fp8_blockwise_weight_quant, fp8_blockwise_act_quant, fp8_blockwise_weight_dequant +from torchao.prototype.blockwise_fp8.blockwise_fp8_gemm_triton import blockwise_fp8_gemm + +def test_quant_dequant(): + torch.manual_seed(0) + x = torch.randn(256, 256).cuda() + qx, s = fp8_blockwise_weight_quant(x, block_size=128) + x_reconstructed = fp8_blockwise_weight_dequant(qx, s, block_size=128) + + error = torch.norm(x - x_reconstructed) / torch.norm(x) + print(f"Relative Error: {error.item():.6f}") + + assert error < 0.05, "Quant-Dequant error too high!" + +def test_blockwise_fp8_gemm(): + torch.manual_seed(0) + M, N, K = 256, 256, 128 + A = torch.randn(M, K).cuda() + B = torch.randn(N, K).cuda() + + C = A @ B.T + + A_q, A_s = fp8_blockwise_act_quant(A, block_size=128) + B_q, B_s = fp8_blockwise_weight_quant(B, block_size=128) + + C_q = blockwise_fp8_gemm(A_q, A_s, B_q, B_s) + + error = torch.norm(C - C_q) / torch.norm(C) + print(f"Relative Error: {error.item():.6f}") + + assert error < 0.05, "Quantized GEMM error is too high!" + + +test_quant_dequant() +test_blockwise_fp8_gemm() + + +if __name__ == "__main__": + k_vals = (8192, 8192, 8192, 28672) + n_vals = (8192, 10240, 57344, 8192) + block_size_vals = (128, 128, 128, 128) + + results = [] + for m in tqdm([1 << i for i in range(10)]): + for n, k, block_size in zip(n_vals, k_vals, block_size_vals): + results.append(benchmark(m, k, n, block_size)) + + df = pd.DataFrame(results) + df.to_csv("blockwise_scaled_linear_triton_time_results.csv", index=False) + print(df.to_markdown(index=False)) \ No newline at end of file diff --git a/benchmarks/test.ipynb b/benchmarks/test.ipynb new file mode 100644 index 0000000000..a979b3d6cd --- /dev/null +++ b/benchmarks/test.ipynb @@ -0,0 +1,144 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "ename": "TypeError", + "evalue": "AffineQuantizedTensor.from_hp_to_intx_static() missing 1 required positional argument: 'target_dtype'", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[34], line 31\u001b[0m\n\u001b[1;32m 29\u001b[0m block_size \u001b[38;5;241m=\u001b[39m [\u001b[38;5;241m1\u001b[39m, \u001b[38;5;241m0\u001b[39m]\n\u001b[1;32m 30\u001b[0m target_dtype \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39mint8\n\u001b[0;32m---> 31\u001b[0m qa \u001b[38;5;241m=\u001b[39m \u001b[43mAffineQuantizedTensor\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfrom_hp_to_intx_static\u001b[49m\u001b[43m(\u001b[49m\u001b[43mA\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mMappingType\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mASYMMETRIC\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mblock_size\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtarget_dtype\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mquant_min\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m-\u001b[39;49m\u001b[38;5;241;43m8\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mquant_max\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m7\u001b[39;49m\u001b[43m)\u001b[49m\n\u001b[1;32m 32\u001b[0m block_size \u001b[38;5;241m=\u001b[39m [\u001b[38;5;241m0\u001b[39m, \u001b[38;5;241m1\u001b[39m]\n\u001b[1;32m 33\u001b[0m qb \u001b[38;5;241m=\u001b[39m AffineQuantizedTensor\u001b[38;5;241m.\u001b[39mfrom_hp_to_intx_static(B, MappingType\u001b[38;5;241m.\u001b[39mASYMMETRIC, block_size, target_dtype)\n", + "\u001b[0;31mTypeError\u001b[0m: AffineQuantizedTensor.from_hp_to_intx_static() missing 1 required positional argument: 'target_dtype'" + ] + } + ], + "source": [ + "import sys\n", + "import os\n", + "parent_dir = os.path.abspath(os.path.join(os.getcwd(), '..'))\n", + "\n", + "if parent_dir not in sys.path:\n", + " sys.path.append(parent_dir)\n", + "\n", + "from torchao.dtypes.affine_quantized_tensor import AffineQuantizedTensor\n", + "import torch\n", + "from torchao.quantization.quant_primitives import MappingType\n", + "from torchao.ops import rowwise_scaled_linear_cutlass_s8s4\n", + "\n", + "A = torch.Tensor([\n", + " [1.0, 2.0, 3.0],\n", + " [4.0, 5.0, 6.0],\n", + " [7.0, 8.0, 9.0]]\n", + ")\n", + "\n", + "B = torch.Tensor([\n", + " [-2.0, 1.0],\n", + " [3.0, -4.0],\n", + " [6.0, -8.0]]\n", + ")\n", + "\n", + "C = torch.Tensor(\n", + " [20.0, -18.0]\n", + ")\n", + "\n", + "block_size = [1, 0]\n", + "target_dtype = torch.int8\n", + "qa = AffineQuantizedTensor.from_hp_to_intx(A, MappingType.ASYMMETRIC, block_size, target_dtype, quant_min=-8, quant_max=7)\n", + "block_size = [0, 1]\n", + "qb = AffineQuantizedTensor.from_hp_to_intx(B, MappingType.ASYMMETRIC, block_size, target_dtype)\n", + "\n", + "print(qa)\n", + "print(qb)\n", + "\n", + "print(rowwise_scaled_linear_cutlass_s8s4(qa.tensor_impl.data, qa.tensor_impl.scale, qb.tensor_impl.data, qb.tensor_impl.scale, None))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "ename": "CompilationError", + "evalue": "at 1:0:\ndef fp8_blockwise_quant_weight_kernel(x_ptr, y_ptr, s_ptr, M, N, BLOCK_SIZE: tl.constexpr):\n^\nValueError(\"type fp8e4nv not supported in this architecture. The supported fp8 dtypes are ('fp8e4b15', 'fp8e5')\")", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mCompilationError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[36], line 15\u001b[0m\n\u001b[1;32m 10\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mRelative Error: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00merror\u001b[38;5;241m.\u001b[39mitem()\u001b[38;5;132;01m:\u001b[39;00m\u001b[38;5;124m.6f\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 12\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m error \u001b[38;5;241m<\u001b[39m \u001b[38;5;241m0.05\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mQuant-Dequant error too high!\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[0;32m---> 15\u001b[0m \u001b[43mtest_quant_dequant\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n", + "Cell \u001b[0;32mIn[36], line 6\u001b[0m, in \u001b[0;36mtest_quant_dequant\u001b[0;34m()\u001b[0m\n\u001b[1;32m 4\u001b[0m torch\u001b[38;5;241m.\u001b[39mmanual_seed(\u001b[38;5;241m0\u001b[39m)\n\u001b[1;32m 5\u001b[0m x \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39mrandn(\u001b[38;5;241m256\u001b[39m, \u001b[38;5;241m256\u001b[39m)\n\u001b[0;32m----> 6\u001b[0m qx, s \u001b[38;5;241m=\u001b[39m \u001b[43mfp8_blockwise_weight_quant\u001b[49m\u001b[43m(\u001b[49m\u001b[43mx\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mblock_size\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m128\u001b[39;49m\u001b[43m)\u001b[49m\n\u001b[1;32m 7\u001b[0m x_reconstructed \u001b[38;5;241m=\u001b[39m fp8_blockwise_weight_dequant(qx, s, block_size\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m128\u001b[39m)\n\u001b[1;32m 9\u001b[0m error \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39mnorm(x \u001b[38;5;241m-\u001b[39m x_reconstructed) \u001b[38;5;241m/\u001b[39m torch\u001b[38;5;241m.\u001b[39mnorm(x)\n", + "File \u001b[0;32m~/Documents/ao/torchao/prototype/blockwise_fp8/blockwise_quantization.py:78\u001b[0m, in \u001b[0;36mfp8_blockwise_weight_quant\u001b[0;34m(x, block_size)\u001b[0m\n\u001b[1;32m 76\u001b[0m s \u001b[38;5;241m=\u001b[39m x\u001b[38;5;241m.\u001b[39mnew_empty(M \u001b[38;5;241m/\u001b[39m\u001b[38;5;241m/\u001b[39m block_size, N \u001b[38;5;241m/\u001b[39m\u001b[38;5;241m/\u001b[39m block_size, dtype\u001b[38;5;241m=\u001b[39mtorch\u001b[38;5;241m.\u001b[39mfloat32)\n\u001b[1;32m 77\u001b[0m grid \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mlambda\u001b[39;00m meta: (triton\u001b[38;5;241m.\u001b[39mcdiv(M, meta[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mBLOCK_SIZE\u001b[39m\u001b[38;5;124m'\u001b[39m]), triton\u001b[38;5;241m.\u001b[39mcdiv(N, meta[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mBLOCK_SIZE\u001b[39m\u001b[38;5;124m'\u001b[39m]))\n\u001b[0;32m---> 78\u001b[0m \u001b[43mfp8_blockwise_quant_weight_kernel\u001b[49m\u001b[43m[\u001b[49m\u001b[43mgrid\u001b[49m\u001b[43m]\u001b[49m\u001b[43m(\u001b[49m\u001b[43mx\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43my\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43ms\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mM\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mN\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mBLOCK_SIZE\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mblock_size\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 79\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m y, s\n", + "File \u001b[0;32m~/Documents/ao/venv/lib/python3.12/site-packages/triton/runtime/jit.py:330\u001b[0m, in \u001b[0;36mKernelInterface.__getitem__..\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 324\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21m__getitem__\u001b[39m(\u001b[38;5;28mself\u001b[39m, grid) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m T:\n\u001b[1;32m 325\u001b[0m \u001b[38;5;250m \u001b[39m\u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m 326\u001b[0m \u001b[38;5;124;03m A JIT function is launched with: fn[grid](*args, **kwargs).\u001b[39;00m\n\u001b[1;32m 327\u001b[0m \u001b[38;5;124;03m Hence JITFunction.__getitem__ returns a callable proxy that\u001b[39;00m\n\u001b[1;32m 328\u001b[0m \u001b[38;5;124;03m memorizes the grid.\u001b[39;00m\n\u001b[1;32m 329\u001b[0m \u001b[38;5;124;03m \"\"\"\u001b[39;00m\n\u001b[0;32m--> 330\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;01mlambda\u001b[39;00m \u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs: \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrun\u001b[49m\u001b[43m(\u001b[49m\u001b[43mgrid\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mgrid\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mwarmup\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mFalse\u001b[39;49;00m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m~/Documents/ao/venv/lib/python3.12/site-packages/triton/runtime/jit.py:623\u001b[0m, in \u001b[0;36mJITFunction.run\u001b[0;34m(self, grid, warmup, *args, **kwargs)\u001b[0m\n\u001b[1;32m 621\u001b[0m \u001b[38;5;66;03m# compile the kernel\u001b[39;00m\n\u001b[1;32m 622\u001b[0m src \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mASTSource(\u001b[38;5;28mself\u001b[39m, signature, constants, configs[\u001b[38;5;241m0\u001b[39m])\n\u001b[0;32m--> 623\u001b[0m kernel \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcompile\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 624\u001b[0m \u001b[43m \u001b[49m\u001b[43msrc\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 625\u001b[0m \u001b[43m \u001b[49m\u001b[43mtarget\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtarget\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 626\u001b[0m \u001b[43m \u001b[49m\u001b[43moptions\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43moptions\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[38;5;18;43m__dict__\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[1;32m 627\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 628\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcache[device][key] \u001b[38;5;241m=\u001b[39m kernel\n\u001b[1;32m 629\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_call_hook(key, signature, device, constants, options, configs, warmup, before\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mFalse\u001b[39;00m)\n", + "File \u001b[0;32m~/Documents/ao/venv/lib/python3.12/site-packages/triton/compiler/compiler.py:273\u001b[0m, in \u001b[0;36mcompile\u001b[0;34m(src, target, options)\u001b[0m\n\u001b[1;32m 271\u001b[0m module_map \u001b[38;5;241m=\u001b[39m backend\u001b[38;5;241m.\u001b[39mget_module_map()\n\u001b[1;32m 272\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m--> 273\u001b[0m module \u001b[38;5;241m=\u001b[39m \u001b[43msrc\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmake_ir\u001b[49m\u001b[43m(\u001b[49m\u001b[43moptions\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcodegen_fns\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmodule_map\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcontext\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 274\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mException\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m e:\n\u001b[1;32m 275\u001b[0m filter_traceback(e)\n", + "File \u001b[0;32m~/Documents/ao/venv/lib/python3.12/site-packages/triton/compiler/compiler.py:100\u001b[0m, in \u001b[0;36mASTSource.make_ir\u001b[0;34m(self, options, codegen_fns, module_map, context)\u001b[0m\n\u001b[1;32m 99\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21mmake_ir\u001b[39m(\u001b[38;5;28mself\u001b[39m, options, codegen_fns, module_map, context):\n\u001b[0;32m--> 100\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mast_to_ttir\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfn\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcontext\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcontext\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43moptions\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43moptions\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcodegen_fns\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcodegen_fns\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 101\u001b[0m \u001b[43m \u001b[49m\u001b[43mmodule_map\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mmodule_map\u001b[49m\u001b[43m)\u001b[49m\n", + "\u001b[0;31mCompilationError\u001b[0m: at 1:0:\ndef fp8_blockwise_quant_weight_kernel(x_ptr, y_ptr, s_ptr, M, N, BLOCK_SIZE: tl.constexpr):\n^\nValueError(\"type fp8e4nv not supported in this architecture. The supported fp8 dtypes are ('fp8e4b15', 'fp8e5')\")" + ] + } + ], + "source": [ + "from torchao.prototype.blockwise_fp8.blockwise_quantization import fp8_blockwise_weight_quant, fp8_blockwise_weight_dequant\n", + "from torchao.prototype.blockwise_fp8.blockwise_fp8_gemm_triton import blockwise_fp8_gemm\n", + "\n", + "def test_quant_dequant():\n", + " torch.manual_seed(0)\n", + " x = torch.randn(256, 256)\n", + " qx, s = fp8_blockwise_weight_quant(x, block_size=128)\n", + " x_reconstructed = fp8_blockwise_weight_dequant(qx, s, block_size=128)\n", + "\n", + " error = torch.norm(x - x_reconstructed) / torch.norm(x)\n", + " print(f\"Relative Error: {error.item():.6f}\")\n", + "\n", + " assert error < 0.05, \"Quant-Dequant error too high!\"\n", + "\n", + "def test_blockwise_fp8_gemm():\n", + " torch.manual_seed(0)\n", + " M, N, K = 256, 256, 256\n", + " A = torch.randn(M, K)\n", + " B = torch.randn(K, N)\n", + "\n", + " C = A @ B\n", + "\n", + " A_q, A_s = fp8_blockwise_weight_quant(A, block_size=128)\n", + " B_q, B_s = fp8_blockwise_weight_quant(B, block_size=128)\n", + "\n", + " C_q = blockwise_fp8_gemm(A_q, A_s, B_q, B_s)\n", + "\n", + " error = torch.norm(C - C_q) / torch.norm(C)\n", + " print(f\"Relative Error: {error.item():.6f}\")\n", + "\n", + " assert error < 0.05, \"Quantized GEMM error is too high!\"\n", + "\n", + "\n", + "test_quant_dequant()\n", + "test_blockwise_fp8_gemm()" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.3" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/torchao/prototype/blockwise_fp8/blockwise_fp8_gemm_triton.py b/torchao/prototype/blockwise_fp8/blockwise_fp8_gemm_triton.py new file mode 100644 index 0000000000..2f97bdb115 --- /dev/null +++ b/torchao/prototype/blockwise_fp8/blockwise_fp8_gemm_triton.py @@ -0,0 +1,58 @@ +import torch +import triton +import triton.language as tl +from triton import Config + +fp8_gemm_configs = [ + Config({'BLOCK_SIZE_M': block_m, 'BLOCK_SIZE_N': block_n, 'BLOCK_SIZE_K': 128}, num_stages=num_stages, num_warps=8) + for block_m in [16, 32, 64] for block_n in [32, 64, 128] for num_stages in [3, 4, 5, 6] +] + +@triton.autotune(configs=fp8_gemm_configs, key=['N', 'K']) +@triton.jit +def blockwise_fp8_gemm_kernel(a_ptr, b_ptr, c_ptr, + a_s_ptr, b_s_ptr, + M, N: tl.constexpr, K: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr): + pid_m = tl.program_id(axis=0) + pid_n = tl.program_id(axis=1) + k = tl.cdiv(K, BLOCK_SIZE_K) + offs_m = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = a_ptr + offs_m[:, None] * K + offs_k[None, :] + b_ptrs = b_ptr + offs_n[None, :] * K + offs_k[:, None] + a_s_ptrs = a_s_ptr + offs_m * k + b_s_ptrs = b_s_ptr + (offs_n // BLOCK_SIZE_K) * k + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for i in range(k): + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - i * BLOCK_SIZE_K, other=0.0) + b = tl.load(b_ptrs, mask=offs_k[:, None] < K - i * BLOCK_SIZE_K, other=0.0) + a_s = tl.load(a_s_ptrs) + b_s = tl.load(b_s_ptrs) + accumulator += tl.dot(a, b) * a_s[:, None] * b_s[None, :] + a_ptrs += BLOCK_SIZE_K + b_ptrs += BLOCK_SIZE_K + a_s_ptrs += 1 + b_s_ptrs += 1 + c = accumulator.to(c_ptr.dtype.element_ty) + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + offs_m[:, None] * N + offs_n[None, :] + mask = (offs_m[:, None] < M) & (offs_n[None, :] < N) + tl.store(c_ptrs, c, mask=mask) + + +def blockwise_fp8_gemm(a: torch.Tensor, a_s: torch.Tensor, b: torch.Tensor, b_s: torch.Tensor): + assert a.is_contiguous() and b.is_contiguous() + assert a_s.is_contiguous() and b_s.is_contiguous() + K = a.size(-1) + M = a.numel() // K + N = b.size(0) + c = a.new_empty(*a.size()[:-1], N, dtype=torch.get_default_dtype()) + grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']), triton.cdiv(N, META['BLOCK_SIZE_N'])) + blockwise_fp8_gemm_kernel[grid](a, b, c, a_s, b_s, M, N, K) + return c \ No newline at end of file diff --git a/torchao/prototype/blockwise_fp8/blockwise_linear.py b/torchao/prototype/blockwise_fp8/blockwise_linear.py new file mode 100644 index 0000000000..7c070e3123 --- /dev/null +++ b/torchao/prototype/blockwise_fp8/blockwise_linear.py @@ -0,0 +1,52 @@ +import torch +from torch import nn +from typing import Optional +from torchao.prototype.blockwise_fp8.blockwise_fp8_gemm_triton import blockwise_fp8_gemm +from torchao.prototype.blockwise_fp8.blockwise_quantization import fp8_blockwise_act_quant + +def linear(x: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None, block_size = 128) -> torch.Tensor: + x, scale = fp8_blockwise_act_quant(x, block_size) + y = blockwise_fp8_gemm(x, scale, weight, weight.scale) + if bias is not None: + y += bias + return y + +class Linear(nn.Module): + """ + Custom linear layer with support for quantized weights and optional bias. + + Args: + in_features (int): Number of input features. + out_features (int): Number of output features. + bias (bool): Whether to include a bias term. Defaults to False. + block_size (int): Block size for quantization. Defaults to 128. + """ + dtype = torch.bfloat16 + + def __init__(self, in_features: int, out_features: int, bias: bool = False, block_size = 128, dtype = None): + super().__init__() + self.in_features = in_features + self.out_features = out_features + self.weight = nn.Parameter(torch.empty(out_features, in_features, dtype=dtype or Linear.dtype)) + if self.weight.element_size() == 1: + scale_out_features = (out_features + block_size - 1) // block_size + scale_in_features = (in_features + block_size - 1) // block_size + self.weight.scale = self.scale = nn.Parameter(torch.empty(scale_out_features, scale_in_features, dtype=torch.float32)) + else: + self.register_parameter("scale", None) + if bias: + self.bias = nn.Parameter(torch.empty(out_features)) + else: + self.register_parameter("bias", None) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Forward pass for the custom linear layer. + + Args: + x (torch.Tensor): Input tensor. + + Returns: + torch.Tensor: Transformed tensor after linear computation. + """ + return linear(x, self.weight, self.bias) \ No newline at end of file diff --git a/torchao/prototype/blockwise_fp8/blockwise_quantization.py b/torchao/prototype/blockwise_fp8/blockwise_quantization.py new file mode 100644 index 0000000000..c64df6b261 --- /dev/null +++ b/torchao/prototype/blockwise_fp8/blockwise_quantization.py @@ -0,0 +1,131 @@ +from typing import Tuple + +import torch +import triton +import triton.language as tl + + +@triton.jit +def fp8_blockwise_quant_act_kernel(x_ptr, y_ptr, s_ptr, BLOCK_SIZE: tl.constexpr): + """ + Quantizes the input tensor `x_ptr` and stores the result in `y_ptr` and the scaling factor in `s_ptr`. + + Args: + x_ptr (triton.Pointer): Pointer to the input tensor. + y_ptr (triton.Pointer): Pointer to the output tensor where quantized values will be stored. + s_ptr (triton.Pointer): Pointer to the output tensor where scaling factors will be stored. + BLOCK_SIZE (tl.constexpr): The size of the block to be processed by each program instance. + + Returns: + None + """ + pid = tl.program_id(axis=0) + offs = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + x = tl.load(x_ptr + offs).to(tl.float32) + s = tl.max(tl.abs(x)) / 448. + y = x / s + y = y.to(y_ptr.dtype.element_ty) + tl.store(y_ptr + offs, y) + tl.store(s_ptr + pid, s) + + +def fp8_blockwise_act_quant(x: torch.Tensor, block_size: int = 128) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Quantizes the input tensor `x` using block-wise quantization. + + Args: + x (torch.Tensor): The input tensor to be quantized. Must be contiguous and its last dimension size must be divisible by `block_size`. + block_size (int, optional): The size of the blocks to be used for quantization. Default is 128. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: A tuple containing: + - The quantized tensor with dtype `torch.float8_e4m3fn`. + - A tensor of scaling factors with dtype `torch.float32`. + """ + assert x.is_contiguous(), 'Input tensor must be contiguous' + assert x.size(-1) % block_size == 0, f'Last dimension size must be divisible by block_size (block_size={block_size})' + y = torch.empty_like(x, dtype=torch.float8_e4m3fn) + s = x.new_empty(*x.size()[:-1], x.size(-1) // block_size, dtype=torch.float32) + grid = lambda meta: (triton.cdiv(x.numel(), meta['BLOCK_SIZE']), ) + fp8_blockwise_quant_act_kernel[grid](x, y, s, BLOCK_SIZE=block_size) + return y, s + +@triton.jit +def fp8_blockwise_quant_weight_kernel(x_ptr, y_ptr, s_ptr, M, N, BLOCK_SIZE: tl.constexpr): + pid_m = tl.program_id(axis=0) + pid_n = tl.program_id(axis=1) + n = tl.cdiv(N, BLOCK_SIZE) + offs_m = pid_m * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + offs_n = pid_n * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + offs = offs_m[:, None] * N + offs_n[None, :] + mask = (offs_m[:, None] < M) & (offs_n[None, :] < N) + x = tl.load(x_ptr + offs, mask=mask).to(tl.float32) + s = tl.max(tl.abs(x)) / 448. + y = (x / s).to(y_ptr.dtype.element_ty) + tl.store(y_ptr + offs, y, mask=mask) + tl.store(s_ptr + pid_m * n + pid_n, s) + + +def fp8_blockwise_weight_quant(x: torch.Tensor, block_size: int = 128): + assert x.is_contiguous(), 'Input tensor must be contiguous' + assert x.dim() == 2, 'Input tensor must have 2 dimensions' + assert x.size(0) % block_size == 0 and x.size(1) % block_size == 0, \ + f"Both dimensions of x must be divisible by block_size (block_size={block_size})" + M, N = x.size() + y = torch.empty_like(x, dtype=torch.float8_e4m3fn) + s = x.new_empty(M // block_size, N // block_size, dtype=torch.float32) + grid = lambda meta: (triton.cdiv(M, meta['BLOCK_SIZE']), triton.cdiv(N, meta['BLOCK_SIZE'])) + fp8_blockwise_quant_weight_kernel[grid](x, y, s, M, N, BLOCK_SIZE=block_size) + return y, s + +@triton.jit +def fp8_blockwise_dequant_weight_kernel(x_ptr, s_ptr, y_ptr, M, N, BLOCK_SIZE: tl.constexpr): + """ + Dequantizes weights using the provided scaling factors and stores the result. + + Args: + x_ptr (tl.pointer): Pointer to the quantized weights. + s_ptr (tl.pointer): Pointer to the scaling factors. + y_ptr (tl.pointer): Pointer to the output buffer for dequantized weights. + M (int): Number of rows in the weight matrix. + N (int): Number of columns in the weight matrix. + BLOCK_SIZE (tl.constexpr): Size of the block for tiling. + + Returns: + None + """ + pid_m = tl.program_id(axis=0) + pid_n = tl.program_id(axis=1) + n = tl.cdiv(N, BLOCK_SIZE) + offs_m = pid_m * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + offs_n = pid_n * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + offs = offs_m[:, None] * N + offs_n[None, :] + mask = (offs_m[:, None] < M) & (offs_n[None, :] < N) + x = tl.load(x_ptr + offs, mask=mask).to(tl.float32) + s = tl.load(s_ptr + pid_m * n + pid_n) + y = x * s + tl.store(y_ptr + offs, y, mask=mask) + + +def fp8_blockwise_weight_dequant(x: torch.Tensor, s: torch.Tensor, block_size: int = 128) -> torch.Tensor: + """ + Dequantizes the given weight tensor using the provided scale tensor. + + Args: + x (torch.Tensor): The quantized weight tensor of shape (M, N). + s (torch.Tensor): The scale tensor of shape (M, N). + block_size (int, optional): The block size to use for dequantization. Defaults to 128. + + Returns: + torch.Tensor: The dequantized weight tensor of the same shape as `x`. + + Raises: + AssertionError: If `x` or `s` are not contiguous or if their dimensions are not 2. + """ + assert x.is_contiguous() and s.is_contiguous(), 'Input tensors must be contiguous' + assert x.dim() == 2 and s.dim() == 2, 'Input tensors must have 2 dimensions' + M, N = x.size() + y = torch.empty_like(x, dtype=torch.get_default_dtype()) + grid = lambda meta: (triton.cdiv(M, meta['BLOCK_SIZE']), triton.cdiv(N, meta['BLOCK_SIZE'])) + fp8_blockwise_dequant_weight_kernel[grid](x, s, y, M, N, BLOCK_SIZE=block_size) + return y \ No newline at end of file From 91b368d1c9ae36517377dfd7a33e309b2525db20 Mon Sep 17 00:00:00 2001 From: Degnel Date: Sat, 22 Feb 2025 10:57:29 +0100 Subject: [PATCH 8/8] linting --- ...enchmark_blockwise_scaled_linear_triton.py | 83 ++++------ benchmarks/test.ipynb | 144 ------------------ test/prototype/test_blockwise_triton.py | 51 +++++++ .../blockwise_fp8/blockwise_linear.py | 11 +- .../blockwise_fp8/blockwise_quantization.py | 3 +- 5 files changed, 93 insertions(+), 199 deletions(-) delete mode 100644 benchmarks/test.ipynb create mode 100644 test/prototype/test_blockwise_triton.py diff --git a/benchmarks/benchmark_blockwise_scaled_linear_triton.py b/benchmarks/benchmark_blockwise_scaled_linear_triton.py index 07786e3c70..9256123e11 100644 --- a/benchmarks/benchmark_blockwise_scaled_linear_triton.py +++ b/benchmarks/benchmark_blockwise_scaled_linear_triton.py @@ -3,21 +3,28 @@ from tqdm import tqdm from triton.testing import do_bench -from torchao.prototype.blockwise_fp8.blockwise_fp8_gemm_triton import blockwise_fp8_gemm +from torchao.float8.float8_utils import compute_error from torchao.ops import rowwise_scaled_linear_cutlass_s8s4 +from torchao.prototype.blockwise_fp8.blockwise_fp8_gemm_triton import blockwise_fp8_gemm +from torchao.prototype.blockwise_fp8.blockwise_quantization import ( + fp8_blockwise_act_quant, + fp8_blockwise_weight_quant, +) +from torchao.quantization.quant_api import ( + int8_dynamic_activation_int4_weight, + quantize_, +) def benchmark_microseconds(f, *args): return do_bench(lambda: f(*args), return_mode="median") * 1e3 -def get_rowwise_problem(m: int, n: int, k: int, A_nbits: int, B_nbits: int): - assert A_nbits in (4, 8) and B_nbits in (4, 8) - +def get_rowwise_problem(m: int, n: int, k: int): dev = torch.device("cuda") - A = torch.randint(-128, 127, (m, k * A_nbits // 8), dtype=torch.int8, device=dev) + A = torch.randint(-128, 127, (m, k), dtype=torch.int8, device=dev) A_scale = torch.randn((m,), dtype=torch.half, device=dev) B = torch.randint( - -128, 127, size=(n, k * B_nbits // 8), dtype=torch.int8, device=dev + -128, 127, size=(n, 4 * k // 8), dtype=torch.int8, device=dev ) B_scale = torch.randn((n,), dtype=torch.half, device=dev) C = None @@ -35,12 +42,13 @@ def get_blockwise_problem(m: int, n: int, k: int, block_size: int): return A, A_scale, B, B_scale def benchmark(m: int, k: int, n: int, block_size: int): + # Speed benchmark dev = torch.device("cuda") A_ref = torch.randn((m, k), dtype=torch.half, device=dev) B_ref = torch.randn((n, k), dtype=torch.half, device=dev) fp16_time = benchmark_microseconds(torch.nn.functional.linear, A_ref, B_ref) - A, A_scale, B, B_scale, C = get_rowwise_problem(m, n, k, 8, 8) + A, A_scale, B, B_scale, C = get_rowwise_problem(m, n, k) rowwise_scaled_linear_cutlass_s8s4_time = benchmark_microseconds( rowwise_scaled_linear_cutlass_s8s4, A, A_scale, B, B_scale, C ) @@ -50,12 +58,21 @@ def benchmark(m: int, k: int, n: int, block_size: int): blockwise_fp8_gemm, A, A_scale, B, B_scale ) - # Add precision tests - # On prend 2 sets de matrices aléatoires - # On les quantise en int8/int4 rowwise - # On les quantise en en float8 blockwise - # + # Precision benchmark + lin = torch.nn.Linear(k, n, False, dev, torch.half) + A = torch.randn((m, k), dtype=torch.half, device=dev) + W = lin.weight + output = A @ W.T + + A_q, A_s = fp8_blockwise_act_quant(A, block_size) + W_q, W_s = fp8_blockwise_weight_quant(W, block_size) + output_blockwise_quant = blockwise_fp8_gemm(A_q, A_s, W_q, W_s) + + quantize_(lin, int8_dynamic_activation_int4_weight()) + output_rowwise_quant = lin(A) + error_rowwise_quant = compute_error(output, output_rowwise_quant) + error_blockwise_quant = compute_error(output, output_blockwise_quant) return { "m": m, @@ -66,46 +83,10 @@ def benchmark(m: int, k: int, n: int, block_size: int): "rowwise s8s4 speedup (d/s)": fp16_time / rowwise_scaled_linear_cutlass_s8s4_time, "blockwise_fp8_gemm latency (ms)": blockwise_fp8_gemm_time, "blockwise fp8 speedup (d/s)": fp16_time / blockwise_fp8_gemm_time, + "error_rowwise_quant (dB)": error_rowwise_quant, + "error_blockwise_quant (dB)": error_blockwise_quant } - -from torchao.prototype.blockwise_fp8.blockwise_quantization import fp8_blockwise_weight_quant, fp8_blockwise_act_quant, fp8_blockwise_weight_dequant -from torchao.prototype.blockwise_fp8.blockwise_fp8_gemm_triton import blockwise_fp8_gemm - -def test_quant_dequant(): - torch.manual_seed(0) - x = torch.randn(256, 256).cuda() - qx, s = fp8_blockwise_weight_quant(x, block_size=128) - x_reconstructed = fp8_blockwise_weight_dequant(qx, s, block_size=128) - - error = torch.norm(x - x_reconstructed) / torch.norm(x) - print(f"Relative Error: {error.item():.6f}") - - assert error < 0.05, "Quant-Dequant error too high!" - -def test_blockwise_fp8_gemm(): - torch.manual_seed(0) - M, N, K = 256, 256, 128 - A = torch.randn(M, K).cuda() - B = torch.randn(N, K).cuda() - - C = A @ B.T - - A_q, A_s = fp8_blockwise_act_quant(A, block_size=128) - B_q, B_s = fp8_blockwise_weight_quant(B, block_size=128) - - C_q = blockwise_fp8_gemm(A_q, A_s, B_q, B_s) - - error = torch.norm(C - C_q) / torch.norm(C) - print(f"Relative Error: {error.item():.6f}") - - assert error < 0.05, "Quantized GEMM error is too high!" - - -test_quant_dequant() -test_blockwise_fp8_gemm() - - if __name__ == "__main__": k_vals = (8192, 8192, 8192, 28672) n_vals = (8192, 10240, 57344, 8192) @@ -117,5 +98,5 @@ def test_blockwise_fp8_gemm(): results.append(benchmark(m, k, n, block_size)) df = pd.DataFrame(results) - df.to_csv("blockwise_scaled_linear_triton_time_results.csv", index=False) + df.to_csv("blockwise_scaled_linear_triton_results.csv", index=False) print(df.to_markdown(index=False)) \ No newline at end of file diff --git a/benchmarks/test.ipynb b/benchmarks/test.ipynb deleted file mode 100644 index a979b3d6cd..0000000000 --- a/benchmarks/test.ipynb +++ /dev/null @@ -1,144 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [ - { - "ename": "TypeError", - "evalue": "AffineQuantizedTensor.from_hp_to_intx_static() missing 1 required positional argument: 'target_dtype'", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)", - "Cell \u001b[0;32mIn[34], line 31\u001b[0m\n\u001b[1;32m 29\u001b[0m block_size \u001b[38;5;241m=\u001b[39m [\u001b[38;5;241m1\u001b[39m, \u001b[38;5;241m0\u001b[39m]\n\u001b[1;32m 30\u001b[0m target_dtype \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39mint8\n\u001b[0;32m---> 31\u001b[0m qa \u001b[38;5;241m=\u001b[39m \u001b[43mAffineQuantizedTensor\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfrom_hp_to_intx_static\u001b[49m\u001b[43m(\u001b[49m\u001b[43mA\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mMappingType\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mASYMMETRIC\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mblock_size\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtarget_dtype\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mquant_min\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m-\u001b[39;49m\u001b[38;5;241;43m8\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mquant_max\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m7\u001b[39;49m\u001b[43m)\u001b[49m\n\u001b[1;32m 32\u001b[0m block_size \u001b[38;5;241m=\u001b[39m [\u001b[38;5;241m0\u001b[39m, \u001b[38;5;241m1\u001b[39m]\n\u001b[1;32m 33\u001b[0m qb \u001b[38;5;241m=\u001b[39m AffineQuantizedTensor\u001b[38;5;241m.\u001b[39mfrom_hp_to_intx_static(B, MappingType\u001b[38;5;241m.\u001b[39mASYMMETRIC, block_size, target_dtype)\n", - "\u001b[0;31mTypeError\u001b[0m: AffineQuantizedTensor.from_hp_to_intx_static() missing 1 required positional argument: 'target_dtype'" - ] - } - ], - "source": [ - "import sys\n", - "import os\n", - "parent_dir = os.path.abspath(os.path.join(os.getcwd(), '..'))\n", - "\n", - "if parent_dir not in sys.path:\n", - " sys.path.append(parent_dir)\n", - "\n", - "from torchao.dtypes.affine_quantized_tensor import AffineQuantizedTensor\n", - "import torch\n", - "from torchao.quantization.quant_primitives import MappingType\n", - "from torchao.ops import rowwise_scaled_linear_cutlass_s8s4\n", - "\n", - "A = torch.Tensor([\n", - " [1.0, 2.0, 3.0],\n", - " [4.0, 5.0, 6.0],\n", - " [7.0, 8.0, 9.0]]\n", - ")\n", - "\n", - "B = torch.Tensor([\n", - " [-2.0, 1.0],\n", - " [3.0, -4.0],\n", - " [6.0, -8.0]]\n", - ")\n", - "\n", - "C = torch.Tensor(\n", - " [20.0, -18.0]\n", - ")\n", - "\n", - "block_size = [1, 0]\n", - "target_dtype = torch.int8\n", - "qa = AffineQuantizedTensor.from_hp_to_intx(A, MappingType.ASYMMETRIC, block_size, target_dtype, quant_min=-8, quant_max=7)\n", - "block_size = [0, 1]\n", - "qb = AffineQuantizedTensor.from_hp_to_intx(B, MappingType.ASYMMETRIC, block_size, target_dtype)\n", - "\n", - "print(qa)\n", - "print(qb)\n", - "\n", - "print(rowwise_scaled_linear_cutlass_s8s4(qa.tensor_impl.data, qa.tensor_impl.scale, qb.tensor_impl.data, qb.tensor_impl.scale, None))" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [ - { - "ename": "CompilationError", - "evalue": "at 1:0:\ndef fp8_blockwise_quant_weight_kernel(x_ptr, y_ptr, s_ptr, M, N, BLOCK_SIZE: tl.constexpr):\n^\nValueError(\"type fp8e4nv not supported in this architecture. The supported fp8 dtypes are ('fp8e4b15', 'fp8e5')\")", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mCompilationError\u001b[0m Traceback (most recent call last)", - "Cell \u001b[0;32mIn[36], line 15\u001b[0m\n\u001b[1;32m 10\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mRelative Error: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00merror\u001b[38;5;241m.\u001b[39mitem()\u001b[38;5;132;01m:\u001b[39;00m\u001b[38;5;124m.6f\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 12\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m error \u001b[38;5;241m<\u001b[39m \u001b[38;5;241m0.05\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mQuant-Dequant error too high!\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[0;32m---> 15\u001b[0m \u001b[43mtest_quant_dequant\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n", - "Cell \u001b[0;32mIn[36], line 6\u001b[0m, in \u001b[0;36mtest_quant_dequant\u001b[0;34m()\u001b[0m\n\u001b[1;32m 4\u001b[0m torch\u001b[38;5;241m.\u001b[39mmanual_seed(\u001b[38;5;241m0\u001b[39m)\n\u001b[1;32m 5\u001b[0m x \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39mrandn(\u001b[38;5;241m256\u001b[39m, \u001b[38;5;241m256\u001b[39m)\n\u001b[0;32m----> 6\u001b[0m qx, s \u001b[38;5;241m=\u001b[39m \u001b[43mfp8_blockwise_weight_quant\u001b[49m\u001b[43m(\u001b[49m\u001b[43mx\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mblock_size\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m128\u001b[39;49m\u001b[43m)\u001b[49m\n\u001b[1;32m 7\u001b[0m x_reconstructed \u001b[38;5;241m=\u001b[39m fp8_blockwise_weight_dequant(qx, s, block_size\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m128\u001b[39m)\n\u001b[1;32m 9\u001b[0m error \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39mnorm(x \u001b[38;5;241m-\u001b[39m x_reconstructed) \u001b[38;5;241m/\u001b[39m torch\u001b[38;5;241m.\u001b[39mnorm(x)\n", - "File \u001b[0;32m~/Documents/ao/torchao/prototype/blockwise_fp8/blockwise_quantization.py:78\u001b[0m, in \u001b[0;36mfp8_blockwise_weight_quant\u001b[0;34m(x, block_size)\u001b[0m\n\u001b[1;32m 76\u001b[0m s \u001b[38;5;241m=\u001b[39m x\u001b[38;5;241m.\u001b[39mnew_empty(M \u001b[38;5;241m/\u001b[39m\u001b[38;5;241m/\u001b[39m block_size, N \u001b[38;5;241m/\u001b[39m\u001b[38;5;241m/\u001b[39m block_size, dtype\u001b[38;5;241m=\u001b[39mtorch\u001b[38;5;241m.\u001b[39mfloat32)\n\u001b[1;32m 77\u001b[0m grid \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mlambda\u001b[39;00m meta: (triton\u001b[38;5;241m.\u001b[39mcdiv(M, meta[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mBLOCK_SIZE\u001b[39m\u001b[38;5;124m'\u001b[39m]), triton\u001b[38;5;241m.\u001b[39mcdiv(N, meta[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mBLOCK_SIZE\u001b[39m\u001b[38;5;124m'\u001b[39m]))\n\u001b[0;32m---> 78\u001b[0m \u001b[43mfp8_blockwise_quant_weight_kernel\u001b[49m\u001b[43m[\u001b[49m\u001b[43mgrid\u001b[49m\u001b[43m]\u001b[49m\u001b[43m(\u001b[49m\u001b[43mx\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43my\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43ms\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mM\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mN\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mBLOCK_SIZE\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mblock_size\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 79\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m y, s\n", - "File \u001b[0;32m~/Documents/ao/venv/lib/python3.12/site-packages/triton/runtime/jit.py:330\u001b[0m, in \u001b[0;36mKernelInterface.__getitem__..\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 324\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21m__getitem__\u001b[39m(\u001b[38;5;28mself\u001b[39m, grid) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m T:\n\u001b[1;32m 325\u001b[0m \u001b[38;5;250m \u001b[39m\u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m 326\u001b[0m \u001b[38;5;124;03m A JIT function is launched with: fn[grid](*args, **kwargs).\u001b[39;00m\n\u001b[1;32m 327\u001b[0m \u001b[38;5;124;03m Hence JITFunction.__getitem__ returns a callable proxy that\u001b[39;00m\n\u001b[1;32m 328\u001b[0m \u001b[38;5;124;03m memorizes the grid.\u001b[39;00m\n\u001b[1;32m 329\u001b[0m \u001b[38;5;124;03m \"\"\"\u001b[39;00m\n\u001b[0;32m--> 330\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;01mlambda\u001b[39;00m \u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs: \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrun\u001b[49m\u001b[43m(\u001b[49m\u001b[43mgrid\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mgrid\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mwarmup\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mFalse\u001b[39;49;00m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", - "File \u001b[0;32m~/Documents/ao/venv/lib/python3.12/site-packages/triton/runtime/jit.py:623\u001b[0m, in \u001b[0;36mJITFunction.run\u001b[0;34m(self, grid, warmup, *args, **kwargs)\u001b[0m\n\u001b[1;32m 621\u001b[0m \u001b[38;5;66;03m# compile the kernel\u001b[39;00m\n\u001b[1;32m 622\u001b[0m src \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mASTSource(\u001b[38;5;28mself\u001b[39m, signature, constants, configs[\u001b[38;5;241m0\u001b[39m])\n\u001b[0;32m--> 623\u001b[0m kernel \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcompile\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 624\u001b[0m \u001b[43m \u001b[49m\u001b[43msrc\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 625\u001b[0m \u001b[43m \u001b[49m\u001b[43mtarget\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtarget\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 626\u001b[0m \u001b[43m \u001b[49m\u001b[43moptions\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43moptions\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[38;5;18;43m__dict__\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[1;32m 627\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 628\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcache[device][key] \u001b[38;5;241m=\u001b[39m kernel\n\u001b[1;32m 629\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_call_hook(key, signature, device, constants, options, configs, warmup, before\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mFalse\u001b[39;00m)\n", - "File \u001b[0;32m~/Documents/ao/venv/lib/python3.12/site-packages/triton/compiler/compiler.py:273\u001b[0m, in \u001b[0;36mcompile\u001b[0;34m(src, target, options)\u001b[0m\n\u001b[1;32m 271\u001b[0m module_map \u001b[38;5;241m=\u001b[39m backend\u001b[38;5;241m.\u001b[39mget_module_map()\n\u001b[1;32m 272\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m--> 273\u001b[0m module \u001b[38;5;241m=\u001b[39m \u001b[43msrc\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmake_ir\u001b[49m\u001b[43m(\u001b[49m\u001b[43moptions\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcodegen_fns\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmodule_map\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcontext\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 274\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mException\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m e:\n\u001b[1;32m 275\u001b[0m filter_traceback(e)\n", - "File \u001b[0;32m~/Documents/ao/venv/lib/python3.12/site-packages/triton/compiler/compiler.py:100\u001b[0m, in \u001b[0;36mASTSource.make_ir\u001b[0;34m(self, options, codegen_fns, module_map, context)\u001b[0m\n\u001b[1;32m 99\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21mmake_ir\u001b[39m(\u001b[38;5;28mself\u001b[39m, options, codegen_fns, module_map, context):\n\u001b[0;32m--> 100\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mast_to_ttir\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfn\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcontext\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcontext\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43moptions\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43moptions\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcodegen_fns\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcodegen_fns\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 101\u001b[0m \u001b[43m \u001b[49m\u001b[43mmodule_map\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mmodule_map\u001b[49m\u001b[43m)\u001b[49m\n", - "\u001b[0;31mCompilationError\u001b[0m: at 1:0:\ndef fp8_blockwise_quant_weight_kernel(x_ptr, y_ptr, s_ptr, M, N, BLOCK_SIZE: tl.constexpr):\n^\nValueError(\"type fp8e4nv not supported in this architecture. The supported fp8 dtypes are ('fp8e4b15', 'fp8e5')\")" - ] - } - ], - "source": [ - "from torchao.prototype.blockwise_fp8.blockwise_quantization import fp8_blockwise_weight_quant, fp8_blockwise_weight_dequant\n", - "from torchao.prototype.blockwise_fp8.blockwise_fp8_gemm_triton import blockwise_fp8_gemm\n", - "\n", - "def test_quant_dequant():\n", - " torch.manual_seed(0)\n", - " x = torch.randn(256, 256)\n", - " qx, s = fp8_blockwise_weight_quant(x, block_size=128)\n", - " x_reconstructed = fp8_blockwise_weight_dequant(qx, s, block_size=128)\n", - "\n", - " error = torch.norm(x - x_reconstructed) / torch.norm(x)\n", - " print(f\"Relative Error: {error.item():.6f}\")\n", - "\n", - " assert error < 0.05, \"Quant-Dequant error too high!\"\n", - "\n", - "def test_blockwise_fp8_gemm():\n", - " torch.manual_seed(0)\n", - " M, N, K = 256, 256, 256\n", - " A = torch.randn(M, K)\n", - " B = torch.randn(K, N)\n", - "\n", - " C = A @ B\n", - "\n", - " A_q, A_s = fp8_blockwise_weight_quant(A, block_size=128)\n", - " B_q, B_s = fp8_blockwise_weight_quant(B, block_size=128)\n", - "\n", - " C_q = blockwise_fp8_gemm(A_q, A_s, B_q, B_s)\n", - "\n", - " error = torch.norm(C - C_q) / torch.norm(C)\n", - " print(f\"Relative Error: {error.item():.6f}\")\n", - "\n", - " assert error < 0.05, \"Quantized GEMM error is too high!\"\n", - "\n", - "\n", - "test_quant_dequant()\n", - "test_blockwise_fp8_gemm()" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "venv", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.12.3" - } - }, - "nbformat": 4, - "nbformat_minor": 2 -} diff --git a/test/prototype/test_blockwise_triton.py b/test/prototype/test_blockwise_triton.py new file mode 100644 index 0000000000..9aa1c50a18 --- /dev/null +++ b/test/prototype/test_blockwise_triton.py @@ -0,0 +1,51 @@ +import pytest +import torch + +from torchao.prototype.blockwise_fp8.blockwise_fp8_gemm_triton import blockwise_fp8_gemm +from torchao.prototype.blockwise_fp8.blockwise_quantization import ( + fp8_blockwise_act_quant, + fp8_blockwise_weight_dequant, + fp8_blockwise_weight_quant, +) + +ROWWISE_SCALED_LINEAR_CUTLASS_SIZE_MNK = [ + (2, 512, 128), + (3, 2048, 2048), + (4, 3584, 640), + (13, 8704, 8576), + (26, 18944, 1664), + (67, 6656, 1408), +] + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@pytest.mark.parametrize("_, N, K", ROWWISE_SCALED_LINEAR_CUTLASS_SIZE_MNK) +def test_quant_dequant(_, N, K): + x = torch.randn(N, K).cuda() + qx, s = fp8_blockwise_weight_quant(x, block_size=128) + x_reconstructed = fp8_blockwise_weight_dequant(qx, s, block_size=128) + error = torch.norm(x - x_reconstructed) / torch.norm(x) + print(f"Relative Error: {error.item():.6f}") + + assert error < 0.05, "Quant-Dequant error is too high" + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@pytest.mark.parametrize("M, N, K", ROWWISE_SCALED_LINEAR_CUTLASS_SIZE_MNK) +def test_blockwise_fp8_gemm(M, N, K): + A = torch.randn(M, K).cuda() + B = torch.randn(N, K).cuda() + + C = A @ B.T + + A_q, A_s = fp8_blockwise_act_quant(A, block_size=128) + B_q, B_s = fp8_blockwise_weight_quant(B, block_size=128) + + C_q = blockwise_fp8_gemm(A_q, A_s, B_q, B_s) + print(C_q, C) + error = torch.norm(C - C_q) / torch.norm(C) + print(f"Relative Error: {error.item():.6f}") + + assert error < 0.05, "Quantize gemm error is too high" + + +# test_quant_dequant() +# test_blockwise_fp8_gemm() \ No newline at end of file diff --git a/torchao/prototype/blockwise_fp8/blockwise_linear.py b/torchao/prototype/blockwise_fp8/blockwise_linear.py index 7c070e3123..ed909ffa07 100644 --- a/torchao/prototype/blockwise_fp8/blockwise_linear.py +++ b/torchao/prototype/blockwise_fp8/blockwise_linear.py @@ -1,8 +1,13 @@ +from typing import Optional + import torch from torch import nn -from typing import Optional -from torchao.prototype.blockwise_fp8.blockwise_fp8_gemm_triton import blockwise_fp8_gemm -from torchao.prototype.blockwise_fp8.blockwise_quantization import fp8_blockwise_act_quant + +from torchao.prototype.blockwise_fp8.blockwise_fp8_gemm_triton import blockwise_fp8_gemm +from torchao.prototype.blockwise_fp8.blockwise_quantization import ( + fp8_blockwise_act_quant, +) + def linear(x: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None, block_size = 128) -> torch.Tensor: x, scale = fp8_blockwise_act_quant(x, block_size) diff --git a/torchao/prototype/blockwise_fp8/blockwise_quantization.py b/torchao/prototype/blockwise_fp8/blockwise_quantization.py index c64df6b261..04c21ba04f 100644 --- a/torchao/prototype/blockwise_fp8/blockwise_quantization.py +++ b/torchao/prototype/blockwise_fp8/blockwise_quantization.py @@ -61,7 +61,8 @@ def fp8_blockwise_quant_weight_kernel(x_ptr, y_ptr, s_ptr, M, N, BLOCK_SIZE: tl. mask = (offs_m[:, None] < M) & (offs_n[None, :] < N) x = tl.load(x_ptr + offs, mask=mask).to(tl.float32) s = tl.max(tl.abs(x)) / 448. - y = (x / s).to(y_ptr.dtype.element_ty) + y = x / s + y = y.to(y_ptr.dtype.element_ty) tl.store(y_ptr + offs, y, mask=mask) tl.store(s_ptr + pid_m * n + pid_n, s)