Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add 2:4 sparse marlin kernels to torchao #549

Closed
jcaip opened this issue Jul 29, 2024 · 3 comments · Fixed by #621
Closed

Add 2:4 sparse marlin kernels to torchao #549

jcaip opened this issue Jul 29, 2024 · 3 comments · Fixed by #621
Assignees
Labels
good first issue Good for newcomers

Comments

@jcaip
Copy link
Contributor

jcaip commented Jul 29, 2024

Neuralmagic / IST-DASLab has written a fast INT4A16 kernel with support for 2:4 sparsity (Sparse-Marlin) https://github.com/IST-DASLab/Sparse-Marlin

image

We'd like to integrate this kernel into torchao. We'd like to test them for ViT acceleration as a datapoint for our PTC poster.

Implementation Details

To add a custom quant + sparse layout into torchao, we need to do three things:

1) Add and bind the CUDA kernel.

Sparse-marlin is implemented as a custom CUDA extension for pytorch, which should be easy to port over. Most of the logic is contained to https://github.com/IST-DASLab/Sparse-Marlin/blob/main/marlin/marlin_cuda_kernel_nm.cu

You can follow the tutorial: https://github.com/pytorch/ao/blob/main/torchao/csrc/README.md which provides details on how to add a custom CUDA extension to torchao.

After this, you should have registered the marin-2:4 mm op to torchao.ops.marlin_24_mm

We would also want to benchmark the op at this time and make sure we get the same speedups reported by neuralmagic

2) Register a custom sparse layout and quantized dispatch

Now that we have our kernel connected, we can connect the kernel to our quantization API by writing a new sparse layout for AffineQuantizedTensor, MarlinSparseLayout.

You can use our semi-structured sparse layout implementation as a reference:

https://github.com/pytorch/ao/blob/main/torchao/dtypes/affine_quantized_tensor.py#L36-L45

https://github.com/pytorch/ao/blob/main/torchao/dtypes/affine_quantized_tensor.py#L471-L511

You'll want to replace the line
int_data_compressed = torch._cslt_compress(int_data)
with the pack function from sparse-marlin found here: https://github.com/IST-DASLab/Sparse-Marlin/blob/c2ffa2395a3ada26c8cb7f910a5ec65bd3ce288a/marlin/__init__.py#L331

While the semi-structured sparse layout extends PlainLayoutType, the marlin packed layout should extend AQTLayout, as the marlin packed format packs both the scales and weights together.

Finally, once your Layout is registered, you'll want to define the quantized_linear_op dispatch. This will call into your earlier registered torchao.ops.marlin_24_mm op, instead of the normal dense mm.

https://github.com/pytorch/ao/blob/main/torchao/dtypes/affine_quantized_tensor.py#L708-L732

The conditional would look something like this, after line 780, as we want to overload the int4-weight-only dispatch path with the sparse marlin kernels:

        if (
            weight_is_uint4 and
            weight_qtensor.dtype == torch.float16 and
            len(weight_qtensor.shape) == 2 and
            weight_qtensor.zero_point_domain == ZeroPointDomain.FLOAT and
            isinstance(weight_qtensor.layout_type, MarlinSparseLayoutType)
        ):
             # call torchao.ops.marlin_24_mm 

3) Add a layout option to int4_weight_only()

Finally, we need to add a entrypoint to our SparseLayout from the quantize_ API, like we do in https://github.com/pytorch/ao/blob/main/torchao/quantization/quant_api.py#L462

but for int4_weight_only quantization instead.

You'll then be able to call into your marlin kernels to test end-to-end with

quantize_(m, int4_weight_only(layout_type=MarlinSparseLayoutType())

Validation

In order to test our kernel in an e2e setting we can extend our SAM benchmarks to add in a new compression option:

https://github.com/pytorch/ao/blob/main/scripts/sam/eval_combo.py#L296

@jerryzh168 jerryzh168 added the good first issue Good for newcomers label Jul 29, 2024
@Diogo-V
Copy link
Contributor

Diogo-V commented Aug 3, 2024

@jcaip

I would love to work on this issue, if possible.

The description of the ticket seems to have everything I need to get started on it and will let you know once I have something for you to take a look at

@jcaip
Copy link
Contributor Author

jcaip commented Aug 5, 2024

Awesome @Diogo-V would be great if you took on this issue :) I'll assign it to you

Are you in #CUDA-MODE? There's a sparsity channel which would be a good resource to ask questions / get unblocked.
If you need help getting started initially as well, don't be shy to reach out.

@Diogo-V
Copy link
Contributor

Diogo-V commented Aug 5, 2024

Thank you @jcaip!

I am already in #CUDA-MODE. I posted a message in the torchao channel to ask about good first issues and thats how I landed on this one :)

I will post a message in the sparsity channel if I hit any blockers.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
good first issue Good for newcomers
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants