Skip to content

Commit

Permalink
fixing regressions
Browse files Browse the repository at this point in the history
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
  • Loading branch information
HDCharles committed Dec 16, 2024
1 parent 6bc64aa commit ef225a0
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 2 deletions.
3 changes: 2 additions & 1 deletion test/integration/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -877,7 +877,8 @@ def _test_lin_weight_subclass_api_impl(
api(mod)

# test get_plain()
mod[0].weight.tensor_impl.get_plain()
if hasattr(mod[0].weight, "tensor_impl"):
mod[0].weight.tensor_impl.get_plain()

test = mod(x)
self.assertGreater(
Expand Down
3 changes: 2 additions & 1 deletion torchao/dtypes/affine_quantized_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,12 +235,13 @@ def from_hp_to_intx(
compute_dtype=compute_dtype,
device=device,
verbose=False,
raw_output=not isinstance(_layout, TensorCoreTiledLayout),
raw_output=not isinstance(_layout, (TensorCoreTiledLayout, PlainLayout)),
# raw_output=False is basically the 'convert to TensorCoreTiledLayout zero_point version' option (add scale*midpoint)
# note in choose_qparams_affine, preserve_zero = False does this same thing while also controlling whether
# zero is preserved.
# TODO uncouple preserve_zero and conversion of zero_point to TensorCoreTiledLayout version
# TODO move the conversion of zero_point out of quant_primitives and into TensorCoreTiledLayout.from_plain
# TODO change PlainLayout to use raw_output.
)
data = data.to(target_dtype)
else:
Expand Down

0 comments on commit ef225a0

Please sign in to comment.