Skip to content

Commit

Permalink
Unbreak build after #621
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewor14 committed Sep 6, 2024
1 parent 65d86c6 commit 36cbe19
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions torchao/dtypes/affine_quantized_tensor.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import torch
from typing import Tuple, Optional, Union
import torchao.ops
from collections import defaultdict
import functools
import math
Expand Down Expand Up @@ -1425,6 +1424,8 @@ def _linear_fp_act_int4_weight_sparse_marlin_check(input_tensor, weight_tensor,

def _linear_fp_act_int4_weight_sparse_marlin_impl(input_tensor, weight_tensor, bias):
from torchao.sparsity.marlin import marlin_24_workspace, const
from torchao.ops import marlin_24_gemm

assert isinstance(weight_tensor, AffineQuantizedTensor)

sparse_w_int4 = weight_tensor.layout_tensor.int_data
Expand All @@ -1441,7 +1442,7 @@ def _linear_fp_act_int4_weight_sparse_marlin_impl(input_tensor, weight_tensor, b
size_k = input_2d.shape[1]
workspace_24 = marlin_24_workspace(original_shape[1])

out = torchao.ops.marlin_24_gemm(
out = marlin_24_gemm(
input_2d, sparse_w_int4, meta, scale,
workspace_24, num_bits, size_m, size_n, size_k
)
Expand Down

0 comments on commit 36cbe19

Please sign in to comment.