Skip to content

Commit

Permalink
[float8] Allow specifying arbitrary dtype for each tensor
Browse files Browse the repository at this point in the history
ghstack-source-id: aa9f551c8d274f349c4298932fc95c88040abb09
Pull Request resolved: #1326
  • Loading branch information
lw committed Nov 22, 2024
1 parent 8e36b11 commit a749a5f
Show file tree
Hide file tree
Showing 11 changed files with 208 additions and 111 deletions.
4 changes: 2 additions & 2 deletions test/float8/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@

from torchao.float8.config import (
CastConfig,
e4m3_dtype,
e5m2_dtype,
Float8LinearConfig,
Float8LinearRecipeName,
recipe_name_to_linear_config,
Expand All @@ -51,8 +53,6 @@
)
from torchao.float8.float8_utils import (
compute_error,
e4m3_dtype,
e5m2_dtype,
fp8_tensor_statistics,
FP8_TYPES,
tensor_to_scale,
Expand Down
2 changes: 1 addition & 1 deletion test/float8/test_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import torch.nn as nn
from torchao.float8.config import (
CastConfig,
e4m3_dtype,
Float8LinearConfig,
ScalingType,
Float8LinearRecipeName,
Expand All @@ -41,7 +42,6 @@
GemmInputRole,
ScaledMMConfig,
)
from torchao.float8.float8_utils import e4m3_dtype
from torchao.testing.float8.test_utils import get_test_float8_linear_config

from torch._dynamo.test_case import TestCase as DynamoTestCase
Expand Down
8 changes: 4 additions & 4 deletions test/float8/test_dtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@
from torchao.float8 import Float8LinearConfig
from torchao.float8.float8_linear_utils import convert_to_float8_training

from torchao.float8.config import CastConfig, ScalingType
from torchao.float8.float8_scaling_utils import NoopFwToFloat8E5M2BwDynamic
from torchao.float8.config import CastConfig, e4m3_dtype, ScalingType
from torchao.float8.float8_scaling_utils import NoopFwToFloat8BwDynamic
from torchao.float8.float8_tensor import (
Float8Tensor,
GemmInputRole,
Expand All @@ -40,7 +40,7 @@
Float8RowwiseParallel,
PrepareFloat8ModuleInput,
)
from torchao.float8.float8_utils import e4m3_dtype, tensor_to_scale
from torchao.float8.float8_utils import tensor_to_scale
from torch.distributed._tensor import distribute_tensor, DTensor, Replicate, Shard
from torch.distributed.device_mesh import DeviceMesh, init_device_mesh
from torch.distributed.tensor.parallel import parallelize_module
Expand Down Expand Up @@ -197,7 +197,7 @@ def _test_dtensor_fp8_autograd(mesh: DeviceMesh, size=16):
)

out = torch.nn.functional.linear(dist_x_fp8, dist_weight_fp8)
out = NoopFwToFloat8E5M2BwDynamic.apply(out, LinearMMConfig())
out = NoopFwToFloat8BwDynamic.apply(out, LinearMMConfig(), fp8_dtype)
assert isinstance(out, DTensor), f"Expected DTensor, got {type(out)}"
loss = torch.sum(torch.abs(out - dist_target))
loss.backward()
Expand Down
54 changes: 44 additions & 10 deletions torchao/float8/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ class CastConfig:
scaling_type: ScalingType = ScalingType.DYNAMIC
scaling_granularity: ScalingGranularity = ScalingGranularity.TENSORWISE
static_scale: Optional[torch.Tensor] = None
dtype: torch.dtype = torch.uint8 # dummy dtype to satisfy typing

def short_str(self):
return f"{self.scaling_type.short_str()}_{self.scaling_granularity.short_str()}"
Expand All @@ -75,6 +76,10 @@ def __post_init__(self):
assert (
self.scaling_type is ScalingType.DYNAMIC
), "only dynamic scaling type is supported for axiswise scaling granularity"
if self.scaling_type is not ScalingType.DISABLED:
assert (
self.dtype.is_floating_point and self.dtype.itemsize == 1
), "must specify a 8-bit floating-point dtype"


@dataclass(frozen=True)
Expand Down Expand Up @@ -124,6 +129,12 @@ def __post_init__(self):
self.e5m2_dtype = torch.float8_e5m2fnuz


