Skip to content

Commit

Permalink
[float8] Allow specifying arbitrary dtype for each tensor
Browse files Browse the repository at this point in the history
ghstack-source-id: d8300e2a07c087f3cd51b03e0e21125a83a29489
ghstack-comment-id: 2517857809
Pull Request resolved: #1378
  • Loading branch information
lw committed Dec 4, 2024
1 parent 11b2a2f commit e8c5769
Show file tree
Hide file tree
Showing 11 changed files with 194 additions and 108 deletions.
4 changes: 2 additions & 2 deletions test/float8/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@

from torchao.float8.config import (
CastConfig,
e4m3_dtype,
e5m2_dtype,
Float8LinearConfig,
Float8LinearRecipeName,
ScalingGranularity,
Expand Down Expand Up @@ -53,8 +55,6 @@
from torchao.float8.float8_utils import (
FP8_TYPES,
compute_error,
e4m3_dtype,
e5m2_dtype,
fp8_tensor_statistics,
tensor_to_scale,
)
Expand Down
2 changes: 1 addition & 1 deletion test/float8/test_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@

from torchao.float8.config import (
CastConfig,
e4m3_dtype,
Float8LinearConfig,
Float8LinearRecipeName,
ScalingType,
Expand All @@ -47,7 +48,6 @@
LinearMMConfig,
ScaledMMConfig,
)
from torchao.float8.float8_utils import e4m3_dtype
from torchao.testing.float8.test_utils import get_test_float8_linear_config


Expand Down
8 changes: 4 additions & 4 deletions test/float8/test_dtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,9 @@
from tqdm import tqdm

from torchao.float8 import Float8LinearConfig
from torchao.float8.config import CastConfig, ScalingType
from torchao.float8.config import CastConfig, e4m3_dtype, ScalingType
from torchao.float8.float8_linear_utils import convert_to_float8_training
from torchao.float8.float8_scaling_utils import NoopFwToFloat8E5M2BwDynamic
from torchao.float8.float8_scaling_utils import NoopFwToFloat8BwDynamic
from torchao.float8.float8_tensor import (
Float8Tensor,
GemmInputRole,
Expand All @@ -45,7 +45,7 @@
Float8RowwiseParallel,
PrepareFloat8ModuleInput,
)
from torchao.float8.float8_utils import e4m3_dtype, tensor_to_scale
from torchao.float8.float8_utils import tensor_to_scale
from torchao.float8.fsdp_utils import WeightWithDynamicFloat8CastTensor
from torchao.testing.float8.dtensor_utils import ToyModel

Expand Down Expand Up @@ -173,7 +173,7 @@ def _test_dtensor_fp8_autograd(mesh: DeviceMesh, size=16):
)

out = torch.nn.functional.linear(dist_x_fp8, dist_weight_fp8)
out = NoopFwToFloat8E5M2BwDynamic.apply(out, LinearMMConfig())
out = NoopFwToFloat8BwDynamic.apply(out, LinearMMConfig(), fp8_dtype)
assert isinstance(out, DTensor), f"Expected DTensor, got {type(out)}"
loss = torch.sum(torch.abs(out - dist_target))
loss.backward()
Expand Down
30 changes: 28 additions & 2 deletions torchao/float8/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ class CastConfig:
scaling_type: ScalingType = ScalingType.DYNAMIC
scaling_granularity: ScalingGranularity = ScalingGranularity.TENSORWISE
static_scale: Optional[torch.Tensor] = None
dtype: Optional[torch.dtype] = None

def short_str(self):
return f"{self.scaling_type.short_str()}_{self.scaling_granularity.short_str()}"
Expand All @@ -75,6 +76,9 @@ def __post_init__(self):
assert (
self.scaling_type is ScalingType.DYNAMIC
), "only dynamic scaling type is supported for axiswise scaling granularity"
assert self.dtype is None or (
self.dtype.is_floating_point and self.dtype.itemsize == 1
), "must specify a 8-bit floating-point dtype"


@dataclass(frozen=True)
Expand Down Expand Up @@ -124,6 +128,12 @@ def __post_init__(self):
self.e5m2_dtype = torch.float8_e5m2fnuz


# User defined type for using the individual F8 type based on config
type_config = Float8TypeConfig()
e4m3_dtype = type_config.e4m3_dtype
e5m2_dtype = type_config.e5m2_dtype


@dataclass(frozen=True)
class Float8GemmConfig:
"""
Expand Down Expand Up @@ -276,6 +286,20 @@ def __post_init__(self):
is_disabled_1 == is_disabled_2
), f"incompatible operand precision for {gemm_name}"

for cc1, cc2, operand_name, default_dtype in [
(cc_i, cc_i_gw, "input", e4m3_dtype),
(cc_w, cc_w_gi, "weight", e4m3_dtype),
(cc_go, cc_go_gw, "grad_output", e5m2_dtype),
]:
# Override the dataclass being frozen
if cc1.dtype is None:
object.__setattr__(cc1, "dtype", default_dtype)
if cc2.dtype is None:
object.__setattr__(cc2, "dtype", default_dtype)
assert (
cc1.dtype == cc2.dtype
), f"{operand_name} must be cast to the same dtype in both matmuls it's used in"

if self.use_fp8_all_gather_only:
assert self.enable_fsdp_float8_all_gather, "use_fp8_all_gather_only requires enable_fsdp_float8_all_gather to be True"

Expand Down Expand Up @@ -340,12 +364,14 @@ def recipe_name_to_linear_config(
cc_w = CastConfig(scaling_granularity=ScalingGranularity.AXISWISE)

# grad_input_hp = grad_output_fp8_axiswise_dim0 @ weight_fp8_tensorwise
cc_go = CastConfig(scaling_granularity=ScalingGranularity.AXISWISE)
cc_go = CastConfig(
scaling_granularity=ScalingGranularity.AXISWISE, dtype=e4m3_dtype
)
cc_w_gi = CastConfig(scaling_granularity=ScalingGranularity.TENSORWISE)

# grad_weight_hp = input_t_hp @ grad_output_hp
cc_i_gw = CastConfig(scaling_type=ScalingType.DISABLED)
cc_go_gw = CastConfig(scaling_type=ScalingType.DISABLED)
cc_go_gw = CastConfig(scaling_type=ScalingType.DISABLED, dtype=e4m3_dtype)

return Float8LinearConfig(
cast_config_input=cc_i,
Expand Down
83 changes: 44 additions & 39 deletions torchao/float8/float8_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@
from torchao.float8.config import Float8LinearConfig, ScalingGranularity, ScalingType
from torchao.float8.distributed_utils import tensor_already_casted_to_fp8
from torchao.float8.float8_scaling_utils import (
NoopFwToFloat8E5M2BwDelayed,
NoopFwToFloat8E5M2BwDynamic,
NoopFwToFloat8E5M2BwStatic,
NoopFwToFloat8BwDelayed,
NoopFwToFloat8BwDynamic,
NoopFwToFloat8BwStatic,
_maybe_initialize_amaxes_scales_for_float8_cast,
get_maybe_axiswise_dim,
hp_tensor_to_float8_delayed,
Expand All @@ -32,8 +32,6 @@
hp_tensor_and_scale_to_float8,
)
from torchao.float8.float8_utils import (
e4m3_dtype,
e5m2_dtype,
tensor_to_amax,
tensor_to_scale,
)
Expand Down Expand Up @@ -136,7 +134,7 @@ def forward(
else:
input_maybe_fp8 = hp_tensor_to_float8_dynamic(
input_hp,
e4m3_dtype,
c.cast_config_input.dtype,
linear_mm_config,
gemm_input_role=GemmInputRole.INPUT,
scaling_granularity=c.cast_config_input.scaling_granularity,
Expand All @@ -150,7 +148,7 @@ def forward(
else:
weight_maybe_fp8_t = hp_tensor_to_float8_dynamic(
weight_hp_t,
e4m3_dtype,
c.cast_config_weight.dtype,
linear_mm_config,
gemm_input_role=GemmInputRole.WEIGHT,
scaling_granularity=c.cast_config_weight.scaling_granularity,
Expand Down Expand Up @@ -186,7 +184,7 @@ def backward(ctx, grad_output):
else:
grad_output_reshaped_maybe_fp8_dim0 = hp_tensor_to_float8_dynamic(
grad_output_reshaped,
e5m2_dtype,
c.cast_config_grad_output.dtype,
ctx.linear_mm_config,
gemm_input_role=GemmInputRole.GRAD_OUTPUT,
scaling_granularity=c.cast_config_grad_output.scaling_granularity,
Expand All @@ -204,7 +202,7 @@ def backward(ctx, grad_output):
# the entire tensor.
weight_t_maybe_fp8_dim0 = hp_tensor_to_float8_dynamic(
weight_hp_t,
e4m3_dtype,
c.cast_config_weight_for_grad_input.dtype,
ctx.linear_mm_config,
gemm_input_role=GemmInputRole.WEIGHT,
scaling_granularity=c.cast_config_weight_for_grad_input.scaling_granularity,
Expand Down Expand Up @@ -236,7 +234,7 @@ def backward(ctx, grad_output):
else:
grad_output_reshaped_maybe_fp8_dim1 = hp_tensor_to_float8_dynamic(
grad_output_reshaped,
e5m2_dtype,
c.cast_config_grad_output_for_grad_weight.dtype,
ctx.linear_mm_config,
gemm_input_role=GemmInputRole.GRAD_OUTPUT,
scaling_granularity=c.cast_config_grad_output_for_grad_weight.scaling_granularity,
Expand All @@ -250,7 +248,7 @@ def backward(ctx, grad_output):
else:
input_reshaped_maybe_fp8_dim1 = hp_tensor_to_float8_dynamic(
input_hp_reshaped,
e4m3_dtype,
c.cast_config_input_for_grad_weight.dtype,
ctx.linear_mm_config,
gemm_input_role=GemmInputRole.INPUT,
scaling_granularity=c.cast_config_input_for_grad_weight.scaling_granularity,
Expand Down Expand Up @@ -347,11 +345,9 @@ def create_buffers(self):
# Default values for history buffers, see above TODO
history_len = self.config.delayed_scaling_config.history_len
device = self.weight.device
# TODO(future PR): dtype values below don't have the other float8
# flavors, fix it
default_input = torch.finfo(torch.float8_e4m3fn).max
default_weight = torch.finfo(torch.float8_e4m3fn).max
default_grad_output = torch.finfo(torch.float8_e5m2).max
default_input = torch.finfo(self.config.cast_config_input.dtype).max
default_weight = torch.finfo(self.config.cast_config_weight.dtype).max
default_grad_output = torch.finfo(self.config.cast_config_grad_output.dtype).max

# Note: for now, create all the buffers if any are needed, to postpone
# the work to make the scale and amax syncing and history calculation
Expand Down Expand Up @@ -438,29 +434,32 @@ def cast_input_to_float8(
self.fp8_amax_history_input,
self.fp8_scale_input,
scale_fn_name,
e4m3_dtype,
self.config.cast_config_input.dtype,
is_amax_initialized,
reduce_amax=True,
)
input_fp8 = hp_tensor_to_float8_delayed(
input,
self.fp8_scale_input,
e4m3_dtype,
self.config.cast_config_input.dtype,
self.fp8_amax_input,
linear_mm_config=self.linear_mm_config,
gemm_input_role=GemmInputRole.INPUT,
)
elif self.scaling_type_input is ScalingType.DYNAMIC:
input_fp8 = hp_tensor_to_float8_dynamic(
input,
e4m3_dtype,
self.config.cast_config_input.dtype,
self.linear_mm_config,
gemm_input_role=GemmInputRole.INPUT,
)
else:
assert self.scaling_type_input is ScalingType.STATIC
input_fp8 = hp_tensor_to_float8_static(
input, self.fp8_static_scale_input, e4m3_dtype, self.linear_mm_config
input,
self.fp8_static_scale_input,
self.config.cast_config_input.dtype,
self.linear_mm_config,
)

return input_fp8
Expand All @@ -476,14 +475,14 @@ def get_weight_scale(self, weight: torch.Tensor) -> Optional[torch.Tensor]:
self.fp8_amax_history_weight,
self.fp8_scale_weight,
scale_fn_name,
e4m3_dtype,
self.config.cast_config_weight.dtype,
self.is_amax_initialized,
reduce_amax=True,
)
self.fp8_amax_weight.fill_(tensor_to_amax(weight))
return self.fp8_scale_weight
elif self.scaling_type_weight is ScalingType.DYNAMIC:
return tensor_to_scale(weight, e4m3_dtype)
return tensor_to_scale(weight, self.config.cast_config_weight.dtype)
else:
assert self.scaling_type_weight is ScalingType.STATIC
return self.fp8_static_scale_weight
Expand All @@ -499,7 +498,7 @@ def cast_weight_to_float8_t(
weight_fp8 = hp_tensor_and_scale_to_float8(
weight,
weight_scale,
e4m3_dtype,
self.config.cast_config_weight.dtype,
self.linear_mm_config,
gemm_input_role=GemmInputRole.WEIGHT,
)
Expand All @@ -514,23 +513,29 @@ def cast_weight_to_original_t(self, weight: torch.Tensor):
def cast_output_to_float8_in_bw(self, output: torch.Tensor) -> torch.Tensor:
if self.scaling_type_grad_output is ScalingType.DELAYED:
scale_fn_name = self.config.delayed_scaling_config.scale_fn_name
output = NoopFwToFloat8E5M2BwDelayed.apply(
output = NoopFwToFloat8BwDelayed.apply(
output,
self.fp8_amax_grad_output,
self.fp8_amax_history_grad_output,
self.fp8_scale_grad_output,
scale_fn_name,
self.is_amax_initialized,
self.linear_mm_config,
self.config.cast_config_grad_output.dtype,
)
elif self.scaling_type_grad_output is ScalingType.DYNAMIC:
output = NoopFwToFloat8E5M2BwDynamic.apply(output, self.linear_mm_config)
output = NoopFwToFloat8BwDynamic.apply(
output,
self.linear_mm_config,
self.config.cast_config_grad_output.dtype,
)
else:
assert self.scaling_type_grad_output is ScalingType.STATIC
output = NoopFwToFloat8E5M2BwStatic.apply(
output = NoopFwToFloat8BwStatic.apply(
output,
self.fp8_static_scale_grad_output,
self.linear_mm_config,
self.config.cast_config_grad_output.dtype,
)
return output

Expand All @@ -547,19 +552,16 @@ def float8_post_forward(self):
return

def forward_fp8_matmul(self, input: torch.Tensor) -> torch.Tensor:
has_any_axiswise_scaling = (
self.config.cast_config_input.scaling_granularity
is ScalingGranularity.AXISWISE
or self.config.cast_config_weight.scaling_granularity
is ScalingGranularity.AXISWISE
or self.config.cast_config_grad_output.scaling_granularity
is ScalingGranularity.AXISWISE
or self.config.cast_config_input_for_grad_weight.scaling_granularity
is ScalingGranularity.AXISWISE
or self.config.cast_config_weight_for_grad_input.scaling_granularity
is ScalingGranularity.AXISWISE
or self.config.cast_config_grad_output_for_grad_weight.scaling_granularity
is ScalingGranularity.AXISWISE
has_any_axiswise_scaling = any(
cc.scaling_granularity is ScalingGranularity.AXISWISE
for cc in [
self.config.cast_config_input,
self.config.cast_config_weight,
self.config.cast_config_grad_output,
self.config.cast_config_input_for_grad_weight,
self.config.cast_config_weight_for_grad_input,
self.config.cast_config_grad_output_for_grad_weight,
]
)

if not has_any_axiswise_scaling:
Expand Down Expand Up @@ -682,6 +684,7 @@ def from_float(
WeightWithDynamicFloat8CastTensor(
new_mod.weight,
new_mod.linear_mm_config,
new_mod.config.cast_config_weight.dtype,
)
)
elif config.cast_config_weight.scaling_type is ScalingType.DELAYED:
Expand All @@ -692,6 +695,7 @@ def from_float(
new_mod.fp8_amax_history_weight,
new_mod.fp8_scale_weight,
new_mod.linear_mm_config,
new_mod.config.cast_config_weight.dtype,
new_mod.is_amax_initialized,
)
)
Expand All @@ -702,6 +706,7 @@ def from_float(
new_mod.weight,
new_mod.fp8_static_scale_weight,
new_mod.linear_mm_config,
new_mod.config.cast_config_weight.dtype,
)
)

Expand Down
Loading

0 comments on commit e8c5769

Please sign in to comment.