-
Notifications
You must be signed in to change notification settings - Fork 211
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
gemlite integration in torchao #1034
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/1034
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 6b786d3 with merge base 039cef4 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
I see that you removed the pruning config, that will produce incorrect results for group-sizes lower than 128 |
its there now, needed the custom op support stuff |
923f657
to
e417818
Compare
d3b4ad7
to
61e2564
Compare
1383982
to
749c0d4
Compare
Summary: This PR adds support for gemlite kernels in torchao using a subclass integration with the gemlite_uintx_weight_only constructor. This works for int4 grouped and ungrouped assymmetric oeight only quantization and int8 symmetric ungrouped quantization for fp16 models. TP support through DTensor is included in thsi PR in the process of integrating gemlite into AQT i also made some fixes to a few quant primitives that are being used which previously were not. Test Plan: test_integration.py -k "test_gemlite_layout" test_affine_quantized_tensor_parallel.py -k "test_tp_gemlite" see benchmarks.sh for gemlite benchmarks as well. Reviewers: Subscribers: Tasks: Tags: new gemlite integration using pip install Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: tests ran Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: fixing gemlite to do int4 matmul instead of fp16 fp16 Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: running tests Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: more testing Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: AQT integration wip Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: Wip Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: testing on gemlite a100_int8_tuning branch Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: gemlite subclass testing bitpacking 8 bits Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: bug fixing stuff Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: hicham fixes Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: new benchmarks Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: testing gemlite 8 bit Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: WIP Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: tp support Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: wip Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: final Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
749c0d4
to
6bc64aa
Compare
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
ef225a0
to
6b786d3
Compare
@@ -1053,6 +1088,7 @@ def callback(x): | |||
) | |||
|
|||
args = parser.parse_args() | |||
print(args) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: remove
@@ -251,7 +261,8 @@ def from_hp_to_intx( | |||
zero_point_domain, | |||
) | |||
# choose_qparams_affine is a custom op that does support returning optional Tensors. We thus set the zero_point to None if its domain is None | |||
if zero_point_domain is None: | |||
# TODO should probably consolidate ZeroPointDomain.NONE and None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah we should not have both, also doing some refactor in https://github.com/pytorch/ao/pull/1402/files#diff-7c9b4c8c6d4ef9c47873263304a335d5cf56c3ac9f98ba10b994cd80dc9c2709L536
@@ -20,6 +20,10 @@ | |||
_linear_int8_act_int8_weight_block_sparse_check, | |||
_linear_int8_act_int8_weight_block_sparse_impl, | |||
) | |||
from torchao.dtypes.uintx.gemlite_layout import ( | |||
_linear_fp_act_int4_weight_gemlite_check, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: fp16? seems like gemlite only works with fp16 right now, but can be a follow up PR
), f"GemliteAQTTensorImpl only works with GemliteLinearTriton but got {_layout}" | ||
group_size, bit_width = _layout.group_size, _layout.bit_width | ||
|
||
torch._dynamo.config.inline_inbuilt_nn_modules = False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: do we need to restore the state afterwords
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i think this is necessary for gemlite to function correctly,
# to use the gemlite matmul kernel, which expects teh weight to be passed in as is, | ||
# we ignore the transpose | ||
if func is aten.detach.default or func is aten.t.default: | ||
return return_and_correct_aliasing( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
would be good to record the transposed state probably:
not args[0].transposed, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
that line does literally nothing though, the actual state isn't tracked since its hardcoded to false always:
self.transposed = False |
long term, i think this may be better to be something handled by AQT itself rather than the tensor impl, it feels like transposing the tensor_impl would mean unpacking and repacking the weight. Whereas what we actually want is to book keep the representation which makes more sense at the top level where the actual shape is changing.
return ( | ||
# input is native fp16 tensor | ||
not is_traceable_wrapper_subclass(input_tensor) | ||
# and input_tensor.dtype == torch.float16 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
don't we need this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
its a little redundant since we block creation of the subclass unpon creation, normally when you mess this type of thing up it gives you a dtype mismatch error which is what will happen now, whereas adding this would likely bypass the error which i don't know if users would actually want.
could see it being either way though.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
looks good, I think we should merge now to unblock benchmarking and fix the follow up a bit later
from torchao.dtypes.uintx.gemlite_layout import apply_gemlite_quant | ||
|
||
use_hqq = True if bit_width == 4 else False | ||
apply_fn = lambda weight: apply_gemlite_quant( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we raise if gemlite is not installed?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think so, please feel free to add more review comments, and we should have follow up fixes.
just want to merge now to unblock benchmarking in sglang for now
* gemlite integration in torchao Summary: This PR adds support for gemlite kernels in torchao using a subclass integration with the gemlite_uintx_weight_only constructor. This works for int4 grouped and ungrouped assymmetric oeight only quantization and int8 symmetric ungrouped quantization for fp16 models. TP support through DTensor is included in thsi PR in the process of integrating gemlite into AQT i also made some fixes to a few quant primitives that are being used which previously were not. Test Plan: test_integration.py -k "test_gemlite_layout" test_affine_quantized_tensor_parallel.py -k "test_tp_gemlite" see benchmarks.sh for gemlite benchmarks as well. Reviewers: Subscribers: Tasks: Tags: new gemlite integration using pip install Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: tests ran Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: fixing gemlite to do int4 matmul instead of fp16 fp16 Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: running tests Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: more testing Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: AQT integration wip Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: Wip Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: testing on gemlite a100_int8_tuning branch Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: gemlite subclass testing bitpacking 8 bits Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: bug fixing stuff Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: hicham fixes Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: new benchmarks Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: testing gemlite 8 bit Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: WIP Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: tp support Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: wip Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: final Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: * fixing regressions Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
gemlite integration in torchao
Summary:
This PR adds support for gemlite kernels in torchao using a subclass
integration with the gemlite_uintx_weight_only constructor. This works
for int4 grouped and ungrouped assymmetric oeight only quantization and
int8 symmetric ungrouped quantization for fp16 models. TP support
through DTensor is included in thsi PR
in the process of integrating gemlite into AQT i also made some fixes to
a few quant primitives that are being used which previously were not.
Test Plan:
test_integration.py -k "test_gemlite_layout"
test_affine_quantized_tensor_parallel.py -k "test_tp_gemlite"
see benchmarks.sh for gemlite benchmarks as well.
Reviewers:
Subscribers:
Tasks:
Tags: