diff --git a/test/prototype/test_low_bit_optim.py b/test/prototype/test_low_bit_optim.py index 94cfe34096..5eb0a54b62 100644 --- a/test/prototype/test_low_bit_optim.py +++ b/test/prototype/test_low_bit_optim.py @@ -226,6 +226,15 @@ def _test_fsdp2(self, optim_cls): base_optim.step() self.assertEqual(fsdp_loss, base_loss) + base_param = base_optim.param_groups[0]["params"][0] + base_exp_avg = base_optim.state[base_param]["exp_avg"] + + fsdp_param = fsdp_optim.param_groups[0]["params"][0] + fsdp_exp_avg = fsdp_optim.state[fsdp_param]["exp_avg"] + full_fsdp_exp_avg = fsdp_exp_avg.full_tensor() + + self.assertEqual(base_exp_avg.dequantize(), full_fsdp_exp_avg.dequantize()) + instantiate_parametrized_tests(TestQuantize) instantiate_parametrized_tests(TestOptim) diff --git a/torchao/prototype/low_bit_optim/adam.py b/torchao/prototype/low_bit_optim/adam.py index b3b7eeb6f3..47a99c06dc 100644 --- a/torchao/prototype/low_bit_optim/adam.py +++ b/torchao/prototype/low_bit_optim/adam.py @@ -39,11 +39,11 @@ def _subclass_zeros(p: Tensor, signed: bool, block_size: int): def _new_buffer(self, p: Tensor, signed: bool): if p.numel() >= 4096 and p.numel() % self.block_size == 0: if isinstance(p, DTensor): - out = torch.empty_like(p) - out._local_tensor = self._subclass_zeros( - out._local_tensor, - signed, - self.block_size, + out = DTensor.from_local( + local_tensor=self._subclass_zeros(p.to_local(), signed, self.block_size), + device_mesh=p.device_mesh, + placements=p.placements, + run_check=False, ) else: out = self._subclass_zeros(p, signed, self.block_size) diff --git a/torchao/prototype/low_bit_optim/adamw.py b/torchao/prototype/low_bit_optim/adamw.py index ad60caa435..dbde91fdd2 100644 --- a/torchao/prototype/low_bit_optim/adamw.py +++ b/torchao/prototype/low_bit_optim/adamw.py @@ -39,11 +39,11 @@ def _subclass_zeros(p: Tensor, signed: bool, block_size: int): def _new_buffer(self, p: Tensor, signed: bool): if p.numel() >= 4096 and p.numel() % self.block_size == 0: if isinstance(p, DTensor): - out = torch.empty_like(p) - out._local_tensor = self._subclass_zeros( - out._local_tensor, - signed, - self.block_size, + out = DTensor.from_local( + local_tensor=self._subclass_zeros(p.to_local(), signed, self.block_size), + device_mesh=p.device_mesh, + placements=p.placements, + run_check=False, ) else: out = self._subclass_zeros(p, signed, self.block_size) diff --git a/torchao/prototype/low_bit_optim/subclass_4bit.py b/torchao/prototype/low_bit_optim/subclass_4bit.py index 9550b3d51c..a24cf8b1d5 100644 --- a/torchao/prototype/low_bit_optim/subclass_4bit.py +++ b/torchao/prototype/low_bit_optim/subclass_4bit.py @@ -8,7 +8,8 @@ aten = torch.ops.aten - +c10d_functional = torch.ops.c10d_functional +_c10d_functional = torch.ops._c10d_functional # https://github.com/thu-ml/low-bit-optimizers/blob/e3e2854728e498c2a606e3fdb88daa27ae94f9a6/lpmm/configs/2nd_moment_group_128.yml # NOTE: power-1 is linear @@ -31,17 +32,29 @@ def __new__(cls, codes: Tensor, scale: Tensor, qmap: Tensor, signed: bool, shape ) def __init__(self, codes: Tensor, scale: Tensor, qmap: Tensor, signed: bool, shape): + """Create quantized 4-bit optimizer state as proposed in https://arxiv.org/abs/2309.01507 + + Args + codes: quantized and packed 4-bit data stored as uint8. + scale: scale data for block-wise quantization. + qmap: lookup table that maps between quantized value (code) and float value. + signed: whether the tensor is signed or unsigned. + shape: shape of original float tensor. + + NOTE: To get block-wise scale, the original float tensor is first reshape to (-1, block_size). + Thus, the last dimension of the original float tensor is not necessarily divisible by block size. + Given `codes` and `scale`, `block_size` is calculated as `codes.numel() * 2 // scale.numel()`. + The extra `* 2` is because `codes` is 4-bit data packed in 8-bit storage. + """ assert codes.dtype is torch.uint8 assert codes.ndim == 1 # flattened buffer + assert scale.ndim == 1 self.codes = codes self.scale = scale self.qmap = qmap self.signed = signed self._shape = shape - - @property - def block_size(self): - return self.codes.numel() * 2 // self.scale.numel() + self.block_size = codes.numel() * 2 // scale.numel() def __tensor_flatten__(self): return self.tensor_attrs, [self.signed, self._shape] @@ -113,9 +126,37 @@ def _(func, *args, **kwargs): return func(*args, **kwargs) +# this is needed for DTensor.from_local() and for flattening tensor @OptimState4bit.implements(aten.view.default) def _(func, *args, **kwargs): x, shape = args - if len(shape) > 1 or shape[0] != -1: - raise ValueError(f"{x.__class__.__name__} only supports .view() with shape=[-1]") - return OptimState4bit(x.codes, x.scale, x.qmap, x.signed, (x.numel(),)) + + if tuple(x.shape) == tuple(shape): + return OptimState4bit(x.codes, x.scale, x.qmap, x.signed, x._shape) + + if len(shape) == 1 and shape[0] == -1: + return OptimState4bit(x.codes, x.scale, x.qmap, x.signed, (x.numel(),)) + + raise ValueError(f"{x.__class__.__name__} only supports .view() with same shape or shape=[-1]") + + +# this is needed for DTensor.full_tensor() +@OptimState4bit.implements([ + c10d_functional.all_gather_into_tensor.default, + _c10d_functional.all_gather_into_tensor.default, + c10d_functional.wait_tensor.default, + _c10d_functional.wait_tensor.default, +]) +def _(func, *args, **kwargs): + x = args[0] + if not isinstance(x, OptimState4bit): + raise ValueError(f"expecting a OptimState4bit but found {type(x)}") + + codes = func(x.codes, *args[1:], **kwargs) + scale = func(x.scale, *args[1:], **kwargs) + + # adjust the first dim + shape = (x._shape[0] * codes.numel() // x.codes.numel(),) + x._shape[1:] + + # assume tensors from all ranks have the same signedness + return OptimState4bit(codes, scale, x.qmap.clone(), x.signed, shape) diff --git a/torchao/prototype/low_bit_optim/subclass_8bit.py b/torchao/prototype/low_bit_optim/subclass_8bit.py index 5b16f6363f..1e2067963a 100644 --- a/torchao/prototype/low_bit_optim/subclass_8bit.py +++ b/torchao/prototype/low_bit_optim/subclass_8bit.py @@ -6,14 +6,13 @@ aten = torch.ops.aten +c10d_functional = torch.ops.c10d_functional +_c10d_functional = torch.ops._c10d_functional QMAP_SIGNED = create_dynamic_map(signed=True) QMAP_UNSIGNED = create_dynamic_map(signed=False) -# dynamic tree quantization -# https://arxiv.org/pdf/1511.04561 -# https://arxiv.org/abs/2110.02861 class OptimState8bit(Tensor): implements = classmethod(_implements) tensor_attrs = ["codes", "scale", "qmap"] @@ -28,15 +27,25 @@ def __new__(cls, codes: Tensor, scale: Tensor, qmap: Tensor, signed: bool): ) def __init__(self, codes: Tensor, scale: Tensor, qmap: Tensor, signed: bool): + """Create quantized 8-bit optimizer state as proposed in https://arxiv.org/abs/2110.02861 + + Args + codes: quantized 8-bit data stored as uint8. Has the same shape as the original float tensor. + scale: scale data for block-wise quantization. + qmap: lookup table that maps between quantized value (code) and float value. + signed: whether the tensor is signed or unsigned. + + NOTE: To get block-wise scale, the original float tensor is first reshape to (-1, block_size). + Thus, the last dimension of the original float tensor is not necessarily divisible by block size. + Given `codes` and `scale`, `block_size` is calculated as `codes.numel() // scale.numel()`. + """ assert codes.dtype is torch.uint8 + assert scale.ndim == 1 self.codes = codes self.scale = scale self.qmap = qmap self.signed = signed - - @property - def block_size(self): - return self.codes.numel() // self.scale.numel() + self.block_size = codes.numel() // scale.numel() def __tensor_flatten__(self): return self.tensor_attrs, [self.signed] @@ -97,3 +106,31 @@ def _(func, *args, **kwargs): def _(func, *args, **kwargs): args = [x.dequantize() if isinstance(x, OptimState8bit) else x for x in args] return func(*args, **kwargs) + + +# this is needed for DTensor.from_local() +@OptimState8bit.implements(aten.view.default) +def _(func, *args, **kwargs): + x, shape = args + return OptimState8bit(x.codes.view(shape), x.scale, x.qmap, x.signed) + + +# this is needed for DTensor.full_tensor() +@OptimState8bit.implements([ + c10d_functional.all_gather_into_tensor.default, + _c10d_functional.all_gather_into_tensor.default, + c10d_functional.wait_tensor.default, + _c10d_functional.wait_tensor.default, +]) +def _(func, *args, **kwargs): + x = args[0] + if not isinstance(x, OptimState8bit): + raise ValueError(f"expecting a OptimState8bit but found {type(x)}") + + # assume tensors from all ranks have the same signedness + return OptimState8bit( + func(x.codes, *args[1:], **kwargs), + func(x.scale, *args[1:], **kwargs), + x.qmap.clone(), + x.signed, + ) diff --git a/torchao/prototype/low_bit_optim/subclass_fp8.py b/torchao/prototype/low_bit_optim/subclass_fp8.py index e3116e20f8..b78638cd01 100644 --- a/torchao/prototype/low_bit_optim/subclass_fp8.py +++ b/torchao/prototype/low_bit_optim/subclass_fp8.py @@ -4,6 +4,9 @@ aten = torch.ops.aten +c10d_functional = torch.ops.c10d_functional +_c10d_functional = torch.ops._c10d_functional + DTYPE = torch.float8_e4m3fn @@ -32,13 +35,21 @@ def __new__(cls, codes: Tensor, scale: Tensor): ) def __init__(self, codes: Tensor, scale: Tensor): + """Create quantized FP8 optimizer state. + + Args + codes: quantized FP8 E4M3FN data. Has the same shape as the original float tensor. + scale: scale data for block-wise quantization. + + NOTE: To get block-wise scale, the original float tensor is first reshape to (-1, block_size). + Thus, the last dimension of the original float tensor is not necessarily divisible by block size. + Given `codes` and `scale`, `block_size` is calculated as `codes.numel() // scale.numel()`. + """ assert codes.dtype is DTYPE + assert scale.ndim == 1 self.codes = codes self.scale = scale - - @property - def block_size(self): - return self.codes.numel() // self.scale.numel() + self.block_size = codes.numel() // scale.numel() def __tensor_flatten__(self): return self.tensor_attrs, [] @@ -99,3 +110,29 @@ def _(func, *args, **kwargs): def _(func, *args, **kwargs): args = [x.dequantize() if isinstance(x, OptimStateFp8) else x for x in args] return func(*args, **kwargs) + + +# this is needed for DTensor.from_local() +@OptimStateFp8.implements(aten.view.default) +def _(func, *args, **kwargs): + x, shape = args + return OptimStateFp8(x.codes.view(shape), x.scale) + + +# this is needed for DTensor.full_tensor() +@OptimStateFp8.implements([ + c10d_functional.all_gather_into_tensor.default, + _c10d_functional.all_gather_into_tensor.default, + c10d_functional.wait_tensor.default, + _c10d_functional.wait_tensor.default, +]) +def _(func, *args, **kwargs): + x = args[0] + if not isinstance(x, OptimStateFp8): + raise ValueError(f"expecting a OptimStateFp8 but found {type(x)}") + + # assume tensors from all ranks have the same signedness + return OptimStateFp8( + func(x.codes, *args[1:], **kwargs), + func(x.scale, *args[1:], **kwargs), + )