Skip to content

Commit

Permalink
Improve FSDP support for low-bit optimizers (pytorch#538)
Browse files Browse the repository at this point in the history
  • Loading branch information
gau-nernst authored Jul 26, 2024
1 parent 8442f03 commit f8472f1
Show file tree
Hide file tree
Showing 6 changed files with 153 additions and 29 deletions.
9 changes: 9 additions & 0 deletions test/prototype/test_low_bit_optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,15 @@ def _test_fsdp2(self, optim_cls):
base_optim.step()
self.assertEqual(fsdp_loss, base_loss)

base_param = base_optim.param_groups[0]["params"][0]
base_exp_avg = base_optim.state[base_param]["exp_avg"]

fsdp_param = fsdp_optim.param_groups[0]["params"][0]
fsdp_exp_avg = fsdp_optim.state[fsdp_param]["exp_avg"]
full_fsdp_exp_avg = fsdp_exp_avg.full_tensor()

self.assertEqual(base_exp_avg.dequantize(), full_fsdp_exp_avg.dequantize())


instantiate_parametrized_tests(TestQuantize)
instantiate_parametrized_tests(TestOptim)
Expand Down
10 changes: 5 additions & 5 deletions torchao/prototype/low_bit_optim/adam.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,11 @@ def _subclass_zeros(p: Tensor, signed: bool, block_size: int):
def _new_buffer(self, p: Tensor, signed: bool):
if p.numel() >= 4096 and p.numel() % self.block_size == 0:
if isinstance(p, DTensor):
out = torch.empty_like(p)
out._local_tensor = self._subclass_zeros(
out._local_tensor,
signed,
self.block_size,
out = DTensor.from_local(
local_tensor=self._subclass_zeros(p.to_local(), signed, self.block_size),
device_mesh=p.device_mesh,
placements=p.placements,
run_check=False,
)
else:
out = self._subclass_zeros(p, signed, self.block_size)
Expand Down
10 changes: 5 additions & 5 deletions torchao/prototype/low_bit_optim/adamw.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,11 @@ def _subclass_zeros(p: Tensor, signed: bool, block_size: int):
def _new_buffer(self, p: Tensor, signed: bool):
if p.numel() >= 4096 and p.numel() % self.block_size == 0:
if isinstance(p, DTensor):
out = torch.empty_like(p)
out._local_tensor = self._subclass_zeros(
out._local_tensor,
signed,
self.block_size,
out = DTensor.from_local(
local_tensor=self._subclass_zeros(p.to_local(), signed, self.block_size),
device_mesh=p.device_mesh,
placements=p.placements,
run_check=False,
)
else:
out = self._subclass_zeros(p, signed, self.block_size)
Expand Down
57 changes: 49 additions & 8 deletions torchao/prototype/low_bit_optim/subclass_4bit.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@


aten = torch.ops.aten

c10d_functional = torch.ops.c10d_functional
_c10d_functional = torch.ops._c10d_functional

# https://github.com/thu-ml/low-bit-optimizers/blob/e3e2854728e498c2a606e3fdb88daa27ae94f9a6/lpmm/configs/2nd_moment_group_128.yml
# NOTE: power-1 is linear
Expand All @@ -31,17 +32,29 @@ def __new__(cls, codes: Tensor, scale: Tensor, qmap: Tensor, signed: bool, shape
)

def __init__(self, codes: Tensor, scale: Tensor, qmap: Tensor, signed: bool, shape):
"""Create quantized 4-bit optimizer state as proposed in https://arxiv.org/abs/2309.01507
Args
codes: quantized and packed 4-bit data stored as uint8.
scale: scale data for block-wise quantization.
qmap: lookup table that maps between quantized value (code) and float value.
signed: whether the tensor is signed or unsigned.
shape: shape of original float tensor.
NOTE: To get block-wise scale, the original float tensor is first reshape to (-1, block_size).
Thus, the last dimension of the original float tensor is not necessarily divisible by block size.
Given `codes` and `scale`, `block_size` is calculated as `codes.numel() * 2 // scale.numel()`.
The extra `* 2` is because `codes` is 4-bit data packed in 8-bit storage.
"""
assert codes.dtype is torch.uint8
assert codes.ndim == 1 # flattened buffer
assert scale.ndim == 1
self.codes = codes
self.scale = scale
self.qmap = qmap
self.signed = signed
self._shape = shape

@property
def block_size(self):
return self.codes.numel() * 2 // self.scale.numel()
self.block_size = codes.numel() * 2 // scale.numel()

def __tensor_flatten__(self):
return self.tensor_attrs, [self.signed, self._shape]
Expand Down Expand Up @@ -113,9 +126,37 @@ def _(func, *args, **kwargs):
return func(*args, **kwargs)


# this is needed for DTensor.from_local() and for flattening tensor
@OptimState4bit.implements(aten.view.default)
def _(func, *args, **kwargs):
x, shape = args
if len(shape) > 1 or shape[0] != -1:
raise ValueError(f"{x.__class__.__name__} only supports .view() with shape=[-1]")
return OptimState4bit(x.codes, x.scale, x.qmap, x.signed, (x.numel(),))

if tuple(x.shape) == tuple(shape):
return OptimState4bit(x.codes, x.scale, x.qmap, x.signed, x._shape)

if len(shape) == 1 and shape[0] == -1:
return OptimState4bit(x.codes, x.scale, x.qmap, x.signed, (x.numel(),))

raise ValueError(f"{x.__class__.__name__} only supports .view() with same shape or shape=[-1]")


# this is needed for DTensor.full_tensor()
@OptimState4bit.implements([
c10d_functional.all_gather_into_tensor.default,
_c10d_functional.all_gather_into_tensor.default,
c10d_functional.wait_tensor.default,
_c10d_functional.wait_tensor.default,
])
def _(func, *args, **kwargs):
x = args[0]
if not isinstance(x, OptimState4bit):
raise ValueError(f"expecting a OptimState4bit but found {type(x)}")

codes = func(x.codes, *args[1:], **kwargs)
scale = func(x.scale, *args[1:], **kwargs)

# adjust the first dim
shape = (x._shape[0] * codes.numel() // x.codes.numel(),) + x._shape[1:]

# assume tensors from all ranks have the same signedness
return OptimState4bit(codes, scale, x.qmap.clone(), x.signed, shape)
51 changes: 44 additions & 7 deletions torchao/prototype/low_bit_optim/subclass_8bit.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,13 @@


aten = torch.ops.aten
c10d_functional = torch.ops.c10d_functional
_c10d_functional = torch.ops._c10d_functional

QMAP_SIGNED = create_dynamic_map(signed=True)
QMAP_UNSIGNED = create_dynamic_map(signed=False)


# dynamic tree quantization
# https://arxiv.org/pdf/1511.04561
# https://arxiv.org/abs/2110.02861
class OptimState8bit(Tensor):
implements = classmethod(_implements)
tensor_attrs = ["codes", "scale", "qmap"]
Expand All @@ -28,15 +27,25 @@ def __new__(cls, codes: Tensor, scale: Tensor, qmap: Tensor, signed: bool):
)

def __init__(self, codes: Tensor, scale: Tensor, qmap: Tensor, signed: bool):
"""Create quantized 8-bit optimizer state as proposed in https://arxiv.org/abs/2110.02861
Args
codes: quantized 8-bit data stored as uint8. Has the same shape as the original float tensor.
scale: scale data for block-wise quantization.
qmap: lookup table that maps between quantized value (code) and float value.
signed: whether the tensor is signed or unsigned.
NOTE: To get block-wise scale, the original float tensor is first reshape to (-1, block_size).
Thus, the last dimension of the original float tensor is not necessarily divisible by block size.
Given `codes` and `scale`, `block_size` is calculated as `codes.numel() // scale.numel()`.
"""
assert codes.dtype is torch.uint8
assert scale.ndim == 1
self.codes = codes
self.scale = scale
self.qmap = qmap
self.signed = signed

@property
def block_size(self):
return self.codes.numel() // self.scale.numel()
self.block_size = codes.numel() // scale.numel()

def __tensor_flatten__(self):
return self.tensor_attrs, [self.signed]
Expand Down Expand Up @@ -97,3 +106,31 @@ def _(func, *args, **kwargs):
def _(func, *args, **kwargs):
args = [x.dequantize() if isinstance(x, OptimState8bit) else x for x in args]
return func(*args, **kwargs)


# this is needed for DTensor.from_local()
@OptimState8bit.implements(aten.view.default)
def _(func, *args, **kwargs):
x, shape = args
return OptimState8bit(x.codes.view(shape), x.scale, x.qmap, x.signed)


# this is needed for DTensor.full_tensor()
@OptimState8bit.implements([
c10d_functional.all_gather_into_tensor.default,
_c10d_functional.all_gather_into_tensor.default,
c10d_functional.wait_tensor.default,
_c10d_functional.wait_tensor.default,
])
def _(func, *args, **kwargs):
x = args[0]
if not isinstance(x, OptimState8bit):
raise ValueError(f"expecting a OptimState8bit but found {type(x)}")

# assume tensors from all ranks have the same signedness
return OptimState8bit(
func(x.codes, *args[1:], **kwargs),
func(x.scale, *args[1:], **kwargs),
x.qmap.clone(),
x.signed,
)
45 changes: 41 additions & 4 deletions torchao/prototype/low_bit_optim/subclass_fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@


aten = torch.ops.aten
c10d_functional = torch.ops.c10d_functional
_c10d_functional = torch.ops._c10d_functional

DTYPE = torch.float8_e4m3fn


Expand Down Expand Up @@ -32,13 +35,21 @@ def __new__(cls, codes: Tensor, scale: Tensor):
)

def __init__(self, codes: Tensor, scale: Tensor):
"""Create quantized FP8 optimizer state.
Args
codes: quantized FP8 E4M3FN data. Has the same shape as the original float tensor.
scale: scale data for block-wise quantization.
NOTE: To get block-wise scale, the original float tensor is first reshape to (-1, block_size).
Thus, the last dimension of the original float tensor is not necessarily divisible by block size.
Given `codes` and `scale`, `block_size` is calculated as `codes.numel() // scale.numel()`.
"""
assert codes.dtype is DTYPE
assert scale.ndim == 1
self.codes = codes
self.scale = scale

@property
def block_size(self):
return self.codes.numel() // self.scale.numel()
self.block_size = codes.numel() // scale.numel()

def __tensor_flatten__(self):
return self.tensor_attrs, []
Expand Down Expand Up @@ -99,3 +110,29 @@ def _(func, *args, **kwargs):
def _(func, *args, **kwargs):
args = [x.dequantize() if isinstance(x, OptimStateFp8) else x for x in args]
return func(*args, **kwargs)


# this is needed for DTensor.from_local()
@OptimStateFp8.implements(aten.view.default)
def _(func, *args, **kwargs):
x, shape = args
return OptimStateFp8(x.codes.view(shape), x.scale)


# this is needed for DTensor.full_tensor()
@OptimStateFp8.implements([
c10d_functional.all_gather_into_tensor.default,
_c10d_functional.all_gather_into_tensor.default,
c10d_functional.wait_tensor.default,
_c10d_functional.wait_tensor.default,
])
def _(func, *args, **kwargs):
x = args[0]
if not isinstance(x, OptimStateFp8):
raise ValueError(f"expecting a OptimStateFp8 but found {type(x)}")

# assume tensors from all ranks have the same signedness
return OptimStateFp8(
func(x.codes, *args[1:], **kwargs),
func(x.scale, *args[1:], **kwargs),
)

0 comments on commit f8472f1

Please sign in to comment.