Skip to content

Conversation

Copy link
Contributor

Copilot AI commented Sep 10, 2025

This PR implements optional attn_mask and attn_bias inputs with adaptive computation skipping to improve performance and reduce unnecessary memory operations in Flash Dynamic Mask Attention.

Problem

The current implementation always assumes both attn_mask and attn_bias are active, causing:

  • Unnecessary global memory loads when only one tensor is needed
  • Needless dbias computation when no bias is conceptually required
  • Inefficient workarounds like fake all-ones masks or zero bias tensors

Solution

Added support for 4 explicit modes with conditional processing:

Case attn_mask attn_bias Behavior
A None None Dense path, no block skip, no bias load/add (fastest)
B Tensor None Block skip using mask, no bias add/dbias
C None Tensor No block skip (all blocks active), add bias + compute dbias
D Tensor Tensor Current behavior (mask skip + bias add + dbias)

Key Changes

Python Interface

  • Both attn_mask and attn_bias parameters now accept Optional[Tensor] = None
  • Added use_mask and use_bias flags passed to CUDA kernels
  • Conditional gradient computation - dbias returned only when bias provided

CUDA Kernels

  • Modified mask application logic to conditionally process mask and bias based on runtime flags
  • Updated backward pass to skip dbias computation when use_bias=False
  • All changes preserve performance when both tensors are provided

Usage Example

from flash_dmattn import flash_dmattn_func_auto flash_attn = flash_dmattn_func_auto() # Case A: Dense attention (fastest for dense workloads) out = flash_attn(q, k, v, attn_mask=None, attn_bias=None) # Case B: Sparse attention with mask only  out = flash_attn(q, k, v, attn_mask=sparse_mask, attn_bias=None) # Case C: Dense attention with bias (e.g., ALiBi, relative position) out = flash_attn(q, k, v, attn_mask=None, attn_bias=position_bias) # Case D: Full functionality (unchanged) out = flash_attn(q, k, v, attn_mask=sparse_mask, attn_bias=position_bias)

Performance Benefits

  • Case A: Eliminates both memory streams + skip logic overhead
  • Case B: Removes bias path (saves reads/writes + math)
  • Case C: Removes mask loads/OR reductions (simpler control flow)
  • Case D: Baseline performance (unchanged)

Backward Compatibility

The implementation is fully backward compatible - existing code continues to work unchanged. Default parameter values maintain current behavior when not specified.

Fixes #161.


✨ Let Copilot coding agent set things up for you — coding agent works faster and does higher quality work when set up for your repo.

Copilot AI and others added 3 commits September 10, 2025 11:40
…al mask/bias Co-authored-by: LoserCheems <124847097+LoserCheems@users.noreply.github.com>
…essing Co-authored-by: LoserCheems <124847097+LoserCheems@users.noreply.github.com>
…on skipping Co-authored-by: LoserCheems <124847097+LoserCheems@users.noreply.github.com>
Copilot AI changed the title [WIP] [FEATURE REQUEST] Optional mask & bias inputs with adaptive computation skipping Add optional mask & bias inputs with adaptive computation skipping Sep 10, 2025
Copilot finished work on behalf of LoserCheems September 10, 2025 11:52
Copilot AI requested a review from LoserCheems September 10, 2025 11:52
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

2 participants