Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions flash_sparse_attn/flash_sparse_attn_triton.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ def _fwd_kernel(
q = (q * softmax_scale).to(q.dtype)

# Loop over k, v and update accumulator
end_n = seqlen_k if not IS_CAUSAL else tl.minimum((start_m + 1) * BLOCK_M, seqlen_k)
end_n = seqlen_k if not IS_CAUSAL and seqlen_k <= seqlen_q else tl.minimum((start_m + 1) * BLOCK_M, seqlen_k)
Copy link

Copilot AI Dec 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The condition not IS_CAUSAL and seqlen_k <= seqlen_q is logically incorrect. This causes causal-like behavior to be applied when IS_CAUSAL is False and seqlen_k > seqlen_q, which breaks non-causal attention with longer KV sequences.

When IS_CAUSAL is False, end_n should always be seqlen_k regardless of the relationship between seqlen_k and seqlen_q. The condition should be:

end_n = seqlen_k if not IS_CAUSAL else tl.minimum((start_m + 1) * BLOCK_M + (seqlen_k - seqlen_q), seqlen_k)

The offset (seqlen_k - seqlen_q) should only be applied to the causal case, not used to determine whether to apply causal behavior.

Suggested change
end_n = seqlen_k if not IS_CAUSAL and seqlen_k <= seqlen_q else tl.minimum((start_m + 1) * BLOCK_M, seqlen_k)
end_n = seqlen_k if not IS_CAUSAL else tl.minimum((start_m + 1) * BLOCK_M + (seqlen_k - seqlen_q), seqlen_k)
Copilot uses AI. Check for mistakes.
for start_n in range(0, end_n, BLOCK_N):
start_n = tl.multiple_of(start_n, BLOCK_N)

Expand Down Expand Up @@ -231,7 +231,7 @@ def _fwd_kernel(
if not EVEN_N: # Need to mask out otherwise the softmax is wrong
acc_s += tl.where((start_n + offs_n)[None, :] < seqlen_k, 0, float("-inf"))
if IS_CAUSAL:
acc_s += tl.where(offs_m[:, None] >= (start_n + offs_n)[None, :], 0, float("-inf"))
acc_s += tl.where(offs_m[:, None] + (seqlen_k - seqlen_q) >= (start_n + offs_n)[None, :], 0, float("-inf"))
if HAS_MASK:
acc_s += tl.where(mask, 0, float("-inf"))

Expand Down
Loading