# User defined type for using the individual F8 type based on config
type_config = Float8TypeConfig()
e4m3_dtype = type_config.e4m3_dtype
e5m2_dtype = type_config.e5m2_dtype


@dataclass(frozen=True)
class Float8GemmConfig:
"""
Expand Down Expand Up @@ -158,13 +169,13 @@ class Float8LinearConfig:
# 3. the same behavior holds for `cast_config_weight` and `cast_config_grad_output`.
#
# `input`
cast_config_input: CastConfig = CastConfig()
cast_config_input: CastConfig = CastConfig(dtype=e4m3_dtype)
cast_config_input_for_grad_weight: Optional[CastConfig] = None
# `weight`
cast_config_weight: CastConfig = CastConfig()
cast_config_weight: CastConfig = CastConfig(dtype=e4m3_dtype)
cast_config_weight_for_grad_input: Optional[CastConfig] = None
# `grad_output`
cast_config_grad_output: CastConfig = CastConfig()
cast_config_grad_output: CastConfig = CastConfig(dtype=e5m2_dtype)
cast_config_grad_output_for_grad_weight: Optional[CastConfig] = None

#
Expand Down Expand Up @@ -279,6 +290,15 @@ def __post_init__(self):
is_disabled_1 == is_disabled_2
), f"incompatible operand precision for {gemm_name}"

for cc1, cc2, operand_name in [
(cc_i, cc_i_gw, "input"),
(cc_w, cc_w_gi, "weight"),
(cc_go, cc_go_gw, "grad_output"),
]:
assert (
cc1.dtype == cc2.dtype
), f"{operand_name} must be cast to the same dtype in both matmuls it's used in"

if self.use_fp8_all_gather_only:
assert self.enable_fsdp_float8_all_gather, "use_fp8_all_gather_only requires enable_fsdp_float8_all_gather to be True"

Expand Down Expand Up @@ -315,9 +335,15 @@ def recipe_name_to_linear_config(

elif recipe_name is Float8LinearRecipeName.ALL_AXISWISE:
# dynamic axiswise scaling with the CUTLASS rowwise kernel
cc_i = CastConfig(scaling_granularity=ScalingGranularity.AXISWISE)
cc_w = CastConfig(scaling_granularity=ScalingGranularity.AXISWISE)
cc_go = CastConfig(scaling_granularity=ScalingGranularity.AXISWISE)
cc_i = CastConfig(
scaling_granularity=ScalingGranularity.AXISWISE, dtype=e4m3_dtype
)
cc_w = CastConfig(
scaling_granularity=ScalingGranularity.AXISWISE, dtype=e4m3_dtype
)
cc_go = CastConfig(
scaling_granularity=ScalingGranularity.AXISWISE, dtype=e5m2_dtype
)

return Float8LinearConfig(
cast_config_input=cc_i,
Expand All @@ -339,12 +365,20 @@ def recipe_name_to_linear_config(
# which is more amenable to fast kernels

# output_hp = input_fp8_axiswise_dim0 @ weight_t_axiswise_dim1
cc_i = CastConfig(scaling_granularity=ScalingGranularity.AXISWISE)
cc_w = CastConfig(scaling_granularity=ScalingGranularity.AXISWISE)
cc_i = CastConfig(
scaling_granularity=ScalingGranularity.AXISWISE, dtype=e4m3_dtype
)
cc_w = CastConfig(
scaling_granularity=ScalingGranularity.AXISWISE, dtype=e4m3_dtype
)

# grad_input_hp = grad_output_fp8_axiswise_dim0 @ weight_fp8_tensorwise
cc_go = CastConfig(scaling_granularity=ScalingGranularity.AXISWISE)
cc_w_gi = CastConfig(scaling_granularity=ScalingGranularity.TENSORWISE)
cc_go = CastConfig(
scaling_granularity=ScalingGranularity.AXISWISE, dtype=e4m3_dtype
)
cc_w_gi = CastConfig(
scaling_granularity=ScalingGranularity.TENSORWISE, dtype=e4m3_dtype
)

# grad_weight_hp = input_t_hp @ grad_output_hp
cc_i_gw = CastConfig(scaling_type=ScalingType.DISABLED)
Expand Down
Loading

0 comments on commit a749a5f

Please sign in to comment.