From 29ad4e378e5956b9109e1196c765460db26c9137 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Thu, 25 Jul 2024 01:57:02 +0000 Subject: [PATCH 1/6] use DTensor.from_local(run_check=False) --- torchao/prototype/low_bit_optim/adam.py | 10 +++++----- torchao/prototype/low_bit_optim/adamw.py | 10 +++++----- torchao/prototype/low_bit_optim/subclass_4bit.py | 12 +++++++++--- torchao/prototype/low_bit_optim/subclass_8bit.py | 7 +++++++ torchao/prototype/low_bit_optim/subclass_fp8.py | 7 +++++++ 5 files changed, 33 insertions(+), 13 deletions(-) diff --git a/torchao/prototype/low_bit_optim/adam.py b/torchao/prototype/low_bit_optim/adam.py index b3b7eeb6f3..47a99c06dc 100644 --- a/torchao/prototype/low_bit_optim/adam.py +++ b/torchao/prototype/low_bit_optim/adam.py @@ -39,11 +39,11 @@ def _subclass_zeros(p: Tensor, signed: bool, block_size: int): def _new_buffer(self, p: Tensor, signed: bool): if p.numel() >= 4096 and p.numel() % self.block_size == 0: if isinstance(p, DTensor): - out = torch.empty_like(p) - out._local_tensor = self._subclass_zeros( - out._local_tensor, - signed, - self.block_size, + out = DTensor.from_local( + local_tensor=self._subclass_zeros(p.to_local(), signed, self.block_size), + device_mesh=p.device_mesh, + placements=p.placements, + run_check=False, ) else: out = self._subclass_zeros(p, signed, self.block_size) diff --git a/torchao/prototype/low_bit_optim/adamw.py b/torchao/prototype/low_bit_optim/adamw.py index ad60caa435..dbde91fdd2 100644 --- a/torchao/prototype/low_bit_optim/adamw.py +++ b/torchao/prototype/low_bit_optim/adamw.py @@ -39,11 +39,11 @@ def _subclass_zeros(p: Tensor, signed: bool, block_size: int): def _new_buffer(self, p: Tensor, signed: bool): if p.numel() >= 4096 and p.numel() % self.block_size == 0: if isinstance(p, DTensor): - out = torch.empty_like(p) - out._local_tensor = self._subclass_zeros( - out._local_tensor, - signed, - self.block_size, + out = DTensor.from_local( + local_tensor=self._subclass_zeros(p.to_local(), signed, self.block_size), + device_mesh=p.device_mesh, + placements=p.placements, + run_check=False, ) else: out = self._subclass_zeros(p, signed, self.block_size) diff --git a/torchao/prototype/low_bit_optim/subclass_4bit.py b/torchao/prototype/low_bit_optim/subclass_4bit.py index 9550b3d51c..6fdb640edd 100644 --- a/torchao/prototype/low_bit_optim/subclass_4bit.py +++ b/torchao/prototype/low_bit_optim/subclass_4bit.py @@ -113,9 +113,15 @@ def _(func, *args, **kwargs): return func(*args, **kwargs) +# this is needed for DTensor.from_local() and for flattening tensor @OptimState4bit.implements(aten.view.default) def _(func, *args, **kwargs): x, shape = args - if len(shape) > 1 or shape[0] != -1: - raise ValueError(f"{x.__class__.__name__} only supports .view() with shape=[-1]") - return OptimState4bit(x.codes, x.scale, x.qmap, x.signed, (x.numel(),)) + + if tuple(x.shape) == tuple(shape): + return OptimState4bit(x.codes, x.scale, x.qmap, x.signed, x._shape) + + if len(shape) == 1 and shape[0] == -1: + return OptimState4bit(x.codes, x.scale, x.qmap, x.signed, (x.numel(),)) + + raise ValueError(f"{x.__class__.__name__} only supports .view() with same shape or shape=[-1]") diff --git a/torchao/prototype/low_bit_optim/subclass_8bit.py b/torchao/prototype/low_bit_optim/subclass_8bit.py index 5b16f6363f..3b291ab61e 100644 --- a/torchao/prototype/low_bit_optim/subclass_8bit.py +++ b/torchao/prototype/low_bit_optim/subclass_8bit.py @@ -97,3 +97,10 @@ def _(func, *args, **kwargs): def _(func, *args, **kwargs): args = [x.dequantize() if isinstance(x, OptimState8bit) else x for x in args] return func(*args, **kwargs) + + +# this is needed for DTensor.from_local() +@OptimState8bit.implements(aten.view.default) +def _(func, *args, **kwargs): + x, shape = args + return OptimState8bit(x.codes.view(shape), x.scale, x.qmap, x.signed) diff --git a/torchao/prototype/low_bit_optim/subclass_fp8.py b/torchao/prototype/low_bit_optim/subclass_fp8.py index e3116e20f8..59bdd53a9d 100644 --- a/torchao/prototype/low_bit_optim/subclass_fp8.py +++ b/torchao/prototype/low_bit_optim/subclass_fp8.py @@ -99,3 +99,10 @@ def _(func, *args, **kwargs): def _(func, *args, **kwargs): args = [x.dequantize() if isinstance(x, OptimStateFp8) else x for x in args] return func(*args, **kwargs) + + +# this is needed for DTensor.from_local() +@OptimStateFp8.implements(aten.view.default) +def _(func, *args, **kwargs): + x, shape = args + return OptimStateFp8(x.codes.view(shape), x.scale) From b701e2863d236bbfe6f15299da98d33013d70fe6 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Thu, 25 Jul 2024 02:10:20 +0000 Subject: [PATCH 2/6] cache block_size as an attribute --- torchao/prototype/low_bit_optim/subclass_4bit.py | 5 +---- torchao/prototype/low_bit_optim/subclass_8bit.py | 5 +---- torchao/prototype/low_bit_optim/subclass_fp8.py | 5 +---- 3 files changed, 3 insertions(+), 12 deletions(-) diff --git a/torchao/prototype/low_bit_optim/subclass_4bit.py b/torchao/prototype/low_bit_optim/subclass_4bit.py index 6fdb640edd..2242fc1bad 100644 --- a/torchao/prototype/low_bit_optim/subclass_4bit.py +++ b/torchao/prototype/low_bit_optim/subclass_4bit.py @@ -38,10 +38,7 @@ def __init__(self, codes: Tensor, scale: Tensor, qmap: Tensor, signed: bool, sha self.qmap = qmap self.signed = signed self._shape = shape - - @property - def block_size(self): - return self.codes.numel() * 2 // self.scale.numel() + self.block_size = codes.numel() * 2 // scale.numel() def __tensor_flatten__(self): return self.tensor_attrs, [self.signed, self._shape] diff --git a/torchao/prototype/low_bit_optim/subclass_8bit.py b/torchao/prototype/low_bit_optim/subclass_8bit.py index 3b291ab61e..63b412fc3e 100644 --- a/torchao/prototype/low_bit_optim/subclass_8bit.py +++ b/torchao/prototype/low_bit_optim/subclass_8bit.py @@ -33,10 +33,7 @@ def __init__(self, codes: Tensor, scale: Tensor, qmap: Tensor, signed: bool): self.scale = scale self.qmap = qmap self.signed = signed - - @property - def block_size(self): - return self.codes.numel() // self.scale.numel() + self.block_size = codes.numel() // scale.numel() def __tensor_flatten__(self): return self.tensor_attrs, [self.signed] diff --git a/torchao/prototype/low_bit_optim/subclass_fp8.py b/torchao/prototype/low_bit_optim/subclass_fp8.py index 59bdd53a9d..aa2d3efb46 100644 --- a/torchao/prototype/low_bit_optim/subclass_fp8.py +++ b/torchao/prototype/low_bit_optim/subclass_fp8.py @@ -35,10 +35,7 @@ def __init__(self, codes: Tensor, scale: Tensor): assert codes.dtype is DTYPE self.codes = codes self.scale = scale - - @property - def block_size(self): - return self.codes.numel() // self.scale.numel() + self.block_size = codes.numel() // scale.numel() def __tensor_flatten__(self): return self.tensor_attrs, [] From 210cf8c639d39af34da080d7142ac166d120cf9a Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Thu, 25 Jul 2024 02:44:56 +0000 Subject: [PATCH 3/6] support DTensor.full_tensor() --- test/prototype/test_low_bit_optim.py | 9 +++++++ .../prototype/low_bit_optim/subclass_4bit.py | 25 ++++++++++++++++++- .../prototype/low_bit_optim/subclass_8bit.py | 23 +++++++++++++++++ .../prototype/low_bit_optim/subclass_fp8.py | 22 ++++++++++++++++ 4 files changed, 78 insertions(+), 1 deletion(-) diff --git a/test/prototype/test_low_bit_optim.py b/test/prototype/test_low_bit_optim.py index 94cfe34096..5eb0a54b62 100644 --- a/test/prototype/test_low_bit_optim.py +++ b/test/prototype/test_low_bit_optim.py @@ -226,6 +226,15 @@ def _test_fsdp2(self, optim_cls): base_optim.step() self.assertEqual(fsdp_loss, base_loss) + base_param = base_optim.param_groups[0]["params"][0] + base_exp_avg = base_optim.state[base_param]["exp_avg"] + + fsdp_param = fsdp_optim.param_groups[0]["params"][0] + fsdp_exp_avg = fsdp_optim.state[fsdp_param]["exp_avg"] + full_fsdp_exp_avg = fsdp_exp_avg.full_tensor() + + self.assertEqual(base_exp_avg.dequantize(), full_fsdp_exp_avg.dequantize()) + instantiate_parametrized_tests(TestQuantize) instantiate_parametrized_tests(TestOptim) diff --git a/torchao/prototype/low_bit_optim/subclass_4bit.py b/torchao/prototype/low_bit_optim/subclass_4bit.py index 2242fc1bad..c297571d50 100644 --- a/torchao/prototype/low_bit_optim/subclass_4bit.py +++ b/torchao/prototype/low_bit_optim/subclass_4bit.py @@ -8,7 +8,8 @@ aten = torch.ops.aten - +c10d_functional = torch.ops.c10d_functional +_c10d_functional = torch.ops._c10d_functional # https://github.com/thu-ml/low-bit-optimizers/blob/e3e2854728e498c2a606e3fdb88daa27ae94f9a6/lpmm/configs/2nd_moment_group_128.yml # NOTE: power-1 is linear @@ -122,3 +123,25 @@ def _(func, *args, **kwargs): return OptimState4bit(x.codes, x.scale, x.qmap, x.signed, (x.numel(),)) raise ValueError(f"{x.__class__.__name__} only supports .view() with same shape or shape=[-1]") + + +# this is needed for DTensor.full_tensor() +@OptimState4bit.implements([ + c10d_functional.all_gather_into_tensor.default, + _c10d_functional.all_gather_into_tensor.default, + c10d_functional.wait_tensor.default, + _c10d_functional.wait_tensor.default, +]) +def _(func, *args, **kwargs): + x = args[0] + if not isinstance(x, OptimState4bit): + raise ValueError(f"expecting a OptimState4bit but found {type(x)}") + + codes = func(x.codes, *args[1:], **kwargs) + scale = func(x.scale, *args[1:], **kwargs) + + # adjust the first dim + shape = (x._shape[0] * codes.numel() // x.codes.numel(),) + x._shape[1:] + + # assume tensors from all ranks have the same signedness + return OptimState4bit(codes, scale, x.qmap.clone(), x.signed, shape) diff --git a/torchao/prototype/low_bit_optim/subclass_8bit.py b/torchao/prototype/low_bit_optim/subclass_8bit.py index 63b412fc3e..a523805fd9 100644 --- a/torchao/prototype/low_bit_optim/subclass_8bit.py +++ b/torchao/prototype/low_bit_optim/subclass_8bit.py @@ -6,6 +6,8 @@ aten = torch.ops.aten +c10d_functional = torch.ops.c10d_functional +_c10d_functional = torch.ops._c10d_functional QMAP_SIGNED = create_dynamic_map(signed=True) QMAP_UNSIGNED = create_dynamic_map(signed=False) @@ -101,3 +103,24 @@ def _(func, *args, **kwargs): def _(func, *args, **kwargs): x, shape = args return OptimState8bit(x.codes.view(shape), x.scale, x.qmap, x.signed) + + +# this is needed for DTensor.full_tensor() +@OptimState8bit.implements([ + c10d_functional.all_gather_into_tensor.default, + _c10d_functional.all_gather_into_tensor.default, + c10d_functional.wait_tensor.default, + _c10d_functional.wait_tensor.default, +]) +def _(func, *args, **kwargs): + x = args[0] + if not isinstance(x, OptimState8bit): + raise ValueError(f"expecting a OptimState8bit but found {type(x)}") + + # assume tensors from all ranks have the same signedness + return OptimState8bit( + func(x.codes, *args[1:], **kwargs), + func(x.scale, *args[1:], **kwargs), + x.qmap.clone(), + x.signed, + ) diff --git a/torchao/prototype/low_bit_optim/subclass_fp8.py b/torchao/prototype/low_bit_optim/subclass_fp8.py index aa2d3efb46..4936a65628 100644 --- a/torchao/prototype/low_bit_optim/subclass_fp8.py +++ b/torchao/prototype/low_bit_optim/subclass_fp8.py @@ -4,6 +4,9 @@ aten = torch.ops.aten +c10d_functional = torch.ops.c10d_functional +_c10d_functional = torch.ops._c10d_functional + DTYPE = torch.float8_e4m3fn @@ -103,3 +106,22 @@ def _(func, *args, **kwargs): def _(func, *args, **kwargs): x, shape = args return OptimStateFp8(x.codes.view(shape), x.scale) + + +# this is needed for DTensor.full_tensor() +@OptimStateFp8.implements([ + c10d_functional.all_gather_into_tensor.default, + _c10d_functional.all_gather_into_tensor.default, + c10d_functional.wait_tensor.default, + _c10d_functional.wait_tensor.default, +]) +def _(func, *args, **kwargs): + x = args[0] + if not isinstance(x, OptimStateFp8): + raise ValueError(f"expecting a OptimStateFp8 but found {type(x)}") + + # assume tensors from all ranks have the same signedness + return OptimStateFp8( + func(x.codes, *args[1:], **kwargs), + func(x.scale, *args[1:], **kwargs), + ) From 279741c963106f3ec30480fd848da5d1590d5149 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Thu, 25 Jul 2024 07:31:28 +0000 Subject: [PATCH 4/6] add docs --- torchao/prototype/low_bit_optim/subclass_4bit.py | 10 ++++++++++ torchao/prototype/low_bit_optim/subclass_8bit.py | 9 +++++++++ torchao/prototype/low_bit_optim/subclass_fp8.py | 7 +++++++ 3 files changed, 26 insertions(+) diff --git a/torchao/prototype/low_bit_optim/subclass_4bit.py b/torchao/prototype/low_bit_optim/subclass_4bit.py index c297571d50..4a2d8f1057 100644 --- a/torchao/prototype/low_bit_optim/subclass_4bit.py +++ b/torchao/prototype/low_bit_optim/subclass_4bit.py @@ -32,8 +32,18 @@ def __new__(cls, codes: Tensor, scale: Tensor, qmap: Tensor, signed: bool, shape ) def __init__(self, codes: Tensor, scale: Tensor, qmap: Tensor, signed: bool, shape): + """Create quantized 4-bit optimizer state. + + Args + codes: quantized 4-bit data stored as uint8. + scale: scale data for block-wise quantization. + qmap: lookup table that maps between quantized value (code) and float value. + signed: whether the tensor is signed or unsigned. + shape: shape of original float tensor. + """ assert codes.dtype is torch.uint8 assert codes.ndim == 1 # flattened buffer + assert scale.ndim == 1 self.codes = codes self.scale = scale self.qmap = qmap diff --git a/torchao/prototype/low_bit_optim/subclass_8bit.py b/torchao/prototype/low_bit_optim/subclass_8bit.py index a523805fd9..2c7332ab9f 100644 --- a/torchao/prototype/low_bit_optim/subclass_8bit.py +++ b/torchao/prototype/low_bit_optim/subclass_8bit.py @@ -30,7 +30,16 @@ def __new__(cls, codes: Tensor, scale: Tensor, qmap: Tensor, signed: bool): ) def __init__(self, codes: Tensor, scale: Tensor, qmap: Tensor, signed: bool): + """Create quantized 8-bit optimizer state. + + Args + codes: quantized 8-bit data stored as uint8. has the same shape as the original float tensor. + scale: scale data for block-wise quantization. + qmap: lookup table that maps between quantized value (code) and float value. + signed: whether the tensor is signed or unsigned. + """ assert codes.dtype is torch.uint8 + assert scale.ndim == 1 self.codes = codes self.scale = scale self.qmap = qmap diff --git a/torchao/prototype/low_bit_optim/subclass_fp8.py b/torchao/prototype/low_bit_optim/subclass_fp8.py index 4936a65628..fa5f17c23b 100644 --- a/torchao/prototype/low_bit_optim/subclass_fp8.py +++ b/torchao/prototype/low_bit_optim/subclass_fp8.py @@ -35,7 +35,14 @@ def __new__(cls, codes: Tensor, scale: Tensor): ) def __init__(self, codes: Tensor, scale: Tensor): + """Create quantized FP8 optimizer state. + + Args + codes: quantized FP8 E4M3FN data. has the same shape as the original float tensor. + scale: scale data for block-wise quantization. + """ assert codes.dtype is DTYPE + assert scale.ndim == 1 self.codes = codes self.scale = scale self.block_size = codes.numel() // scale.numel() From bce599f2a49c6b635124aaf28d1a9ec33c6926e7 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Fri, 26 Jul 2024 01:02:43 +0000 Subject: [PATCH 5/6] update docs --- torchao/prototype/low_bit_optim/subclass_4bit.py | 9 +++++++-- torchao/prototype/low_bit_optim/subclass_8bit.py | 11 ++++++----- torchao/prototype/low_bit_optim/subclass_fp8.py | 6 +++++- 3 files changed, 18 insertions(+), 8 deletions(-) diff --git a/torchao/prototype/low_bit_optim/subclass_4bit.py b/torchao/prototype/low_bit_optim/subclass_4bit.py index 4a2d8f1057..51a80cb5b9 100644 --- a/torchao/prototype/low_bit_optim/subclass_4bit.py +++ b/torchao/prototype/low_bit_optim/subclass_4bit.py @@ -32,14 +32,19 @@ def __new__(cls, codes: Tensor, scale: Tensor, qmap: Tensor, signed: bool, shape ) def __init__(self, codes: Tensor, scale: Tensor, qmap: Tensor, signed: bool, shape): - """Create quantized 4-bit optimizer state. + """Create quantized 4-bit optimizer state as proposed in https://arxiv.org/abs/2309.01507. Args - codes: quantized 4-bit data stored as uint8. + codes: quantized and packed 4-bit data stored as uint8. scale: scale data for block-wise quantization. qmap: lookup table that maps between quantized value (code) and float value. signed: whether the tensor is signed or unsigned. shape: shape of original float tensor. + + NOTE: To get block-wise scale, the original float tensor is first reshape to (-1, block_size). + Thus, the last dimension of the original float tensor is not necessarily divisible by block size. + Given `codes` and `scale`, `block_size` is calculated as `codes.numel() * 2 // scale.numel()`. + The extra `* 2` is because `codes` is 4-bit data packed in 8-bit storage. """ assert codes.dtype is torch.uint8 assert codes.ndim == 1 # flattened buffer diff --git a/torchao/prototype/low_bit_optim/subclass_8bit.py b/torchao/prototype/low_bit_optim/subclass_8bit.py index 2c7332ab9f..6734f99ecc 100644 --- a/torchao/prototype/low_bit_optim/subclass_8bit.py +++ b/torchao/prototype/low_bit_optim/subclass_8bit.py @@ -13,9 +13,6 @@ QMAP_UNSIGNED = create_dynamic_map(signed=False) -# dynamic tree quantization -# https://arxiv.org/pdf/1511.04561 -# https://arxiv.org/abs/2110.02861 class OptimState8bit(Tensor): implements = classmethod(_implements) tensor_attrs = ["codes", "scale", "qmap"] @@ -30,13 +27,17 @@ def __new__(cls, codes: Tensor, scale: Tensor, qmap: Tensor, signed: bool): ) def __init__(self, codes: Tensor, scale: Tensor, qmap: Tensor, signed: bool): - """Create quantized 8-bit optimizer state. + """Create quantized 8-bit optimizer state as proposed in https://arxiv.org/abs/2110.02861. Args - codes: quantized 8-bit data stored as uint8. has the same shape as the original float tensor. + codes: quantized 8-bit data stored as uint8. Has the same shape as the original float tensor. scale: scale data for block-wise quantization. qmap: lookup table that maps between quantized value (code) and float value. signed: whether the tensor is signed or unsigned. + + NOTE: To get block-wise scale, the original float tensor is first reshape to (-1, block_size). + Thus, the last dimension of the original float tensor is not necessarily divisible by block size. + Given `codes` and `scale`, `block_size` is calculated as `codes.numel() // scale.numel()`. """ assert codes.dtype is torch.uint8 assert scale.ndim == 1 diff --git a/torchao/prototype/low_bit_optim/subclass_fp8.py b/torchao/prototype/low_bit_optim/subclass_fp8.py index fa5f17c23b..b78638cd01 100644 --- a/torchao/prototype/low_bit_optim/subclass_fp8.py +++ b/torchao/prototype/low_bit_optim/subclass_fp8.py @@ -38,8 +38,12 @@ def __init__(self, codes: Tensor, scale: Tensor): """Create quantized FP8 optimizer state. Args - codes: quantized FP8 E4M3FN data. has the same shape as the original float tensor. + codes: quantized FP8 E4M3FN data. Has the same shape as the original float tensor. scale: scale data for block-wise quantization. + + NOTE: To get block-wise scale, the original float tensor is first reshape to (-1, block_size). + Thus, the last dimension of the original float tensor is not necessarily divisible by block size. + Given `codes` and `scale`, `block_size` is calculated as `codes.numel() // scale.numel()`. """ assert codes.dtype is DTYPE assert scale.ndim == 1 From 6cec214c5684b710b96e5408d3b6d24a435b4e57 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Fri, 26 Jul 2024 01:22:33 +0000 Subject: [PATCH 6/6] remove full stop --- torchao/prototype/low_bit_optim/subclass_4bit.py | 2 +- torchao/prototype/low_bit_optim/subclass_8bit.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/torchao/prototype/low_bit_optim/subclass_4bit.py b/torchao/prototype/low_bit_optim/subclass_4bit.py index 51a80cb5b9..a24cf8b1d5 100644 --- a/torchao/prototype/low_bit_optim/subclass_4bit.py +++ b/torchao/prototype/low_bit_optim/subclass_4bit.py @@ -32,7 +32,7 @@ def __new__(cls, codes: Tensor, scale: Tensor, qmap: Tensor, signed: bool, shape ) def __init__(self, codes: Tensor, scale: Tensor, qmap: Tensor, signed: bool, shape): - """Create quantized 4-bit optimizer state as proposed in https://arxiv.org/abs/2309.01507. + """Create quantized 4-bit optimizer state as proposed in https://arxiv.org/abs/2309.01507 Args codes: quantized and packed 4-bit data stored as uint8. diff --git a/torchao/prototype/low_bit_optim/subclass_8bit.py b/torchao/prototype/low_bit_optim/subclass_8bit.py index 6734f99ecc..1e2067963a 100644 --- a/torchao/prototype/low_bit_optim/subclass_8bit.py +++ b/torchao/prototype/low_bit_optim/subclass_8bit.py @@ -27,7 +27,7 @@ def __new__(cls, codes: Tensor, scale: Tensor, qmap: Tensor, signed: bool): ) def __init__(self, codes: Tensor, scale: Tensor, qmap: Tensor, signed: bool): - """Create quantized 8-bit optimizer state as proposed in https://arxiv.org/abs/2110.02861. + """Create quantized 8-bit optimizer state as proposed in https://arxiv.org/abs/2110.02861 Args codes: quantized 8-bit data stored as uint8. Has the same shape as the original float tensor.