Flash-DMA is a high-performance attention implementation that integrates Flash Attention's memory efficiency with Dynamic Mask Attention's sparse computation capabilities for processing extremely long sequences in transformer models.
- Sparse Attention Computation: Dynamically selects the most important keys for each query, reducing computation from
$O(N^2)$ to$O(N \cdot w)$ where$w \ll N$ . - Memory Efficiency: Maintains Flash Attention's
$O(N)$ memory complexity without materializing the full attention matrix. - CUDA-Accelerated: Deep integration at the CUDA kernel level with custom sparse GEMM operations for maximum performance.
- Long Sequence Support: Efficiently handles sequences of 128K+ tokens through dynamic masking when sequence length exceeds
keep_window_size. - Advanced Integration: Complete integration from Python frontend to CUDA backend with optimized memory layouts and sparse computation strategies.
- Python: 3.8 or later
- PyTorch: 2.0.0 or later
- CUDA: 11.8 or later
- NVIDIA GPU: Compute Capability 8.0 or higher
- C++ Compiler: GCC 7+
Ensure your CUDA environment is properly configured:
# Check CUDA installation nvcc --version # Set CUDA_HOME if needed export CUDA_HOME=/usr/local/cudagit clone https://github.com/SmallDoges/flash-dmattn.git cd flash-dmattn git submodule update --init --recursive pip install .import torch import flash_dma_cuda import torch.nn.functional as F import math # Setup batch_size, seq_len, num_heads, head_dim = 2, 4096, 12, 128 device = torch.device('cuda') dtype = torch.bfloat16 # Input tensors query = torch.randn(batch_size, seq_len, num_heads, head_dim, device=device, dtype=dtype) key = torch.randn(batch_size, seq_len, num_heads, head_dim, device=device, dtype=dtype) value = torch.randn(batch_size, seq_len, num_heads, head_dim, device=device, dtype=dtype) zoh_states = torch.randn(batch_size, num_heads, seq_len, seq_len, device=device, dtype=dtype) active_mask = torch.ones(batch_size, num_heads, seq_len, seq_len, device=device, dtype=dtype) # Run Flash-DMA output = flash_dma_cuda.fwd( q=query, k=key, v=value, zoh=zoh_states, active_mask=active_mask, softmax_scale=1.0/math.sqrt(head_dim), keep_window_size=keep_window_size, is_causal=True )[0] print(f"Output shape: {output.shape}") # [2, 4096, 12, 128]Flash-DMA combines two complementary techniques:
- Dynamic Mask Attention: Computes relevance scores for keys and selects only the most important ones for attention computation
- Flash Attention: Processes attention in blocks to reduce memory usage and HBM access
The integration happens at the CUDA kernel level with several key components:
- ZOH States: Pre-computed importance scores for key selection
- Active Masks: Binary masks indicating which keys should be considered for each query
- Sparse Matrix Multiplication: Custom CUDA kernels for efficient sparse attention computation
- Block-Based Processing: Maintains Flash Attention's block-based approach for memory efficiency
This creates a hybrid attention mechanism that achieves both memory and computational efficiency for long sequences.
π Complete documentation is available in the docs directory:
- API Reference - Complete function documentation and usage examples
- Integration Guide - Detailed technical documentation of the Flash Attention integration
# Clone with submodules git clone --recursive https://github.com/SmallDoges/flash-dmattn.git cd flash-dmattn # Build in development mode pip install -e . # Run tests to verify installation python -c "import flash_dma_cuda; print('β
Flash DMA CUDA extension imported successfully')"- CUDA Toolkit 11.8+
- CUTLASS library
- PyTorch with CUDA support
- SM 8.0
- SM 9.0
- SM 10.0
- SM 12.0
Note: Flash Dynamic Mask Attention requires CUDA compute capability 8.0+ for optimal performance. Earlier architectures are not supported.
Flash-DMA provides comprehensive benchmarking tools to evaluate performance across different configurations:
python benchmarks/benchmark_forward_equivalence.pyValidates numerical consistency between Python reference and CUDA implementation.
python benchmarks/benchmark_forward_performance.pyCompares Flash-DMA against standard Flash Attention across various sequence lengths and batch sizes.
python benchmarks/benchmark_grad.pyTests backward pass implementation and gradient equivalence.
python benchmarks/benchmark_mqar.pyEvaluates performance on long-range reasoning tasks.
Compilation Errors
# Ensure CUDA_HOME is set correctly echo $CUDA_HOME # Linux/Mac echo $env:CUDA_HOME # Windows PowerShell # Check CUDA toolkit version nvcc --version # Verify PyTorch CUDA support python -c "import torch; print(f'CUDA available: {torch.cuda.is_available()}')"Import Errors
# Test basic import try: import flash_dma_cuda print("β
Flash DMA CUDA extension imported successfully") except ImportError as e: print(f"β Import failed: {e}") print("Please ensure the package is properly installed with: pip install -e .")Performance Issues
- Ensure GPU has compute capability 8.0+ for optimal performance
- Use
torch.bfloat16for better numerical stability - Adjust
keep_window_sizebased on available GPU memory - Verify CUDA kernels are being used
Memory Issues
# Monitor GPU memory usage torch.cuda.memory_summary() torch.cuda.max_memory_allocated() # Clear cache if needed torch.cuda.empty_cache()Numerical Issues
- Use
torch.bfloat16instead oftorch.float16for better stability - Check input tensor ranges for NaN or infinite values
- Validate ZOH states and active mask values are in expected ranges
This project is licensed under the BSD 3-Clause License. See LICENSE for details.
If you use Flash-DMA in your research, please cite:
@misc{flash_dma_2025, title={Trainable Dynamic Mask Sparse Attention}, author={Jingze Shi and Yifan Wu and Bingheng Wu and Yiran Peng and Yuyu Luo}, year={2025}, url={https://github.com/SmallDoges/flash-dmattn} }This project builds upon and integrates several excellent works:
- Flash-Attention - Memory-efficient attention computation
- NVIDIA CUTLASS - High-performance matrix operations library
We thank the open-source community for their contributions to efficient transformer implementations.
