Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[float8] Allow specifying arbitrary dtype for each tensor #1378

Merged
merged 6 commits into from
Dec 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions test/float8/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@
Float8LinearRecipeName,
ScalingGranularity,
ScalingType,
e4m3_dtype,
e5m2_dtype,
recipe_name_to_linear_config,
)
from torchao.float8.float8_linear import Float8Linear
Expand All @@ -53,8 +55,6 @@
from torchao.float8.float8_utils import (
FP8_TYPES,
compute_error,
e4m3_dtype,
e5m2_dtype,
fp8_tensor_statistics,
tensor_to_scale,
)
Expand Down Expand Up @@ -546,7 +546,7 @@ def test_repr(self):
config=config,
)
s = m.__repr__()
assert "i:dyn_ten,w:del_ten,go:dyn_ten" in s
assert "i:dyn_ten_e4m3,w:del_ten_e4m3,go:dyn_ten_e5m2" in s

@unittest.skipIf(not is_sm_at_least_89(), "CUDA 8.9 not available")
def test_inference_mode(self):
Expand Down
2 changes: 1 addition & 1 deletion test/float8/test_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
Float8LinearConfig,
Float8LinearRecipeName,
ScalingType,
e4m3_dtype,
recipe_name_to_linear_config,
)
from torchao.float8.float8_linear import Float8Linear
Expand All @@ -47,7 +48,6 @@
LinearMMConfig,
ScaledMMConfig,
)
from torchao.float8.float8_utils import e4m3_dtype
from torchao.testing.float8.test_utils import get_test_float8_linear_config


Expand Down
8 changes: 4 additions & 4 deletions test/float8/test_dtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,9 @@
from tqdm import tqdm

from torchao.float8 import Float8LinearConfig
from torchao.float8.config import CastConfig, ScalingType
from torchao.float8.config import CastConfig, ScalingType, e4m3_dtype
from torchao.float8.float8_linear_utils import convert_to_float8_training
from torchao.float8.float8_scaling_utils import NoopFwToFloat8E5M2BwDynamic
from torchao.float8.float8_scaling_utils import NoopFwToFloat8BwDynamic
from torchao.float8.float8_tensor import (
Float8Tensor,
GemmInputRole,
Expand All @@ -45,7 +45,7 @@
Float8RowwiseParallel,
PrepareFloat8ModuleInput,
)
from torchao.float8.float8_utils import e4m3_dtype, tensor_to_scale
from torchao.float8.float8_utils import tensor_to_scale
from torchao.float8.fsdp_utils import WeightWithDynamicFloat8CastTensor
from torchao.testing.float8.dtensor_utils import ToyModel

Expand Down Expand Up @@ -173,7 +173,7 @@ def _test_dtensor_fp8_autograd(mesh: DeviceMesh, size=16):
)

out = torch.nn.functional.linear(dist_x_fp8, dist_weight_fp8)
out = NoopFwToFloat8E5M2BwDynamic.apply(out, LinearMMConfig())
out = NoopFwToFloat8BwDynamic.apply(out, LinearMMConfig(), fp8_dtype)
assert isinstance(out, DTensor), f"Expected DTensor, got {type(out)}"
loss = torch.sum(torch.abs(out - dist_target))
loss.backward()
Expand Down
82 changes: 56 additions & 26 deletions torchao/float8/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,35 @@ def short_str(self):
return "axs"


@dataclass
class Float8TypeConfig:
"""
Configuration for selecting the preferred float8 type pair, either e4m3fn/e5m2 or e4m3fnuz/e5m2fnuz.

Currently, ROCm only supports fnuz variants.
"""

# The preferred e4m3 type.
e4m3_dtype = torch.float8_e4m3fn

# The preferred e5m2 type.
e5m2_dtype = torch.float8_e5m2

def __post_init__(self):
if torch.version.hip and torch.cuda.is_available():
prop = torch.cuda.get_device_properties(0)
MI300_ARCH = ("gfx940", "gfx941", "gfx942")
if prop.gcnArchName.split(":")[0] in MI300_ARCH:
self.e4m3_dtype = torch.float8_e4m3fnuz
self.e5m2_dtype = torch.float8_e5m2fnuz


# User defined type for using the individual F8 type based on config
type_config = Float8TypeConfig()
e4m3_dtype = type_config.e4m3_dtype
e5m2_dtype = type_config.e5m2_dtype


