-
Notifications
You must be signed in to change notification settings - Fork 386
Description
Neuralmagic / IST-DASLab has written a fast INT4A16 kernel with support for 2:4 sparsity (Sparse-Marlin) https://github.com/IST-DASLab/Sparse-Marlin
We'd like to integrate this kernel into torchao. We'd like to test them for ViT acceleration as a datapoint for our PTC poster.
Implementation Details
To add a custom quant + sparse layout into torchao, we need to do three things:
1) Add and bind the CUDA kernel.
Sparse-marlin is implemented as a custom CUDA extension for pytorch, which should be easy to port over. Most of the logic is contained to https://github.com/IST-DASLab/Sparse-Marlin/blob/main/marlin/marlin_cuda_kernel_nm.cu
You can follow the tutorial: https://github.com/pytorch/ao/blob/main/torchao/csrc/README.md which provides details on how to add a custom CUDA extension to torchao.
After this, you should have registered the marin-2:4 mm op to torchao.ops.marlin_24_mm
We would also want to benchmark the op at this time and make sure we get the same speedups reported by neuralmagic
2) Register a custom sparse layout and quantized dispatch
Now that we have our kernel connected, we can connect the kernel to our quantization API by writing a new sparse layout for AffineQuantizedTensor, MarlinSparseLayout.
You can use our semi-structured sparse layout implementation as a reference:
https://github.com/pytorch/ao/blob/main/torchao/dtypes/affine_quantized_tensor.py#L36-L45
https://github.com/pytorch/ao/blob/main/torchao/dtypes/affine_quantized_tensor.py#L471-L511
You'll want to replace the line
int_data_compressed = torch._cslt_compress(int_data)
with the pack function from sparse-marlin found here: https://github.com/IST-DASLab/Sparse-Marlin/blob/c2ffa2395a3ada26c8cb7f910a5ec65bd3ce288a/marlin/__init__.py#L331
While the semi-structured sparse layout extends PlainLayoutType, the marlin packed layout should extend AQTLayout, as the marlin packed format packs both the scales and weights together.
Finally, once your Layout is registered, you'll want to define the quantized_linear_op dispatch. This will call into your earlier registered torchao.ops.marlin_24_mm op, instead of the normal dense mm.
https://github.com/pytorch/ao/blob/main/torchao/dtypes/affine_quantized_tensor.py#L708-L732
The conditional would look something like this, after line 780, as we want to overload the int4-weight-only dispatch path with the sparse marlin kernels:
if (
weight_is_uint4 and
weight_qtensor.dtype == torch.float16 and
len(weight_qtensor.shape) == 2 and
weight_qtensor.zero_point_domain == ZeroPointDomain.FLOAT and
isinstance(weight_qtensor.layout_type, MarlinSparseLayoutType)
):
# call torchao.ops.marlin_24_mm
3) Add a layout option to int4_weight_only()
Finally, we need to add a entrypoint to our SparseLayout from the quantize_ API, like we do in https://github.com/pytorch/ao/blob/main/torchao/quantization/quant_api.py#L462
but for int4_weight_only quantization instead.
You'll then be able to call into your marlin kernels to test end-to-end with
quantize_(m, int4_weight_only(layout_type=MarlinSparseLayoutType())
Validation
In order to test our kernel in an e2e setting we can extend our SAM benchmarks to add in a new compression option:
https://github.com/pytorch/ao/blob/main/scripts/sam/eval_combo.py#L296
