Skip to content

Commit

Permalink
feat: tests pass & can execute llama2
Browse files Browse the repository at this point in the history
  • Loading branch information
Diogo-V committed Aug 20, 2024
1 parent 26cfc08 commit fecb1f8
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 8 deletions.
2 changes: 2 additions & 0 deletions test/sparsity/test_marlin.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,8 @@ def test_pack_unpack_equivalence(self):
w_24, scales, zeros, n_bit=4, groupsize=group_size
)

scales = scales.reshape(-1, w_q_24.shape[1])

# Test pack/unpack equivalence
q_w_comp, packed_scales, meta = pack_to_marlin_24(
w_q_24, scales, num_bits, group_size
Expand Down
32 changes: 25 additions & 7 deletions torchao/dtypes/affine_quantized_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -618,7 +618,9 @@ def from_plain(

# Linear layers are (in_features, out_features) but the int_data that is reaching this point
# is (out_features, in_features). We need to transpose it to match the expected shape in the marlin code.
# NOTE(reviewers): Please check if this is what I should do.
q_w_24 = int_data.t()
scale = scale.reshape(-1, q_w_24.shape[1])

if q_w_24.dtype != torch.int32:
raise ValueError("Only `torch.int32` weights are supported.")
Expand All @@ -631,15 +633,14 @@ def from_plain(

# NOTE: The current marlin 2:4 kernel supports both 4 and 8 bits quantization but fp8
# will require a bit more work to get our current quantization flow to work with it.
# Check the below link for a reference:
# https://github.com/neuralmagic/nm-vllm/tree/main
# Check the link for a reference: https://github.com/neuralmagic/nm-vllm/tree/main
num_bits = 4 if torch.max(q_w_24) < 16 else -1
if num_bits not in [4]:
raise ValueError(
f"Only {const.SUPPORTED_NUM_BITS} bits are supported, got {num_bits}."
f"Only {[4]} bits are supported, got {num_bits}."
)

group_size = in_features // scale.shape[-1]
group_size = in_features // scale.shape[0]
if group_size == 0:
group_size = in_features
assert group_size <= in_features, "Group size must be less than or equal to in_features."
Expand Down Expand Up @@ -1043,27 +1044,44 @@ def _linear_fp_act_int4_weight_sparse_marlin_check(input_tensor, weight_tensor,
isinstance(weight_tensor.layout_type, MarlinSparseLayoutType)
)


def _linear_fp_act_int4_weight_sparse_marlin_impl(input_tensor, weight_tensor, bias):
from torchao.sparsity.marlin import marlin_24_workspace
from torchao.sparsity.marlin import marlin_24_workspace, const

sparse_w_int4 = weight_tensor.layout_tensor.int_data
scale = weight_tensor.layout_tensor.scale
meta = weight_tensor.layout_tensor.meta
original_shape = weight_tensor.layout_tensor.original_shape
num_bits = weight_tensor.layout_tensor.num_bits

# Saves batch size for reshaping back to original shape after the matmul
# Reshapes tensor to (m, k) where m is in_features * batch and k is out_features
# NOTE(reviewers): Please check if I am handling the batch size correctly
batch_size = -1
if input_tensor.dim() == 3:
batch_size = input_tensor.size(0)
input_tensor = input_tensor.reshape(-1, input_tensor.shape[-1]).contiguous()

size_m = input_tensor.shape[0]
size_n = original_shape[0]
size_n = original_shape[1]
size_k = input_tensor.shape[1]
workspace_24 = marlin_24_workspace(original_shape[1])

# Pad input_tensor dim 1 to a multiple of the marlin tile size (16)
if size_k % const.TILE != 0:
pad_size = find_multiple(size_k, const.TILE)
input_tensor = torch.nn.functional.pad(input_tensor, (0, pad_size - size_k))
size_k = pad_size

out = torchao.ops.marlin_24_gemm(
input_tensor, sparse_w_int4, meta, scale,
workspace_24, num_bits, size_m, size_n, size_k
)
torch.cuda.synchronize()

# Reshape back to original shape
if batch_size != -1:
out = out.reshape(batch_size, -1, out.shape[-1])

if bias is not None:
out += bias.to(out.dtype)
return out
Expand Down
2 changes: 1 addition & 1 deletion torchao/sparsity/marlin/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,7 +333,7 @@ def _from_marlin_scale(
if group_size < size_k and group_size != -1:
reverse_perms = reverse_marlin_24_scale_perm[num_bits]
scales = scales.reshape((-1, len(reverse_perms)))[:, reverse_perms]
return scales.reshape((size_n, -1))
return scales.reshape((size_k // group_size, size_n))
else:
reverse_perms = reverse_marlin_24_scale_perm_single[num_bits]
scales = scales.reshape((-1, len(reverse_perms)))[:, reverse_perms]
Expand Down
24 changes: 24 additions & 0 deletions wip_test_llama2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import torch
from torchao import quantize_
from torchao.quantization import int4_weight_only
from torchao.dtypes import MarlinSparseLayoutType
from transformers import AutoTokenizer, LlamaForCausalLM

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
name = "meta-llama/Llama-2-7b-hf"
token = "your token"

model = LlamaForCausalLM.from_pretrained(name, torch_dtype=torch.float16, token=token).to(device)
tokenizer = AutoTokenizer.from_pretrained(name, token=token)

prompt = "Hey, are you conscious? Can you talk to me? I'm"
inputs = tokenizer(prompt, return_tensors="pt")

# Quantize
quantize_(model, int4_weight_only(layout_type=MarlinSparseLayoutType()))

# Generate
ids = inputs.input_ids.to(device)
generate_ids = model.generate(ids, max_length=30)
out = tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
print(out)

0 comments on commit fecb1f8

Please sign in to comment.