You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
After this, you should have registered the marin-2:4 mm op to torchao.ops.marlin_24_mm
We would also want to benchmark the op at this time and make sure we get the same speedups reported by neuralmagic
2) Register a custom sparse layout and quantized dispatch
Now that we have our kernel connected, we can connect the kernel to our quantization API by writing a new sparse layout for AffineQuantizedTensor, MarlinSparseLayout.
You can use our semi-structured sparse layout implementation as a reference:
While the semi-structured sparse layout extends PlainLayoutType, the marlin packed layout should extend AQTLayout, as the marlin packed format packs both the scales and weights together.
Finally, once your Layout is registered, you'll want to define the quantized_linear_op dispatch. This will call into your earlier registered torchao.ops.marlin_24_mm op, instead of the normal dense mm.
The conditional would look something like this, after line 780, as we want to overload the int4-weight-only dispatch path with the sparse marlin kernels:
if (
weight_is_uint4 and
weight_qtensor.dtype == torch.float16 and
len(weight_qtensor.shape) == 2 and
weight_qtensor.zero_point_domain == ZeroPointDomain.FLOAT and
isinstance(weight_qtensor.layout_type, MarlinSparseLayoutType)
):
# call torchao.ops.marlin_24_mm
The description of the ticket seems to have everything I need to get started on it and will let you know once I have something for you to take a look at
Awesome @Diogo-V would be great if you took on this issue :) I'll assign it to you
Are you in #CUDA-MODE? There's a sparsity channel which would be a good resource to ask questions / get unblocked.
If you need help getting started initially as well, don't be shy to reach out.
Neuralmagic / IST-DASLab has written a fast INT4A16 kernel with support for 2:4 sparsity (Sparse-Marlin) https://github.com/IST-DASLab/Sparse-Marlin
We'd like to integrate this kernel into torchao. We'd like to test them for ViT acceleration as a datapoint for our PTC poster.
Implementation Details
To add a custom quant + sparse layout into torchao, we need to do three things:
1) Add and bind the CUDA kernel.
Sparse-marlin is implemented as a custom CUDA extension for pytorch, which should be easy to port over. Most of the logic is contained to https://github.com/IST-DASLab/Sparse-Marlin/blob/main/marlin/marlin_cuda_kernel_nm.cu
You can follow the tutorial: https://github.com/pytorch/ao/blob/main/torchao/csrc/README.md which provides details on how to add a custom CUDA extension to torchao.
After this, you should have registered the marin-2:4 mm op to
torchao.ops.marlin_24_mm
We would also want to benchmark the op at this time and make sure we get the same speedups reported by neuralmagic
2) Register a custom sparse layout and quantized dispatch
Now that we have our kernel connected, we can connect the kernel to our quantization API by writing a new sparse layout for AffineQuantizedTensor,
MarlinSparseLayout
.You can use our semi-structured sparse layout implementation as a reference:
https://github.com/pytorch/ao/blob/main/torchao/dtypes/affine_quantized_tensor.py#L36-L45
https://github.com/pytorch/ao/blob/main/torchao/dtypes/affine_quantized_tensor.py#L471-L511
You'll want to replace the line
int_data_compressed = torch._cslt_compress(int_data)
with the
pack
function from sparse-marlin found here: https://github.com/IST-DASLab/Sparse-Marlin/blob/c2ffa2395a3ada26c8cb7f910a5ec65bd3ce288a/marlin/__init__.py#L331While the semi-structured sparse layout extends
PlainLayoutType
, the marlin packed layout should extendAQTLayout
, as the marlin packed format packs both the scales and weights together.Finally, once your Layout is registered, you'll want to define the
quantized_linear_op
dispatch. This will call into your earlier registeredtorchao.ops.marlin_24_mm
op, instead of the normal dense mm.https://github.com/pytorch/ao/blob/main/torchao/dtypes/affine_quantized_tensor.py#L708-L732
The conditional would look something like this, after line 780, as we want to overload the int4-weight-only dispatch path with the sparse marlin kernels:
3) Add a layout option to
int4_weight_only()
Finally, we need to add a entrypoint to our SparseLayout from the
quantize_
API, like we do in https://github.com/pytorch/ao/blob/main/torchao/quantization/quant_api.py#L462but for
int4_weight_only
quantization instead.You'll then be able to call into your marlin kernels to test end-to-end with
Validation
In order to test our kernel in an e2e setting we can extend our SAM benchmarks to add in a new compression option:
https://github.com/pytorch/ao/blob/main/scripts/sam/eval_combo.py#L296
The text was updated successfully, but these errors were encountered: