Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

gemlite integration in torchao #1034

Merged
merged 2 commits into from
Dec 16, 2024
Merged

gemlite integration in torchao #1034

merged 2 commits into from
Dec 16, 2024

Conversation

HDCharles
Copy link
Contributor

@HDCharles HDCharles commented Oct 8, 2024

gemlite integration in torchao

Summary:

This PR adds support for gemlite kernels in torchao using a subclass
integration with the gemlite_uintx_weight_only constructor. This works
for int4 grouped and ungrouped assymmetric oeight only quantization and
int8 symmetric ungrouped quantization for fp16 models. TP support
through DTensor is included in thsi PR

in the process of integrating gemlite into AQT i also made some fixes to
a few quant primitives that are being used which previously were not.

Test Plan:

test_integration.py -k "test_gemlite_layout"
test_affine_quantized_tensor_parallel.py -k "test_tp_gemlite"

see benchmarks.sh for gemlite benchmarks as well.

Reviewers:

Subscribers:

Tasks:

Tags:

Copy link

pytorch-bot bot commented Oct 8, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/1034

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 6b786d3 with merge base 039cef4 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Oct 8, 2024
@HDCharles HDCharles changed the title (wip) gemlite integration and llama batchsize>1 gemlite integration in torchao Oct 8, 2024
@mobicham
Copy link
Collaborator

mobicham commented Oct 9, 2024

I see that you removed the pruning config, that will produce incorrect results for group-sizes lower than 128

@HDCharles
Copy link
Contributor Author

its there now, needed the custom op support stuff

@HDCharles HDCharles force-pushed the 050_gemlite_integration branch from 923f657 to e417818 Compare October 29, 2024 09:03
@HDCharles HDCharles force-pushed the 050_gemlite_integration branch 2 times, most recently from d3b4ad7 to 61e2564 Compare November 22, 2024 10:40
yanbing-j pushed a commit to yanbing-j/ao that referenced this pull request Dec 9, 2024
@HDCharles HDCharles force-pushed the 050_gemlite_integration branch 2 times, most recently from 1383982 to 749c0d4 Compare December 16, 2024 10:38
@HDCharles HDCharles added the topic: new feature Use this tag if this PR adds a new feature label Dec 16, 2024
@HDCharles HDCharles requested a review from jerryzh168 December 16, 2024 10:39
Summary:

This PR adds support for gemlite kernels in torchao using a subclass
integration with the gemlite_uintx_weight_only constructor. This works
for int4 grouped and ungrouped assymmetric oeight only quantization and
int8 symmetric ungrouped quantization for fp16 models. TP support
through DTensor is included in thsi PR

in the process of integrating gemlite into AQT i also made some fixes to
a few quant primitives that are being used which previously were not.

Test Plan:

test_integration.py -k "test_gemlite_layout"
test_affine_quantized_tensor_parallel.py -k "test_tp_gemlite"

see benchmarks.sh for gemlite benchmarks as well.

Reviewers:

Subscribers:

Tasks:

Tags:

new gemlite integration using pip install

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

tests ran

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

fixing gemlite to do int4 matmul instead of fp16 fp16

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

running tests

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

more testing

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

AQT integration wip

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

Wip

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

testing on gemlite a100_int8_tuning branch

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

gemlite subclass testing bitpacking 8 bits

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

bug fixing stuff

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

hicham fixes

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

new benchmarks

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

testing gemlite 8 bit

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

WIP

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

tp support

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

wip

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

final

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
@HDCharles HDCharles force-pushed the 050_gemlite_integration branch from 749c0d4 to 6bc64aa Compare December 16, 2024 10:51
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
@HDCharles HDCharles force-pushed the 050_gemlite_integration branch from ef225a0 to 6b786d3 Compare December 16, 2024 13:54
@@ -1053,6 +1088,7 @@ def callback(x):
)

args = parser.parse_args()
print(args)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: remove

