Skip to content
This repository has been archived by the owner on Aug 7, 2024. It is now read-only.

Add a Float8LinearInference module to support static, dynamic, and wo quant #287

Closed
wants to merge 9 commits into from
Closed
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/python-app.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ jobs:
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip3 install -U --pre torch --index-url https://download.pytorch.org/whl/nightly/cu121
pip install -e .
pip install -e .'[dev]'
pip install -e .'[test]'
Expand Down
1 change: 0 additions & 1 deletion benchmarks/profile_linear_float8.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
from float8_experimental.float8_dynamic_linear import Float8DynamicLinear
from float8_experimental.float8_linear import Float8Linear
from float8_experimental.float8_linear_utils import (
get_float8_linear,
linear_requires_sync,
LinearType,
swap_linear_with_float8_linear,
Expand Down
1 change: 0 additions & 1 deletion benchmarks/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
# LICENSE file in the root directory of this source tree.

import collections
import json
import re


Expand Down
7 changes: 6 additions & 1 deletion float8_experimental/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,11 @@
# LICENSE file in the root directory of this source tree.
# Lets define a few top level things here
from float8_experimental.float8_linear import Float8Linear
from float8_experimental.float8_tensor import Float8Tensor
from float8_experimental.float8_tensor import Float8Tensor, ScaledMMConfig

# Needed to load Float8Tensor with weights_only = True
from torch.serialization import add_safe_globals

add_safe_globals([Float8Tensor, ScaledMMConfig])

__all__ = ["Float8Tensor", "Float8Linear"]
86 changes: 55 additions & 31 deletions float8_experimental/float8_dynamic_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,8 @@ class Float8DynamicLinear(torch.nn.Linear):
def __init__(self, **super_kwargs):
super().__init__(**super_kwargs)

def forward(self, x):
x_fp8 = cast_to_float8_e4m3fn(x, self.forward_config)
def forward(self, input: torch.Tensor) -> torch.Tensor:
x_fp8 = cast_to_float8_e4m3fn(input, self.forward_config)
if isinstance(self.weight, Float8Tensor): # cast by FSDP
w_fp8 = self.weight
else:
Expand All @@ -73,47 +73,71 @@ def forward(self, x):
return y

@classmethod
def from_float(cls, mod, emulate: bool = False) -> "Float8DynamicLinear":
def create_meta_class(
cls, in_features: int, out_features: int
) -> "Float8DynamicLinear":
with torch.device("meta"):
return cls(in_features=in_features, out_features=out_features, bias=False)

def set_mm_configs(self, emulate: bool) -> "Float8DynamicLinear":
self.forward_config = ScaledMMConfig(
emulate, not emulate, pad_inner_dim=config.pad_inner_dim
)
self.backward_config = ScaledMMConfig(
emulate, False, pad_inner_dim=config.pad_inner_dim
)
return self

def set_weight_and_bias(
self, weight: torch.nn.Parameter, bias: Optional[torch.nn.Parameter]
) -> "Float8DynamicLinear":
if config.enable_fsdp_fp8_all_gather:
self.weight = nn.Parameter(
WeightWithDynamicFloat8CastTensor(weight, self.forward_config)
)
else:
self.weight = weight
self.bias = bias
return self

@classmethod
def from_float(
cls,
mod,
emulate: bool = False,
) -> "Float8DynamicLinear":
"""
Create an nn.Linear with fp8 compute from a regular nn.Linear

Args:
mod (torch.nn.Linear): nn.Linear to convert
emulate (bool): whether to emulate fp8 matmul logic in float32
"""
with torch.device("meta"):
super_kwargs = {
"in_features": mod.in_features,
"out_features": mod.out_features,
"bias": False,
}
new_mod = cls(**super_kwargs)

new_mod.forward_config = ScaledMMConfig(
emulate=emulate,
use_fast_accum=not bool(emulate),
fp8_output=False,
pad_inner_dim=config.pad_inner_dim,
)
new_mod.backward_config = ScaledMMConfig(
emulate=emulate,
use_fast_accum=False,
fp8_output=False,
pad_inner_dim=config.pad_inner_dim,
return (
cls.create_meta_class(mod.in_features, mod.out_features)
.set_mm_configs(emulate)
.set_weight_and_bias(mod.weight, mod.bias)
)
if config.enable_fsdp_fp8_all_gather:
new_mod.weight = nn.Parameter(
WeightWithDynamicFloat8CastTensor(mod.weight, new_mod.forward_config)
)
else:
new_mod.weight = mod.weight
new_mod.bias = mod.bias
return new_mod


def cast_to_float8_e4m3fn(
inpt_tensor: torch.Tensor, mm_config: ScaledMMConfig, reduce_amax: bool = False
inpt_tensor: torch.Tensor,
mm_config: ScaledMMConfig,
reduce_amax: bool = False,
) -> Float8Tensor:
"""Casts an input tensor to the Float8 (e4m3fn) format for efficient computation.

Args:
inpt_tensor: The input tensor to be cast.
mm_config: Configuration settings for the matrix multiplication
reduce_amax: Whether to reduce the amax (absolute maximum) among the local distributed group.

Returns:
Float8Tensor: The input tensor cast to Float8 (e4m3fn) format.

Note:
If the input tensor is already in Float8 format, it is returned as is without re-casting.
"""
if tensor_already_casted_to_fp8(inpt_tensor):
return inpt_tensor
scale = tensor_to_scale(inpt_tensor, e4m3_dtype, reduce_amax)
Expand Down
57 changes: 40 additions & 17 deletions float8_experimental/float8_linear_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import copy
import logging
from enum import auto, Enum
from typing import Callable, List, Optional, Type
from typing import Callable, List, Optional, Union

import torch
import torch.distributed as dist
Expand Down Expand Up @@ -97,45 +97,51 @@ def filter_out_small_unaligned_layers(size_limit: int) -> Callable[[nn.Linear],
)


def swap_linear_with_float8_linear(
def swap_linear_layers(
module: nn.Module,
module_cls: Type[nn.Module],
from_float_func: Callable[[nn.Linear], nn.Linear],
*,
skip_fqn_list: Optional[List[str]] = None,
emulate: bool = False,
linear_layer_filter: Optional[Callable[[nn.Linear], bool]] = None,
) -> nn.Module:
"""
Replaces all instances of ``torch.nn.Linear`` in ``module`` with instances
of ``module_cls`` (either ``Float8Linear`` or ``Float8DynamicLinear``).
Generic function to swap linear layers in a module with a new type of linear layer.

Note:
If applied to a root-level nn.Linear, the module will not be modified in place
and returned instead

Args:
module (torch.nn.Module): Module to modify.
module_cls (Union[Type[Float8Linear], Type[Float8DynamicLinear]]): Float8 linear class for the swap.
skip_fqn_list (List[str], optional): If specified, a list of module FQNs to skip.
Linear submodules of these skipped modules will also be skipped.
emulate (bool): Whether to emulate the fp8 matmul logic in fp32.
linear_layer_filter (Optional[Callable[[nn.Linear], bool]]): If specified, only the linear layers
module: Module to modify.
from_float_func: Function that accepts a linear layer and returns a new type of linear layer.
skip_fqn_list: If specified, a list of module FQNs to skip.
linear_layer_filter: If specified, only the linear layers
that pass the filter function will be swapped.
from_float_kwargs: Additional keyword arguments for from_float_func.

Returns:
nn.Module: The modified module with swapped linear layers.
"""
module_names_to_skip = set(skip_fqn_list or [])

if isinstance(module, nn.Linear) and (
linear_layer_filter is None or linear_layer_filter(module)
):
if len(list(module.children())) > 0:
raise AssertionError(
f"Does not support a root nn.Linear with children: {module}"
)
return module_cls.from_float(module, emulate=emulate)
return from_float_func(
module,
)

# Mark all modules to skip as visited
root_module = module
visited_modules = {root_module}

for module_name, module in root_module.named_modules():
if module_name in module_names_to_skip:
visited_modules.add(module)

# Run a post-order traversal to swap linears
def post_order_traversal(
module: nn.Module, module_name: str, parent_module: Optional[nn.Module]
):
Expand All @@ -144,14 +150,15 @@ def post_order_traversal(
if child_module not in visited_modules:
visited_modules.add(child_module)
post_order_traversal(child_module, child_module_name, module)

if isinstance(module, nn.Linear) and (
linear_layer_filter is None or linear_layer_filter(module)
):
assert (
parent_module is not None
), f"Linear root module should return early: {module}"
float8linear_module = module_cls.from_float(module, emulate=emulate)
setattr(parent_module, module_name, float8linear_module)
new_linear_module = from_float_func(module)
setattr(parent_module, module_name, new_linear_module)

post_order_traversal(root_module, "", None)
# Without this explicit `del`, this set only gets deleted upon an explicit
Expand All @@ -160,6 +167,22 @@ def post_order_traversal(
return root_module


def swap_linear_with_float8_linear(
module: nn.Module,
module_cls: Union[Float8Linear, Float8DynamicLinear],
*,
skip_fqn_list: Optional[List[str]] = None,
emulate: bool = False,
linear_layer_filter: Optional[Callable[[nn.Linear], bool]] = None,
) -> nn.Module:
return swap_linear_layers(
module,
lambda m: module_cls.from_float(m, emulate=emulate),
skip_fqn_list=skip_fqn_list,
linear_layer_filter=linear_layer_filter,
)


def get_float8_layers(model: torch.nn.Module):
"""Iterates through the model and returns all the Float8Linear layers.
Args:
Expand Down
1 change: 1 addition & 0 deletions float8_experimental/float8_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,7 @@ def to_float8(
scale: the scale to use to convert the tensor
float8_dtype: the float8 dtype to use
amax_buffer: a buffer to store the amax value in prior to conversion
mm_config: Defines the configuration for the scaled_mm

Returns:
Float8Tensor: a float8 tensor
Expand Down
2 changes: 1 addition & 1 deletion float8_experimental/float8_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ def to_fp8_saturated(x: torch.Tensor, float8_dtype: torch.dtype):
raise ValueError(f"Unsupported float8_dtype: {float8_dtype}")


def compute_error(x: torch.Tensor, y: torch.Tensor):
def compute_error(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
"""Computes the error between two tensors in dB.

For more details see:
Expand Down
Loading
Loading