Skip to content

thu-ml/SpargeAttn

Repository files navigation

SpargeAttention

Recommended API

We highly recommend using the spas_sage2_attn_meansim_topk_cuda and block_sparse_sage2_attn_cuda APIs. They are plug-and-play and customizable:

Plug-and-Play API

from spas_sage_attn import spas_sage2_attn_meansim_topk_cuda attn_output = spas_sage2_attn_meansim_topk_cuda(q, k, v, topk=0.5, is_causal=False)

You can adjust topk to balance between attention accuracy (higher topk is more accurate) and sparsity (lower topk is more sparse).

Customize your Block-Sparse Mask

from spas_sage_attn import block_sparse_sage2_attn_cuda attn_output = block_sparse_sage2_attn_cuda(q, k, v, mask_id=None):

In this API, we support computing attention in any block sparse mask per attention head. Specifically, the per-head attention mask mask_id has shape (batch_size, num_heads, ⌈seq_len / 128⌉, ⌈seq_len // 64⌉) and consists of 0 and 1. Currently, the block size is 128×64.


The official implementation of SpargeAttn, a universal training-free sparse attention accelerating language, image, and video models.

SpargeAttention: Accurate and Training-free Sparse Attention
Accelerating Any Model Inference

Daily papers: HuggingFace arXiv:2502.18137

speed comparison.

overview.

Project Updates

  • Please use the spas_sage2_attn_meansim_topk_cuda and block_sparse_sage2_attn_cuda APIs.
  • [2025-07]: Release a Triton Kernel example.
  • [2025-06]: SpargeAttn based on SageAttention2++ is released.
  • [2025-05]: Add a very simple usage without tuning or calibration: o = spas_sage2_attn_meansim_topk_cuda(q, k, v).
  • [2025-05]: 🎉SpargeAttn and SageAttention2 are accepted by ICML 2025!
  • [2025-03] Support high acceleration on more GPUs, e.g., H100.

Installation

Base environment

  • python>=3.9 , torch>=2.3.0
  • CUDA:
    • >=12.8 for Blackwell, >=12.4 for fp8 support on Ada, >=12.3 for fp8 support on Hopper, >=12.0 for Ampere

Install Package

pip install ninja # for parallel compilation python setup.py install # or pip install -e .

Available API

  • spas_sage2_attn_meansim_topk_cuda: SpargeAttn based on SageAttention2 that we recommend using.

  • spas_sage2_attn_meansim_cuda: SpargeAttn based on SageAttention2 that we do not recommend.

  • spas_sage_attn_meansim_topk_cuda: SpargeAttn based on SageAttention that we recommend using.

  • spas_sage_attn_meansim_cuda: SpargeAttn based on SageAttention that we do not recommend.

Usage

Plug-and-Play Usage

Just replace torch.nn.functional.scaled_dot_product_attention API using spas_sage2_attn_meansim_topk_cuda:

from spas_sage_attn import spas_sage2_attn_meansim_topk_cuda - attn_output = torch.nn.functional.scaled_dot_product_attention(q, k, v, is_causal=False) # is_causal can be True + attn_output = spas_sage2_attn_meansim_topk_cuda(q, k, v, topk=0.5, is_causal=False) # is_causal can be True

Plug-and-Play API

from spas_sage_attn import spas_sage2_attn_meansim_topk_cuda attn_output = spas_sage2_attn_meansim_topk_cuda(q, k, v, topk=0.5, is_causal=False)

You can adjust topk to balance between attention accuracy (higher topk is more accurate) and sparsity (lower topk is more sparse).

Customize your Block-Sparse Mask API

from spas_sage_attn import block_sparse_sage2_attn_cuda attn_output = block_sparse_sage2_attn_cuda(q, k, v, mask_id=None):

In this API, we support computing attention for any block-sparse mask per attention head. Specifically, the per-head attention mask mask_id has shape (batch_size, num_heads, ⌈seq_len / 128⌉, ⌈seq_len // 64⌉) and consists of 0 and 1. Currently, the block size is 128×64.

Citation

@inproceedings{zhang2025spargeattn, title={Spargeattn: Accurate sparse attention accelerating any model inference}, author={Zhang, Jintao and Xiang, Chendong and Huang, Haofeng and Wei, Jia and Xi, Haocheng and Zhu, Jun and Chen, Jianfei}, booktitle={International Conference on Machine Learning (ICML)}, year={2025} } @article{zhang2026spargeattention2, title={SpargeAttention2: Trainable Sparse Attention via Hybrid Top-k+ Top-p Masking and Distillation Fine-Tuning}, author={Zhang, Jintao and Jiang, Kai and Xiang, Chendong and Feng, Weiqi and Hu, Yuezhou and Xi, Haocheng and Chen, Jianfei and Zhu, Jun}, journal={arXiv preprint arXiv:2602.13515}, year={2026} } @inproceedings{zhang2025sageattention, title={SageAttention: Accurate 8-Bit Attention for Plug-and-play Inference Acceleration}, author={Zhang, Jintao and Wei, Jia and Zhang, Pengle and Zhu, Jun and Chen, Jianfei}, booktitle={International Conference on Learning Representations (ICLR)}, year={2025} } @inproceedings{zhang2024sageattention2, title={Sageattention2: Efficient attention with thorough outlier smoothing and per-thread int4 quantization}, author={Zhang, Jintao and Huang, Haofeng and Zhang, Pengle and Wei, Jia and Zhu, Jun and Chen, Jianfei}, booktitle={International Conference on Machine Learning (ICML)}, year={2025} } 

Releases

No releases published

Packages

 
 
 

Contributors

Languages