Skip to content

Dao-AILab/quack

Repository files navigation

🦆 QuACK: A Quirky Assortment of CuTe Kernels 🦆

Kernels are written in the CuTe-DSL.

Installation

# For CUDA 12.9:
pip install quack-kernels

# For CUDA 13.x:
pip install 'quack-kernels[cu13]' --extra-index-url https://download.pytorch.org/whl/cu130

# Do not use uv for CUDA 13.x installs yet: it can race/install
# nvidia-cutlass-dsl[cu13] in the wrong order (NVIDIA/cutlass#3259):
# https://github.com/NVIDIA/cutlass/issues/3259

# Optional: install NVIDIA matmul heuristics for better untuned GEMM configs
pip install 'quack-kernels[heuristics]'

# Optional: JAX bindings (pulls in jax and jax-tvm-ffi)
pip install 'quack-kernels[jax]'

Requirements

  • H100, B200/B300, or RTX 50 GPU
  • CUDA toolkit 12.9+
  • Python 3.12

Kernels 🐥

  • 🦆 RMSNorm forward + backward
  • 🦆 Softmax forward + backward
  • 🦆 Cross entropy forward + backward
  • 🦆 Layernorm forward
  • 🦆 Hopper gemm + epilogue
  • 🦆 Blackwell gemm + epilogue
  • 🦆 Blackwell GeForce gemm + epilogue

Usage

from quack import rmsnorm, softmax, cross_entropy

JAX bindings are also available for some kernels (see docs/jax.md):

from quack.softmax_jax import softmax

Documentations

  • JAX interface — optional jax + jax-tvm-ffi bindings, see quack/softmax_jax.py for an example.

[2025-07-10] We have a comprehensive blogpost on how to get memory-bound kernels to speed-of-light, right in the comfort of Python thanks to the CuTe-DSL.

Performance

Image

See our blogpost for the details.

Development

To set up the development environment:

pip install -e '.[dev]'
pre-commit install

# For CUDA 13.x:
pip install 'quack-kernels[dev,cu13]' --extra-index-url https://download.pytorch.org/whl/cu130

# Do not use uv for CUDA 13.x installs yet; use pip instead.
# See https://github.com/NVIDIA/cutlass/issues/3259

About

A Quirky Assortment of CuTe Kernels

Resources

License

Stars

Watchers

Forks

Packages

 
 
 

Contributors