Skip to content

Commit

Permalink
[float8] Allow specifying arbitrary dtype for each tensor
Browse files Browse the repository at this point in the history
ghstack-source-id: 5cc007cd363033e95314a20b70f7ded4144b7842
ghstack-comment-id: 2517857809
Pull Request resolved: #1378
  • Loading branch information
lw committed Dec 4, 2024
1 parent da1560b commit aa0587f
Show file tree
Hide file tree
Showing 11 changed files with 220 additions and 132 deletions.
4 changes: 2 additions & 2 deletions test/float8/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@
Float8LinearRecipeName,
ScalingGranularity,
ScalingType,
e4m3_dtype,
e5m2_dtype,
recipe_name_to_linear_config,
)
from torchao.float8.float8_linear import Float8Linear
Expand All @@ -53,8 +55,6 @@
from torchao.float8.float8_utils import (
FP8_TYPES,
compute_error,
e4m3_dtype,
e5m2_dtype,
fp8_tensor_statistics,
tensor_to_scale,
)
Expand Down
2 changes: 1 addition & 1 deletion test/float8/test_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
Float8LinearConfig,
Float8LinearRecipeName,
ScalingType,
e4m3_dtype,
recipe_name_to_linear_config,
)
from torchao.float8.float8_linear import Float8Linear
Expand All @@ -47,7 +48,6 @@
LinearMMConfig,
ScaledMMConfig,
)
from torchao.float8.float8_utils import e4m3_dtype
from torchao.testing.float8.test_utils import get_test_float8_linear_config


Expand Down
8 changes: 4 additions & 4 deletions test/float8/test_dtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,9 @@
from tqdm import tqdm

from torchao.float8 import Float8LinearConfig
from torchao.float8.config import CastConfig, ScalingType
from torchao.float8.config import CastConfig, ScalingType, e4m3_dtype
from torchao.float8.float8_linear_utils import convert_to_float8_training
from torchao.float8.float8_scaling_utils import NoopFwToFloat8E5M2BwDynamic
from torchao.float8.float8_scaling_utils import NoopFwToFloat8BwDynamic
from torchao.float8.float8_tensor import (
Float8Tensor,
GemmInputRole,
Expand All @@ -45,7 +45,7 @@
Float8RowwiseParallel,
PrepareFloat8ModuleInput,
)
from torchao.float8.float8_utils import e4m3_dtype, tensor_to_scale
from torchao.float8.float8_utils import tensor_to_scale
from torchao.float8.fsdp_utils import WeightWithDynamicFloat8CastTensor
from torchao.testing.float8.dtensor_utils import ToyModel

Expand Down Expand Up @@ -173,7 +173,7 @@ def _test_dtensor_fp8_autograd(mesh: DeviceMesh, size=16):
)

out = torch.nn.functional.linear(dist_x_fp8, dist_weight_fp8)
out = NoopFwToFloat8E5M2BwDynamic.apply(out, LinearMMConfig())
out = NoopFwToFloat8BwDynamic.apply(out, LinearMMConfig(), fp8_dtype)
assert isinstance(out, DTensor), f"Expected DTensor, got {type(out)}"
loss = torch.sum(torch.abs(out - dist_target))
loss.backward()
Expand Down
80 changes: 54 additions & 26 deletions torchao/float8/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,35 @@ def short_str(self):
return "axs"


@dataclass
class Float8TypeConfig:
"""
Configuration for selecting the preferred float8 type pair, either e4m3fn/e5m2 or e4m3fnuz/e5m2fnuz.
Currently, ROCm only supports fnuz variants.
"""

# The preferred e4m3 type.
e4m3_dtype = torch.float8_e4m3fn

# The preferred e5m2 type.
e5m2_dtype = torch.float8_e5m2

def __post_init__(self):
if torch.version.hip and torch.cuda.is_available():
prop = torch.cuda.get_device_properties(0)
MI300_ARCH = ("gfx940", "gfx941", "gfx942")
if prop.gcnArchName.split(":")[0] in MI300_ARCH:
self.e4m3_dtype = torch.float8_e4m3fnuz
self.e5m2_dtype = torch.float8_e5m2fnuz


# User defined type for using the individual F8 type based on config
type_config = Float8TypeConfig()
e4m3_dtype = type_config.e4m3_dtype
e5m2_dtype = type_config.e5m2_dtype


