-
Notifications
You must be signed in to change notification settings - Fork 223
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Improve FSDP support for low-bit optimizers #538
Changes from 6 commits
29ad4e3
b701e28
210cf8c
279741c
bce599f
c98509a
6cec214
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -8,7 +8,8 @@ | |
|
||
|
||
aten = torch.ops.aten | ||
|
||
c10d_functional = torch.ops.c10d_functional | ||
_c10d_functional = torch.ops._c10d_functional | ||
|
||
# https://github.com/thu-ml/low-bit-optimizers/blob/e3e2854728e498c2a606e3fdb88daa27ae94f9a6/lpmm/configs/2nd_moment_group_128.yml | ||
# NOTE: power-1 is linear | ||
|
@@ -31,17 +32,29 @@ def __new__(cls, codes: Tensor, scale: Tensor, qmap: Tensor, signed: bool, shape | |
) | ||
|
||
def __init__(self, codes: Tensor, scale: Tensor, qmap: Tensor, signed: bool, shape): | ||
"""Create quantized 4-bit optimizer state as proposed in https://arxiv.org/abs/2309.01507. | ||
|
||
Args | ||
codes: quantized and packed 4-bit data stored as uint8. | ||
scale: scale data for block-wise quantization. | ||
qmap: lookup table that maps between quantized value (code) and float value. | ||
signed: whether the tensor is signed or unsigned. | ||
shape: shape of original float tensor. | ||
|
||
NOTE: To get block-wise scale, the original float tensor is first reshape to (-1, block_size). | ||
Thus, the last dimension of the original float tensor is not necessarily divisible by block size. | ||
Given `codes` and `scale`, `block_size` is calculated as `codes.numel() * 2 // scale.numel()`. | ||
The extra `* 2` is because `codes` is 4-bit data packed in 8-bit storage. | ||
""" | ||
assert codes.dtype is torch.uint8 | ||
assert codes.ndim == 1 # flattened buffer | ||
assert scale.ndim == 1 | ||
self.codes = codes | ||
self.scale = scale | ||
self.qmap = qmap | ||
self.signed = signed | ||
self._shape = shape | ||
|
||
@property | ||
def block_size(self): | ||
return self.codes.numel() * 2 // self.scale.numel() | ||
self.block_size = codes.numel() * 2 // scale.numel() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. curious q: Is there some description of the codes/ scales tensor and their relation to each other? I can see the pattern that codes has .5x (4bit) and 1x (8bit) the bsize * scale numels There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I will add some description. Basically for 8-bit and FP8, There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @drisspg Added some docs. Lmk if it is still unclear. |
||
|
||
def __tensor_flatten__(self): | ||
return self.tensor_attrs, [self.signed, self._shape] | ||
|
@@ -113,9 +126,37 @@ def _(func, *args, **kwargs): | |
return func(*args, **kwargs) | ||
|
||
|
||
# this is needed for DTensor.from_local() and for flattening tensor | ||
@OptimState4bit.implements(aten.view.default) | ||
def _(func, *args, **kwargs): | ||
x, shape = args | ||
if len(shape) > 1 or shape[0] != -1: | ||
raise ValueError(f"{x.__class__.__name__} only supports .view() with shape=[-1]") | ||
return OptimState4bit(x.codes, x.scale, x.qmap, x.signed, (x.numel(),)) | ||
|
||
if tuple(x.shape) == tuple(shape): | ||
return OptimState4bit(x.codes, x.scale, x.qmap, x.signed, x._shape) | ||
|
||
if len(shape) == 1 and shape[0] == -1: | ||
return OptimState4bit(x.codes, x.scale, x.qmap, x.signed, (x.numel(),)) | ||
|
||
raise ValueError(f"{x.__class__.__name__} only supports .view() with same shape or shape=[-1]") | ||
|
||
|
||
# this is needed for DTensor.full_tensor() | ||
@OptimState4bit.implements([ | ||
c10d_functional.all_gather_into_tensor.default, | ||
_c10d_functional.all_gather_into_tensor.default, | ||
c10d_functional.wait_tensor.default, | ||
_c10d_functional.wait_tensor.default, | ||
]) | ||
def _(func, *args, **kwargs): | ||
x = args[0] | ||
if not isinstance(x, OptimState4bit): | ||
raise ValueError(f"expecting a OptimState4bit but found {type(x)}") | ||
|
||
codes = func(x.codes, *args[1:], **kwargs) | ||
scale = func(x.scale, *args[1:], **kwargs) | ||
|
||
# adjust the first dim | ||
shape = (x._shape[0] * codes.numel() // x.codes.numel(),) + x._shape[1:] | ||
|
||
# assume tensors from all ranks have the same signedness | ||
return OptimState4bit(codes, scale, x.qmap.clone(), x.signed, shape) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,14 +6,13 @@ | |
|
||
|
||
aten = torch.ops.aten | ||
c10d_functional = torch.ops.c10d_functional | ||
_c10d_functional = torch.ops._c10d_functional | ||
|
||
QMAP_SIGNED = create_dynamic_map(signed=True) | ||
QMAP_UNSIGNED = create_dynamic_map(signed=False) | ||
|
||
|
||
# dynamic tree quantization | ||
# https://arxiv.org/pdf/1511.04561 | ||
# https://arxiv.org/abs/2110.02861 | ||
class OptimState8bit(Tensor): | ||
implements = classmethod(_implements) | ||
tensor_attrs = ["codes", "scale", "qmap"] | ||
|
@@ -28,15 +27,25 @@ def __new__(cls, codes: Tensor, scale: Tensor, qmap: Tensor, signed: bool): | |
) | ||
|
||
def __init__(self, codes: Tensor, scale: Tensor, qmap: Tensor, signed: bool): | ||
"""Create quantized 8-bit optimizer state as proposed in https://arxiv.org/abs/2110.02861. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. same for this one |
||
|
||
Args | ||
codes: quantized 8-bit data stored as uint8. Has the same shape as the original float tensor. | ||
scale: scale data for block-wise quantization. | ||
qmap: lookup table that maps between quantized value (code) and float value. | ||
signed: whether the tensor is signed or unsigned. | ||
|
||
NOTE: To get block-wise scale, the original float tensor is first reshape to (-1, block_size). | ||
Thus, the last dimension of the original float tensor is not necessarily divisible by block size. | ||
Given `codes` and `scale`, `block_size` is calculated as `codes.numel() // scale.numel()`. | ||
""" | ||
assert codes.dtype is torch.uint8 | ||
assert scale.ndim == 1 | ||
self.codes = codes | ||
self.scale = scale | ||
self.qmap = qmap | ||
self.signed = signed | ||
|
||
@property | ||
def block_size(self): | ||
return self.codes.numel() // self.scale.numel() | ||
self.block_size = codes.numel() // scale.numel() | ||
|
||
def __tensor_flatten__(self): | ||
return self.tensor_attrs, [self.signed] | ||
|
@@ -97,3 +106,31 @@ def _(func, *args, **kwargs): | |
def _(func, *args, **kwargs): | ||
args = [x.dequantize() if isinstance(x, OptimState8bit) else x for x in args] | ||
return func(*args, **kwargs) | ||
|
||
|
||
# this is needed for DTensor.from_local() | ||
@OptimState8bit.implements(aten.view.default) | ||
def _(func, *args, **kwargs): | ||
x, shape = args | ||
return OptimState8bit(x.codes.view(shape), x.scale, x.qmap, x.signed) | ||
|
||
|
||
# this is needed for DTensor.full_tensor() | ||
@OptimState8bit.implements([ | ||
c10d_functional.all_gather_into_tensor.default, | ||
_c10d_functional.all_gather_into_tensor.default, | ||
c10d_functional.wait_tensor.default, | ||
_c10d_functional.wait_tensor.default, | ||
]) | ||
def _(func, *args, **kwargs): | ||
x = args[0] | ||
if not isinstance(x, OptimState8bit): | ||
raise ValueError(f"expecting a OptimState8bit but found {type(x)}") | ||
|
||
# assume tensors from all ranks have the same signedness | ||
return OptimState8bit( | ||
func(x.codes, *args[1:], **kwargs), | ||
func(x.scale, *args[1:], **kwargs), | ||
x.qmap.clone(), | ||
x.signed, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
btw the link is not valid, can you remove
.
in the end?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done