-
Notifications
You must be signed in to change notification settings - Fork 211
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
gemlite integration in torchao #1034
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -225,6 +225,8 @@ def from_hp_to_intx( | |
else input_float.dtype | ||
) | ||
device = input_float.device | ||
from torchao.dtypes.uintx import TensorCoreTiledLayout | ||
|
||
data, scale, zero_point, _ = choose_qparams_and_quantize_affine_hqq( | ||
input_float, | ||
nbits=nbits, | ||
|
@@ -233,7 +235,15 @@ def from_hp_to_intx( | |
compute_dtype=compute_dtype, | ||
device=device, | ||
verbose=False, | ||
raw_output=False, | ||
raw_output=not isinstance( | ||
_layout, (TensorCoreTiledLayout, PlainLayout) | ||
), | ||
# raw_output=False is basically the 'convert to TensorCoreTiledLayout zero_point version' option (add scale*midpoint) | ||
# note in choose_qparams_affine, preserve_zero = False does this same thing while also controlling whether | ||
# zero is preserved. | ||
# TODO uncouple preserve_zero and conversion of zero_point to TensorCoreTiledLayout version | ||
# TODO move the conversion of zero_point out of quant_primitives and into TensorCoreTiledLayout.from_plain | ||
# TODO change PlainLayout to use raw_output. | ||
) | ||
data = data.to(target_dtype) | ||
else: | ||
|
@@ -251,7 +261,8 @@ def from_hp_to_intx( | |
zero_point_domain, | ||
) | ||
# choose_qparams_affine is a custom op that does support returning optional Tensors. We thus set the zero_point to None if its domain is None | ||
if zero_point_domain is None: | ||
# TODO should probably consolidate ZeroPointDomain.NONE and None | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yeah we should not have both, also doing some refactor in https://github.com/pytorch/ao/pull/1402/files#diff-7c9b4c8c6d4ef9c47873263304a335d5cf56c3ac9f98ba10b994cd80dc9c2709L536 |
||
if zero_point_domain is None or zero_point_domain == ZeroPointDomain.NONE: | ||
zero_point = None | ||
data = quantize_affine( | ||
input_float, | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -20,6 +20,10 @@ | |
_linear_int8_act_int8_weight_block_sparse_check, | ||
_linear_int8_act_int8_weight_block_sparse_impl, | ||
) | ||
from torchao.dtypes.uintx.gemlite_layout import ( | ||
_linear_fp_act_int4_weight_gemlite_check, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: fp16? seems like gemlite only works with fp16 right now, but can be a follow up PR |
||
_linear_fp_act_int4_weight_gemlite_impl, | ||
) | ||
from torchao.dtypes.uintx.marlin_qqq_tensor import ( | ||
_linear_int8_act_int4_weight_marlin_qqq_check, | ||
_linear_int8_act_int4_weight_marlin_qqq_impl, | ||
|
@@ -135,6 +139,10 @@ def _register_aqt_quantized_linear_dispatches(): | |
_linear_int8_act_int4_weight_marlin_qqq_check, | ||
_linear_int8_act_int4_weight_marlin_qqq_impl, | ||
), | ||
( | ||
_linear_fp_act_int4_weight_gemlite_check, | ||
_linear_fp_act_int4_weight_gemlite_impl, | ||
), | ||
]: | ||
register_aqt_quantized_linear_dispatch(dispatch_condition, impl) | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: remove