Skip to content
This repository has been archived by the owner on Aug 7, 2024. It is now read-only.

Commit

Permalink
add static scaling option
Browse files Browse the repository at this point in the history
  • Loading branch information
drisspg committed Jun 21, 2024
1 parent d1eae9a commit 5d5a48e
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 4 deletions.
5 changes: 5 additions & 0 deletions float8_experimental/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,9 @@
from float8_experimental.float8_linear import Float8Linear
from float8_experimental.float8_tensor import Float8Tensor

# Needed to load Float8Tensor with weights_only = True
from torch.serialization import add_safe_globals

add_safe_globals([Float8Tensor])

__all__ = ["Float8Tensor", "Float8Linear"]
40 changes: 36 additions & 4 deletions float8_experimental/float8_dynamic_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,12 @@ class Float8DynamicLinear(torch.nn.Linear):
def __init__(self, **super_kwargs):
super().__init__(**super_kwargs)

self.activation_scale: Optional[torch.Tensor] = None

def forward(self, x):
x_fp8 = cast_to_float8_e4m3fn(x, self.forward_config)
x_fp8 = cast_to_float8_e4m3fn(
x, self.forward_config, activation_scale=self.activation_scale
)
if isinstance(self.weight, Float8Tensor): # cast by FSDP
w_fp8 = self.weight
else:
Expand All @@ -86,7 +90,11 @@ def static_quantize_weight(self, dtype: torch.dtype = torch.float8_e4m3fn) -> No

@classmethod
def from_float(
cls, mod, emulate: bool = False, static_quantize_weight: bool = False
cls,
mod,
emulate: bool = False,
static_quantize_weight: bool = False,
activation_scale: Optional[torch.Tensor] = None,
) -> "Float8DynamicLinear":
"""
Create an nn.Linear with fp8 compute from a regular nn.Linear
Expand All @@ -96,6 +104,8 @@ def from_float(
emulate (bool): whether to emulate fp8 matmul logic in float32
static_quantize_weight (bool): whether to quantize the weight statically, this is useful
for inference where weights are not updated.
activation_scale (torch.Tensor): The scale of the input to this linear module, used for
for inference when a statically known scale is available.
"""
with torch.device("meta"):
super_kwargs = {
Expand All @@ -116,16 +126,38 @@ def from_float(
if static_quantize_weight:
new_mod.static_quantize_weight()

new_mod.activation_scale = activation_scale
new_mod.bias = mod.bias
return new_mod


def cast_to_float8_e4m3fn(
inpt_tensor: torch.Tensor, mm_config: ScaledMMConfig, reduce_amax: bool = False
inpt_tensor: torch.Tensor,
mm_config: ScaledMMConfig,
reduce_amax: bool = False,
activation_scale: Optional[torch.Tensor] = None,
) -> Float8Tensor:
"""Casts an input tensor to the Float8 (e4m3fn) format for efficient computation.
Args:
inpt_tensor: The input tensor to be cast.
mm_config: Configuration settings for the matrix multiplication
reduce_amax: Whether to reduce the amax (absolute maximum) among the local distributed group.
activation_scale: Optional tensor specifying the scale for activation. Default is None.
Returns:
Float8Tensor: The input tensor cast to Float8 (e4m3fn) format.
Note:
If the input tensor is already in Float8 format, it is returned as is without re-casting.
"""
if tensor_already_casted_to_fp8(inpt_tensor):
return inpt_tensor
scale = tensor_to_scale(inpt_tensor, torch.float8_e4m3fn, reduce_amax)
scale = (
activation_scale
if activation_scale is not None
else tensor_to_scale(inpt_tensor, torch.float8_e4m3fn)
)
return Float8Tensor.to_float8(
inpt_tensor, scale, torch.float8_e4m3fn, mm_config=mm_config
)
Expand Down

0 comments on commit 5d5a48e

Please sign in to comment.