@dataclass(frozen=True)
class CastConfig:
"""
Expand All @@ -62,9 +91,11 @@ class CastConfig:
scaling_type: ScalingType = ScalingType.DYNAMIC
scaling_granularity: ScalingGranularity = ScalingGranularity.TENSORWISE
static_scale: Optional[torch.Tensor] = None
target_dtype: Optional[torch.dtype] = None

def short_str(self):
return f"{self.scaling_type.short_str()}_{self.scaling_granularity.short_str()}"
dtype = {e4m3_dtype: "e4m3", e5m2_dtype: "e5m2"}[self.target_dtype]
return f"{self.scaling_type.short_str()}_{self.scaling_granularity.short_str()}_{dtype}"

def __post_init__(self):
if self.scaling_type is ScalingType.STATIC:
Expand All @@ -75,6 +106,9 @@ def __post_init__(self):
assert (
self.scaling_type is ScalingType.DYNAMIC
), "only dynamic scaling type is supported for axiswise scaling granularity"
assert self.target_dtype is None or (
self.target_dtype.is_floating_point and self.target_dtype.itemsize == 1
), "must specify a 8-bit floating-point dtype"


@dataclass(frozen=True)
Expand All @@ -101,29 +135,6 @@ def __post_init__(self):
), f"{self.scale_fn_name} is not implemented yet. Only max is supported for now."


@dataclass
class Float8TypeConfig:
"""
Configuration for selecting the preferred float8 type pair, either e4m3fn/e5m2 or e4m3fnuz/e5m2fnuz.

Currently, ROCm only supports fnuz variants.
"""

# The preferred e4m3 type.
e4m3_dtype = torch.float8_e4m3fn

# The preferred e5m2 type.
e5m2_dtype = torch.float8_e5m2

def __post_init__(self):
if torch.version.hip and torch.cuda.is_available():
prop = torch.cuda.get_device_properties(0)
MI300_ARCH = ("gfx940", "gfx941", "gfx942")
if prop.gcnArchName.split(":")[0] in MI300_ARCH:
self.e4m3_dtype = torch.float8_e4m3fnuz
self.e5m2_dtype = torch.float8_e5m2fnuz


@dataclass(frozen=True)
class Float8GemmConfig:
"""
Expand Down Expand Up @@ -276,6 +287,20 @@ def __post_init__(self):
is_disabled_1 == is_disabled_2
), f"incompatible operand precision for {gemm_name}"

for cc1, cc2, operand_name, default_dtype in [
(cc_i, cc_i_gw, "input", e4m3_dtype),
(cc_w, cc_w_gi, "weight", e4m3_dtype),
(cc_go, cc_go_gw, "grad_output", e5m2_dtype),
]:
# Override the dataclass being frozen
if cc1.target_dtype is None:
object.__setattr__(cc1, "target_dtype", default_dtype)
if cc2.target_dtype is None:
object.__setattr__(cc2, "target_dtype", default_dtype)
assert (
cc1.target_dtype == cc2.target_dtype
), f"{operand_name} must be cast to the same dtype in both matmuls it's used in"

if self.use_fp8_all_gather_only:
assert self.enable_fsdp_float8_all_gather, "use_fp8_all_gather_only requires enable_fsdp_float8_all_gather to be True"

Expand Down Expand Up @@ -334,18 +359,23 @@ def recipe_name_to_linear_config(
# * `input`, `weight` and `grad_output` now only need to be scaled
# axiswise across a single dim compared to vanilla all-axiswise,
# which is more amenable to fast kernels
# * the e4m3 dtype is used across the board, including for gradients

# output_hp = input_fp8_axiswise_dim0 @ weight_t_axiswise_dim1
cc_i = CastConfig(scaling_granularity=ScalingGranularity.AXISWISE)
cc_w = CastConfig(scaling_granularity=ScalingGranularity.AXISWISE)

# grad_input_hp = grad_output_fp8_axiswise_dim0 @ weight_fp8_tensorwise
cc_go = CastConfig(scaling_granularity=ScalingGranularity.AXISWISE)
cc_go = CastConfig(
scaling_granularity=ScalingGranularity.AXISWISE, target_dtype=e4m3_dtype
)
cc_w_gi = CastConfig(scaling_granularity=ScalingGranularity.TENSORWISE)

# grad_weight_hp = input_t_hp @ grad_output_hp
cc_i_gw = CastConfig(scaling_type=ScalingType.DISABLED)
cc_go_gw = CastConfig(scaling_type=ScalingType.DISABLED)
cc_go_gw = CastConfig(
scaling_type=ScalingType.DISABLED, target_dtype=e4m3_dtype
)

return Float8LinearConfig(
cast_config_input=cc_i,
Expand Down
Loading
Loading