Skip to content

Commit 7c8b314

Browse files
authored
Merge pull request #213 from flash-algo:fix-triton-decode
[BUG FIX] Correct causal mask handling for longer KV pairs
2 parents 266e4f3 + 289f2de commit 7c8b314

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

flash_sparse_attn/flash_sparse_attn_triton.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,7 @@ def _fwd_kernel(
165165
q = (q * softmax_scale).to(q.dtype)
166166

167167
# Loop over k, v and update accumulator
168-
end_n = seqlen_k if not IS_CAUSAL else tl.minimum((start_m + 1) * BLOCK_M, seqlen_k)
168+
end_n = seqlen_k if not IS_CAUSAL and seqlen_k <= seqlen_q else tl.minimum((start_m + 1) * BLOCK_M, seqlen_k)
169169
for start_n in range(0, end_n, BLOCK_N):
170170
start_n = tl.multiple_of(start_n, BLOCK_N)
171171

@@ -231,7 +231,7 @@ def _fwd_kernel(
231231
if not EVEN_N: # Need to mask out otherwise the softmax is wrong
232232
acc_s += tl.where((start_n + offs_n)[None, :] < seqlen_k, 0, float("-inf"))
233233
if IS_CAUSAL:
234-
acc_s += tl.where(offs_m[:, None] >= (start_n + offs_n)[None, :], 0, float("-inf"))
234+
acc_s += tl.where(offs_m[:, None] + (seqlen_k - seqlen_q) >= (start_n + offs_n)[None, :], 0, float("-inf"))
235235
if HAS_MASK:
236236
acc_s += tl.where(mask, 0, float("-inf"))
237237

0 commit comments

Comments
 (0)