@@ -251,7 +261,8 @@ def from_hp_to_intx(
zero_point_domain,
)
# choose_qparams_affine is a custom op that does support returning optional Tensors. We thus set the zero_point to None if its domain is None
if zero_point_domain is None:
# TODO should probably consolidate ZeroPointDomain.NONE and None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@@ -20,6 +20,10 @@
_linear_int8_act_int8_weight_block_sparse_check,
_linear_int8_act_int8_weight_block_sparse_impl,
)
from torchao.dtypes.uintx.gemlite_layout import (
_linear_fp_act_int4_weight_gemlite_check,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: fp16? seems like gemlite only works with fp16 right now, but can be a follow up PR

), f"GemliteAQTTensorImpl only works with GemliteLinearTriton but got {_layout}"
group_size, bit_width = _layout.group_size, _layout.bit_width

torch._dynamo.config.inline_inbuilt_nn_modules = False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: do we need to restore the state afterwords

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think this is necessary for gemlite to function correctly,

# to use the gemlite matmul kernel, which expects teh weight to be passed in as is,
# we ignore the transpose
if func is aten.detach.default or func is aten.t.default:
return return_and_correct_aliasing(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would be good to record the transposed state probably:

not args[0].transposed,
in case it's used later

Copy link
Contributor Author

@HDCharles HDCharles Dec 17, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that line does literally nothing though, the actual state isn't tracked since its hardcoded to false always:

self.transposed = False

long term, i think this may be better to be something handled by AQT itself rather than the tensor impl, it feels like transposing the tensor_impl would mean unpacking and repacking the weight. Whereas what we actually want is to book keep the representation which makes more sense at the top level where the actual shape is changing.

return (
# input is native fp16 tensor
not is_traceable_wrapper_subclass(input_tensor)
# and input_tensor.dtype == torch.float16
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

don't we need this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

its a little redundant since we block creation of the subclass unpon creation, normally when you mess this type of thing up it gives you a dtype mismatch error which is what will happen now, whereas adding this would likely bypass the error which i don't know if users would actually want.

could see it being either way though.

Copy link
Contributor

@jerryzh168 jerryzh168 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks good, I think we should merge now to unblock benchmarking and fix the follow up a bit later

@jerryzh168 jerryzh168 merged commit 603d908 into main Dec 16, 2024
18 checks passed
from torchao.dtypes.uintx.gemlite_layout import apply_gemlite_quant

use_hqq = True if bit_width == 4 else False
apply_fn = lambda weight: apply_gemlite_quant(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we raise if gemlite is not installed?

Copy link
Contributor

@jerryzh168 jerryzh168 Dec 16, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think so, please feel free to add more review comments, and we should have follow up fixes.

just want to merge now to unblock benchmarking in sglang for now

amdfaa pushed a commit that referenced this pull request Jan 10, 2025
* gemlite integration in torchao

Summary:

This PR adds support for gemlite kernels in torchao using a subclass
integration with the gemlite_uintx_weight_only constructor. This works
for int4 grouped and ungrouped assymmetric oeight only quantization and
int8 symmetric ungrouped quantization for fp16 models. TP support
through DTensor is included in thsi PR

in the process of integrating gemlite into AQT i also made some fixes to
a few quant primitives that are being used which previously were not.

Test Plan:

test_integration.py -k "test_gemlite_layout"
test_affine_quantized_tensor_parallel.py -k "test_tp_gemlite"

see benchmarks.sh for gemlite benchmarks as well.

Reviewers:

Subscribers:

Tasks:

Tags:

new gemlite integration using pip install

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

tests ran

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

fixing gemlite to do int4 matmul instead of fp16 fp16

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

running tests

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

more testing

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

AQT integration wip

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

Wip

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

testing on gemlite a100_int8_tuning branch

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

gemlite subclass testing bitpacking 8 bits

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

bug fixing stuff

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

hicham fixes

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

new benchmarks

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

testing gemlite 8 bit

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

WIP

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

tp support

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

wip

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

final

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

* fixing regressions

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. topic: new feature Use this tag if this PR adds a new feature
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants