Skip to content

Commit

Permalink
fix lint
Browse files Browse the repository at this point in the history
Summary:

run `ruff format` to fix lint on main branch

Test Plan: CI

Reviewers:

Subscribers:

Tasks:

Tags:
  • Loading branch information
vkuzo committed Dec 4, 2024
1 parent 04d611a commit a298c5a
Showing 1 changed file with 2 additions and 1 deletion.
3 changes: 2 additions & 1 deletion torchao/dtypes/uintx/semi_sparse_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,14 +44,15 @@ def _linear_int8_act_int8_weight_semi_structured_sparse_impl(
# must pad
row, col = tmp.shape
from torch.sparse import SparseSemiStructuredTensorCUSPARSELT

tmp_padded = SparseSemiStructuredTensorCUSPARSELT._pad_dense_input(tmp)
# we fuse one of the scalar matrix multiplications (w_scales) into the sparse mm
y_dot_bf16_w_scales_fused = torch._cslt_sparse_mm(
w_vals_int8,
tmp_padded.t(),
alpha=w_scales.to(torch.float32),
out_dtype=torch.bfloat16,
).t()[:row, :]
).t()[:row, :]
y = (y_dot_bf16_w_scales_fused * x_scales.reshape(-1, 1)).reshape(
*x_vals_int8.shape[:-1], y_dot_bf16_w_scales_fused.shape[-1]
)
Expand Down

0 comments on commit a298c5a

Please sign in to comment.