From 9869dfc96940f58b3ababa57fc8ce45dcb854b81 Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Thu, 2 Jan 2025 15:14:55 -0800 Subject: [PATCH 1/4] Make int8 dynamic quant in autoquant serializable Summary: lambda function is not supported for serialization, so we need to reuse the non-lambda functions that already supports serialization: https://github.com/pytorch/ao/blob/00a8d290aab354985fce8c880e1fded22bc48e30/torchao/quantization/quant_api.py#L1263C5-L1268 Note this PR only supports int8 dynamic quant, will need to test and support float8 separately (in H100 machines) Test Plan: Tested locally with transformer push_to_hub: https://huggingface.co/jerryzh168/llama3-8b-autoquant/tree/main Reviewers: Subscribers: Tasks: Tags: --- torchao/quantization/autoquant.py | 14 ++++---------- 1 file changed, 4 insertions(+), 10 deletions(-) diff --git a/torchao/quantization/autoquant.py b/torchao/quantization/autoquant.py index b13f1d16a5..4503c17f75 100644 --- a/torchao/quantization/autoquant.py +++ b/torchao/quantization/autoquant.py @@ -426,6 +426,7 @@ def from_float(cls, weight): # avoid circular dep from torchao.dtypes import to_affine_quantized_intx + from torchao.quantization.quant_api import _int8_symm_per_token_reduced_range_quant # weight settings mapping_type = MappingType.SYMMETRIC @@ -450,16 +451,7 @@ def get_per_token_block_size(x): input_quant_min = -127 input_quant_max = 127 _layout = cls.layout - input_quant_func = lambda x: to_affine_quantized_intx( - x, - input_mapping_type, - get_per_token_block_size(x), - input_target_dtype, - eps=input_eps, - quant_min=input_quant_min, - quant_max=input_quant_max, - scale_dtype=torch.float32 if x.dtype == torch.float16 else None, - ) + input_quant_func = _int8_symm_per_token_reduced_range_quant block_size = get_weight_block_size(weight) weight = to_affine_quantized_intx( @@ -937,6 +929,7 @@ def get_per_token_block_size(x): input_target_dtype = torch.float8_e4m3fn _layout = Float8Layout(mm_config=Float8MMConfig(use_fast_accum=True)) + # TODO: make this serializable input_quant_func = lambda x: _input_activation_quant_func_fp8( x=x, activation_granularity=cls.activation_granularity, @@ -980,6 +973,7 @@ def get_weight_block_size(x): input_target_dtype = torch.float8_e4m3fn _layout = Float8Layout(mm_config=Float8MMConfig(use_fast_accum=True)) + # TODO: make this serializable input_quant_func = lambda x: _input_activation_quant_func_fp8( x=x, activation_granularity=cls.activation_granularity, From ca3c5400ef1aba3cbf2b6e3fc2da52b015f6b586 Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Thu, 2 Jan 2025 15:19:31 -0800 Subject: [PATCH 2/4] fix --- torchao/quantization/autoquant.py | 19 ++++--------------- 1 file changed, 4 insertions(+), 15 deletions(-) diff --git a/torchao/quantization/autoquant.py b/torchao/quantization/autoquant.py index 4503c17f75..34bc3de40a 100644 --- a/torchao/quantization/autoquant.py +++ b/torchao/quantization/autoquant.py @@ -428,6 +428,9 @@ def from_float(cls, weight): from torchao.dtypes import to_affine_quantized_intx from torchao.quantization.quant_api import _int8_symm_per_token_reduced_range_quant + # input settings + input_quant_func = _int8_symm_per_token_reduced_range_quant + # weight settings mapping_type = MappingType.SYMMETRIC @@ -437,23 +440,9 @@ def get_weight_block_size(x): target_dtype = torch.int8 eps = torch.finfo(torch.float32).eps zero_point_dtype = torch.int64 - - # input settings - def get_per_token_block_size(x): - block_size = list(x.shape) - for i in range(len(block_size) - 1): - block_size[i] = 1 - return block_size - - input_mapping_type = MappingType.SYMMETRIC - input_target_dtype = torch.int8 - input_eps = 1e-5 - input_quant_min = -127 - input_quant_max = 127 _layout = cls.layout - input_quant_func = _int8_symm_per_token_reduced_range_quant - block_size = get_weight_block_size(weight) + weight = to_affine_quantized_intx( weight, mapping_type, From a463397d18604a58dde20d1107b26db4ed3977de Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Thu, 2 Jan 2025 16:20:04 -0800 Subject: [PATCH 3/4] fixes --- .github/workflows/ruff_linter.yml | 2 +- .pre-commit-config.yaml | 3 ++- torchao/quantization/autoquant.py | 8 +++++++- 3 files changed, 10 insertions(+), 3 deletions(-) diff --git a/.github/workflows/ruff_linter.yml b/.github/workflows/ruff_linter.yml index 40ac883e71..0da7c0a4b5 100644 --- a/.github/workflows/ruff_linter.yml +++ b/.github/workflows/ruff_linter.yml @@ -68,7 +68,7 @@ jobs: # --isolated is used to skip the allowlist at all so this applies to all files # please be careful when using this large changes means everyone needs to rebase # if you do be sure to update .pre-commit-config.yaml - ruff check --isolated --select F821,F823,W191 + ruff check --isolated --select F821,F823,W191,E731 ruff check ruff format --check || { echo "Ruff check failed, please try again after running 'ruff format'." diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 2ac354e6d7..3e34f1d465 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -26,4 +26,5 @@ repos: alias: ruff-isolated args: - --isolated - - select F821,F823,W191 + - --select + - F821,F823,W191 diff --git a/torchao/quantization/autoquant.py b/torchao/quantization/autoquant.py index 34bc3de40a..15166aca0d 100644 --- a/torchao/quantization/autoquant.py +++ b/torchao/quantization/autoquant.py @@ -426,7 +426,9 @@ def from_float(cls, weight): # avoid circular dep from torchao.dtypes import to_affine_quantized_intx - from torchao.quantization.quant_api import _int8_symm_per_token_reduced_range_quant + from torchao.quantization.quant_api import ( + _int8_symm_per_token_reduced_range_quant, + ) # input settings input_quant_func = _int8_symm_per_token_reduced_range_quant @@ -1270,3 +1272,7 @@ def finalize_autoquant(): model(*example_input) return model + + +if TORCH_VERSION_AT_LEAST_2_5: + torch.serialization.add_safe_globals(ALL_AUTOQUANT_CLASS_LIST) From 90256264142216fa1c33f2af2d66b4ecb4f88e45 Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Thu, 2 Jan 2025 16:26:53 -0800 Subject: [PATCH 4/4] fix --- .github/workflows/ruff_linter.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/ruff_linter.yml b/.github/workflows/ruff_linter.yml index 0da7c0a4b5..40ac883e71 100644 --- a/.github/workflows/ruff_linter.yml +++ b/.github/workflows/ruff_linter.yml @@ -68,7 +68,7 @@ jobs: # --isolated is used to skip the allowlist at all so this applies to all files # please be careful when using this large changes means everyone needs to rebase # if you do be sure to update .pre-commit-config.yaml - ruff check --isolated --select F821,F823,W191,E731 + ruff check --isolated --select F821,F823,W191 ruff check ruff format --check || { echo "Ruff check failed, please try again after running 'ruff format'."