diff --git a/torchao/float8/config.py b/torchao/float8/config.py index 175ab03f3c..6a092d5f38 100644 --- a/torchao/float8/config.py +++ b/torchao/float8/config.py @@ -170,7 +170,6 @@ class Float8LinearConfig: # # Per-gemm configuration for gemms calculating `output`, `grad_input` and # `grad_weight` - # TODO(this PR): throw warning if fast_accum False is used with axiswise scaling # gemm_config_output: Float8GemmConfig = Float8GemmConfig(use_fast_accum=True) gemm_config_grad_input: Float8GemmConfig = Float8GemmConfig() @@ -317,21 +316,10 @@ def recipe_name_to_linear_config( cc_w = CastConfig(scaling_granularity=ScalingGranularity.AXISWISE) cc_go = CastConfig(scaling_granularity=ScalingGranularity.AXISWISE) - # The current rowwise CUTLASS kernels in `torch._scaled_mm` are only - # fast with `use_fast_accum=True`. Note that rowwise scaling is more - # accurate than tensorwise scaling, so the overall impact on accuracy - # of tensorwise vs rowwise taking this flag into account will vary. - gc_o = Float8GemmConfig(use_fast_accum=True) - gc_gi = Float8GemmConfig(use_fast_accum=True) - gc_gw = Float8GemmConfig(use_fast_accum=True) - return Float8LinearConfig( cast_config_input=cc_i, cast_config_weight=cc_w, cast_config_grad_output=cc_go, - gemm_config_output=gc_o, - gemm_config_grad_input=gc_gi, - gemm_config_grad_weight=gc_gw, ) elif recipe_name is Float8LinearRecipeName.LW_AXISWISE_WITH_GW_HP: @@ -359,14 +347,6 @@ def recipe_name_to_linear_config( cc_i_gw = CastConfig(scaling_type=ScalingType.DISABLED) cc_go_gw = CastConfig(scaling_type=ScalingType.DISABLED) - # The current rowwise CUTLASS kernels in `torch._scaled_mm` are only - # fast with `use_fast_accum=True`. Note that rowwise scaling is more - # accurate than tensorwise scaling, so the overall impact on accuracy - # of tensorwise vs rowwise taking this flag into account will vary. - gc_o = Float8GemmConfig(use_fast_accum=True) - gc_gi = Float8GemmConfig(use_fast_accum=True) - gc_gw = Float8GemmConfig(use_fast_accum=True) - return Float8LinearConfig( cast_config_input=cc_i, cast_config_weight=cc_w, @@ -374,9 +354,6 @@ def recipe_name_to_linear_config( cast_config_input_for_grad_weight=cc_i_gw, cast_config_weight_for_grad_input=cc_w_gi, cast_config_grad_output_for_grad_weight=cc_go_gw, - gemm_config_output=gc_o, - gemm_config_grad_input=gc_gi, - gemm_config_grad_weight=gc_gw, ) else: diff --git a/torchao/float8/float8_python_api.py b/torchao/float8/float8_python_api.py index 6608dba958..402ce2eb0f 100644 --- a/torchao/float8/float8_python_api.py +++ b/torchao/float8/float8_python_api.py @@ -37,19 +37,25 @@ def addmm_float8_unwrapped( a_inverse_scale = a_scale.reciprocal() b_inverse_scale = b_scale.reciprocal() - if output_dtype == torch.float32 and bias is not None: + post_inverse_scale = None + if ( + a_scale.shape == (a_data.shape[0], 1) + and b_scale.shape == (1, b_data.shape[1]) + and not use_fast_accum + ): + # The rowwise CUTLASS-based kernel is so slow without fast-accum that + # we'd rather use the tensorwise cuBLAS-based kernel and do the scaling + # manually afterwards (hoping Inductor will be able to fuse it). + post_inverse_scale = a_inverse_scale * b_inverse_scale + a_inverse_scale = a_inverse_scale.new_ones(()) + b_inverse_scale = a_inverse_scale.new_ones(()) + + post_bias = None + if output_dtype == torch.float32: # Bias is not supported by _scaled_mm when output is fp32 - output = torch._scaled_mm( - a_data, - b_data, - scale_a=a_inverse_scale, - scale_b=b_inverse_scale, - scale_result=output_scale, - out_dtype=output_dtype, - use_fast_accum=use_fast_accum, - ) - output += bias - return output + post_bias = bias + bias = None + output = torch._scaled_mm( a_data, b_data, @@ -60,4 +66,10 @@ def addmm_float8_unwrapped( out_dtype=output_dtype, use_fast_accum=use_fast_accum, ) + + if post_inverse_scale is not None: + output *= post_inverse_scale + if post_bias is not None: + output += post_bias + return output