@dataclass(frozen=True)
class CastConfig:
"""
Expand All @@ -62,9 +91,11 @@ class CastConfig:
scaling_type: ScalingType = ScalingType.DYNAMIC
scaling_granularity: ScalingGranularity = ScalingGranularity.TENSORWISE
static_scale: Optional[torch.Tensor] = None
dtype: Optional[torch.dtype] = None

def short_str(self):
return f"{self.scaling_type.short_str()}_{self.scaling_granularity.short_str()}"
dtype = {e4m3_dtype: "e4m3", e5m2_dtype: "e5m2"}[self.dtype]
return f"{self.scaling_type.short_str()}_{self.scaling_granularity.short_str()}_{dtype}"

def __post_init__(self):
if self.scaling_type is ScalingType.STATIC:
Expand All @@ -75,6 +106,9 @@ def __post_init__(self):
assert (
self.scaling_type is ScalingType.DYNAMIC
), "only dynamic scaling type is supported for axiswise scaling granularity"
assert self.dtype is None or (
self.dtype.is_floating_point and self.dtype.itemsize == 1
), "must specify a 8-bit floating-point dtype"


@dataclass(frozen=True)
Expand All @@ -101,29 +135,6 @@ def __post_init__(self):
), f"{self.scale_fn_name} is not implemented yet. Only max is supported for now."


@dataclass
class Float8TypeConfig:
"""
Configuration for selecting the preferred float8 type pair, either e4m3fn/e5m2 or e4m3fnuz/e5m2fnuz.
Currently, ROCm only supports fnuz variants.
"""

# The preferred e4m3 type.
e4m3_dtype = torch.float8_e4m3fn

# The preferred e5m2 type.
e5m2_dtype = torch.float8_e5m2

def __post_init__(self):
if torch.version.hip and torch.cuda.is_available():
prop = torch.cuda.get_device_properties(0)
MI300_ARCH = ("gfx940", "gfx941", "gfx942")
if prop.gcnArchName.split(":")[0] in MI300_ARCH:
self.e4m3_dtype = torch.float8_e4m3fnuz
self.e5m2_dtype = torch.float8_e5m2fnuz


@dataclass(frozen=True)
class Float8GemmConfig:
"""
Expand Down Expand Up @@ -276,6 +287,20 @@ def __post_init__(self):
is_disabled_1 == is_disabled_2
), f"incompatible operand precision for {gemm_name}"

for cc1, cc2, operand_name, default_dtype in [
(cc_i, cc_i_gw, "input", e4m3_dtype),
(cc_w, cc_w_gi, "weight", e4m3_dtype),
(cc_go, cc_go_gw, "grad_output", e5m2_dtype),
]:
# Override the dataclass being frozen
if cc1.dtype is None:
object.__setattr__(cc1, "dtype", default_dtype)
if cc2.dtype is None:
object.__setattr__(cc2, "dtype", default_dtype)
assert (
cc1.dtype == cc2.dtype
), f"{operand_name} must be cast to the same dtype in both matmuls it's used in"

if self.use_fp8_all_gather_only:
assert self.enable_fsdp_float8_all_gather, "use_fp8_all_gather_only requires enable_fsdp_float8_all_gather to be True"

Expand Down Expand Up @@ -334,18 +359,21 @@ def recipe_name_to_linear_config(
# * `input`, `weight` and `grad_output` now only need to be scaled
# axiswise across a single dim compared to vanilla all-axiswise,
# which is more amenable to fast kernels
# * the e4m3 dtype is used across the board, including for gradients

# output_hp = input_fp8_axiswise_dim0 @ weight_t_axiswise_dim1
cc_i = CastConfig(scaling_granularity=ScalingGranularity.AXISWISE)
cc_w = CastConfig(scaling_granularity=ScalingGranularity.AXISWISE)

# grad_input_hp = grad_output_fp8_axiswise_dim0 @ weight_fp8_tensorwise
cc_go = CastConfig(scaling_granularity=ScalingGranularity.AXISWISE)
cc_go = CastConfig(
scaling_granularity=ScalingGranularity.AXISWISE, dtype=e4m3_dtype
)
cc_w_gi = CastConfig(scaling_granularity=ScalingGranularity.TENSORWISE)

# grad_weight_hp = input_t_hp @ grad_output_hp
cc_i_gw = CastConfig(scaling_type=ScalingType.DISABLED)
cc_go_gw = CastConfig(scaling_type=ScalingType.DISABLED)
cc_go_gw = CastConfig(scaling_type=ScalingType.DISABLED, dtype=e4m3_dtype)

return Float8LinearConfig(
cast_config_input=cc_i,
Expand Down
Loading

0 comments on commit aa0587f

Please sign in to comment.