Skip to content

Conversation

@LoserCheems
Copy link
Collaborator

Summary

  • Fixes an issue where the causal mask did not properly handle cases with longer key-value pairs compared to queries.

Root Cause

  • The causal attention mechanism failed to ignore future tokens when the key-value length exceeded the query length, leading to invalid attention weights.

Changes

  • Adjusted the forward loop bounds to respect shorter queries and modified the causal mask offset to align with the gaps between KV and Q.

Reproduction

  • A minimal example involves using a sequence where the key-value length is greater than the query length, observing incorrect attention weights.

Tests

  • Updated tests to validate the correct behavior of the causal mask with varying lengths of KV and Q.

Compatibility

  • No backward compatibility issues identified.

Checklist

  • Linked issue provided
  • Adds or updates tests
  • Updates docs if needed
  • No perf regressions
Ensures the forward loop bounds respect shorter queries so causal attention ignores future tokens when kv > q Adjusts the causal mask offset to align with kv-q gaps and prevent invalid attention weights
Copilot AI review requested due to automatic review settings December 2, 2025 02:13
@LoserCheems LoserCheems merged commit 7c8b314 into main Dec 2, 2025
7 of 8 checks passed
Copilot finished reviewing on behalf of LoserCheems December 2, 2025 02:16
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR attempts to fix causal mask handling when the key-value sequence length exceeds the query sequence length, a common scenario in KV caching during inference. The changes modify the forward kernel's loop bounds and causal mask offset calculation.

Key Changes:

  • Modified the end_n loop bound calculation to account for longer KV sequences
  • Added offset (seqlen_k - seqlen_q) to the causal mask condition to properly align query positions with the key timeline

Critical Issue Identified:
The logic in line 168 contains a bug that will break non-causal attention when seqlen_k > seqlen_q, as it incorrectly applies causal-like loop bounds in this case.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.


# Loop over k, v and update accumulator
end_n = seqlen_k if not IS_CAUSAL else tl.minimum((start_m + 1) * BLOCK_M, seqlen_k)
end_n = seqlen_k if not IS_CAUSAL and seqlen_k <= seqlen_q else tl.minimum((start_m + 1) * BLOCK_M, seqlen_k)
Copy link

Copilot AI Dec 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The condition not IS_CAUSAL and seqlen_k <= seqlen_q is logically incorrect. This causes causal-like behavior to be applied when IS_CAUSAL is False and seqlen_k > seqlen_q, which breaks non-causal attention with longer KV sequences.

When IS_CAUSAL is False, end_n should always be seqlen_k regardless of the relationship between seqlen_k and seqlen_q. The condition should be:

end_n = seqlen_k if not IS_CAUSAL else tl.minimum((start_m + 1) * BLOCK_M + (seqlen_k - seqlen_q), seqlen_k)

The offset (seqlen_k - seqlen_q) should only be applied to the causal case, not used to determine whether to apply causal behavior.

Suggested change
end_n = seqlen_k if not IS_CAUSAL and seqlen_k <= seqlen_q else tl.minimum((start_m + 1) * BLOCK_M, seqlen_k)
end_n = seqlen_k if not IS_CAUSAL else tl.minimum((start_m + 1) * BLOCK_M + (seqlen_k - seqlen_q), seqlen_k)
Copilot uses AI. Check for mistakes.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

8 participants