Skip to content

Conversation

@LoserCheems
Copy link
Collaborator

Summary

  • Fixes NaN in bias.grad when running on SM80 with both attn_mask and attn_bias enabled using the sample in case.py.

Root Cause

  • On SM80 with large head dims, a runtime branch forced a “single-split” fallback based on device SMEM. This diverged kernel paths between split/non-split selection and led to inconsistent shared-memory layouts/predicates for mask+bias tiles in the last N-block, producing corrupted attention scores that propagated to NaN in dbias.

Changes

  • Drops device SMEM query and the special-case fallback to a single split on limited-SMEM GPUs.
  • Unifies the forward path to the split-kv implementation and defers split selection to the caller, eliminating the risky runtime branch.
  • Touched code paths:
    • Forward API and launch path: FLASH_NAMESPACE::mha_fwd, run/dispatch helpers in flash_fwd_launch_template.h
    • Split-KV forward kernel path (mask+bias handling and predicates): FLASH_NAMESPACE::compute_attn_splitkv
  • Commit message:
    • Removes SMEM-based split-kv restriction
    • Drops device SMEM query and the special case that forced a single split on limited-SMEM GPUs for large head dims.
    • Simplifies the forward path and defers split selection to the caller, reducing runtime branching.

Reproduction

  • Env: SM80 GPU (Ampere), bf16/fp16.
  • Script: case.py
  • Shell:
    • python case.py
  • Observe before: for case "both" (use_bias=True, use_mask=True) → grad_bias_has_nan: True.
  • After fix: all out_has_nan and grad_*_has_nan are False across:
    • bias_only, mask_only, both, neither.

Tests

  • Added/updated unit to assert no NaN in outputs and grads for the four scenarios above with is_causal=True and deterministic=True using _flash_dynamic_mask_attention_forward from flash_dynamic_mask_attention.py.
  • Stress across representative head dims and varying key_length to cover last N-block predicate paths.
  • Validated numerical equivalence with reference attention for mask+bias where applicable.

Compatibility

  • No API signature changes.
  • Behavior note: split selection is now caller-driven. If code previously relied on implicit SMEM-based fallback, ensure an explicit num_splits/keep_window_size policy is provided upstream. Default behavior remains functional with num_splits=1; performance tuning may require an explicit split on large heads.

Checklist

  • [] Linked issue provided
  • Adds or updates tests
  • Updates docs if needed (note on split selection responsibility)
  • No perf regressions (removing branching reduces overhead; split choice remains tunable)
Drops device SMEM query and the special case that forced a single split on limited-SMEM GPUs for large head dims. Simplifies the forward path and defers split selection to the caller, reducing runtime branching.
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR fixes a NaN bug in bias gradients when running on SM80 GPUs with both attention mask and bias enabled. The issue stemmed from a runtime branch that forced a "single-split" fallback based on device shared memory limitations, causing inconsistent kernel paths and corrupted attention scores.

  • Removes device SMEM query and special-case fallback logic for limited-SMEM GPUs
  • Unifies the forward path to use split-kv implementation consistently
  • Defers split selection responsibility to the caller to eliminate risky runtime branching

Tip: Customize your code reviews with copilot-instructions.md. Create the file or learn how to get started.

@LoserCheems LoserCheems merged commit af2ec35 into main Sep 22, 2025
3 of 4 checks passed
@LoserCheems LoserCheems deleted the fix-nan-in-sm80 branch November 13, 2025 04:41
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

10 participants