Skip to content

Conversation

@bkryu
Copy link
Collaborator

@bkryu bkryu commented Dec 3, 2025

📌 Description

trtllm_batch_decode_with_kv_cache_mla and xqa_batch_decode_with_kv_cache_mla currently reside in decode.py. This PR moves them to mla.py and makes them show up in the documentation via adding them to attention.rst.

Note that the addition to documentation at the correct place requires this refactor as the docs generator looks at each module for indexing.

🔍 Related Issues

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

Summary by CodeRabbit

  • New Features

    • Added new batch decode entry points for MLA with KV cache, supporting additional configuration parameters for enhanced control.
    • XQA MLA functionality now officially documented and available.
  • Documentation

    • Updated API reference with expanded MLA module documentation.

✏️ Tip: You can customize this high-level summary in your review settings.

@bkryu
Copy link
Collaborator Author

bkryu commented Dec 3, 2025

/bot run

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Dec 3, 2025

Walkthrough

MLA decode functions are refactored from flashinfer/decode.py to a dedicated flashinfer/mla.py module. Two new batch decode entry points—trtllm_batch_decode_with_kv_cache_mla and xqa_batch_decode_with_kv_cache_mla—are added with extended parameters. Backward compatibility is maintained through import aliases in decode.py. Documentation and benchmarks are updated accordingly.

Changes

Cohort / File(s) Summary
Module Reorganization
flashinfer/decode.py
Removes MLA function definitions; imports trtllm_batch_decode_with_kv_cache_mla and xqa_batch_decode_with_kv_cache_mla from mla.py as backward-compatible aliases.
New MLA Decode Entrypoints
flashinfer/mla.py
Introduces trtllm_batch_decode_with_kv_cache_mla() and xqa_batch_decode_with_kv_cache_mla() functions with extended parameters (qk_nope_head_dim, kv_lora_rank, qk_rope_head_dim, block_tables, seq_lens, max_seq_len, bmm1_scale, bmm2_scale). Adds internal validation helper _check_trtllm_gen_mla_shape() and cached module loader get_trtllm_gen_fmha_module().
API Documentation
docs/api/attention.rst
Adds autosummary entries for xqa_mla (XQA section) and trtllm_batch_decode_with_kv_cache_mla (MLA page attention documentation).
Benchmark Update
benchmarks/routines/attention.py
Updates BatchMLAPagedAttentionWrapper to call flashinfer.mla.trtllm_batch_decode_with_kv_cache_mla instead of the decode.py variant, passing additional parameters and applying .squeeze(1) to the result.

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~20 minutes

  • flashinfer/mla.py: Contains new functions with validation logic, backend routing, and device capability checks requiring careful review of parameter handling and shape constraints.
  • flashinfer/decode.py: Straightforward import changes but requires verification that backward compatibility is properly maintained.
  • benchmarks/routines/attention.py: Verify that the API migration is complete and parameters are correctly mapped to the new function signature.
  • docs/api/attention.rst: Simple documentation additions; verify entries are consistent with exported symbols.

Possibly related PRs

Suggested reviewers

  • cyx-6
  • yzh119
  • nvmbreughe
  • wenscarl

Poem

🐰 Hop, hop—the MLA functions now reside,
In mla.py where they'll truly thrive with pride!
Backward aliases keep the old paths alive,
While new parameters help the decode to survive.
Refactored and documented, the code's all set—
Binky approves, no regression yet! 🎉

Pre-merge checks and finishing touches

❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 54.55% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
✅ Passed checks (2 passed)
Check name Status Explanation
Title check ✅ Passed The title accurately reflects the main change: moving MLA code from decode.py to mla.py and updating documentation.
Description check ✅ Passed The description covers the main objective (moving functions to mla.py and updating docs) and notes the reasoning (docs generator module indexing). Pre-commit and test sections are properly completed.
✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @bkryu, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request focuses on improving the code organization by refactoring MLA-related decoding functionalities into their own module. This change not only enhances modularity but also resolves an issue with documentation generation, ensuring that these important functions are correctly exposed in the API reference.

Highlights

  • Code Relocation: The Multi-Layer Attention (MLA) decoding functions, specifically trtllm_batch_decode_with_kv_cache_mla and xqa_batch_decode_with_kv_cache_mla, have been moved from flashinfer/decode.py to a new dedicated module, flashinfer/mla.py.
  • Documentation Update: The documentation (docs/api/attention.rst) has been updated to include the newly relocated MLA functions, ensuring they are properly indexed and appear in the generated API documentation.
  • Backward Compatibility: Aliases for the moved MLA functions have been added to flashinfer/decode.py to maintain backward compatibility for existing code that might still reference them from their original location.
  • Benchmark Adaptation: The benchmark routines (benchmarks/routines/attention.py) have been updated to reflect the new module path for the trtllm_batch_decode_with_kv_cache_mla function.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @bkryu, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request refactors the project's structure by migrating Multi-Layer Attention (MLA) specific functions from the general decode.py module to a newly created mla.py module. This change aims to improve code organization and modularity, which is also a prerequisite for accurately generating API documentation for these functions under the attention section. Backward compatibility is preserved through aliasing in the original decode.py file.

Highlights

  • Code Relocation: Moved Multi-Layer Attention (MLA) related functions (trtllm_batch_decode_with_kv_cache_mla, xqa_batch_decode_with_kv_cache_mla) from decode.py to a new mla.py module.
  • Documentation Update: Ensured the moved MLA functions are correctly included in the API documentation under attention.rst.
  • Backward Compatibility: Maintained existing API calls by adding aliases for the moved functions in decode.py.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @bkryu, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request undertakes a significant refactoring effort by relocating Multi-Layer Attention (MLA) functions into their own dedicated module. This change aims to improve the logical organization of the codebase and is essential for correctly integrating these functions into the project's documentation. The refactoring ensures that the project's structure is more modular and easier to navigate, while also preserving existing functionality through compatibility aliases.

Highlights

  • Code Relocation: Moved MLA (Multi-Layer Attention) related functions, specifically trtllm_batch_decode_with_kv_cache_mla and xqa_batch_decode_with_kv_cache_mla, from flashinfer/decode.py to a new flashinfer/mla.py module.
  • Documentation Update: Updated docs/api/attention.rst to include the newly moved MLA functions, ensuring they are properly documented.
  • Backward Compatibility: Maintained backward compatibility by adding import aliases for the moved MLA functions in flashinfer/decode.py.
  • Benchmark Adaptation: Modified benchmarks/routines/attention.py to reflect the new module path for the trtllm_batch_decode_with_kv_cache_mla function.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

@flashinfer-bot
Copy link
Collaborator

GitLab MR !173 has been created, and the CI pipeline #39501974 is currently running. I'll report back once the pipeline job completes.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request refactors the codebase by moving MLA-related functions (trtllm_batch_decode_with_kv_cache_mla and xqa_batch_decode_with_kv_cache_mla) from decode.py to a new mla.py module. This is a good change for modularity and code organization. The PR also updates the documentation to include these functions, which is a necessary follow-up. The refactoring appears to be done correctly, including maintaining backward compatibility. I have a couple of minor suggestions to improve code style and consistency.

Comment on lines +711 to +740
Parameters:
query: [batch_size, q_len_per_request, num_heads, head_dim_qk], head_dim_qk = qk_nope_head_dim (kv_lora_rank) + qk_rope_head_dim, should be concated q_nope + q_rope; q_len_per_request is the MTP query length.
kv_cache: [num_pages, page_size, head_dim_ckv + head_dim_kpe], should be concated ckv_cache + kpe_cache
workspace_buffer: torch.Tensor. Must be initialized to 0 for its first use.
qk_nope_head_dim: qk_nope_head_dim, must be 128
kv_lora_rank: kv_lora_rank, must be 512
qk_rope_head_dim: qk_rope_head_dim, must be 64
block_tables: page_table of kv cache, [batch_size, num_pages]
seq_lens: query_len
max_seq_len: max sequence length for kv_cache
out: output tensor, if not provided, will be allocated internally
bmm1_scale: fused scale for mla bmm1 input. Can be a float or a torch.Tensor.
bmm2_scale: fused scale for mla bmm2 input. Can be a float or a torch.Tensor.
sinks: additional value per head in the denominator of the softmax.
Note:
In MLA, the actual BMM1 and BMM2 scales applied would be fused as:
bmm1_scale = q_scale * k_scale * sm_scale / (head_dim_qk ** 0.5)
bmm2_scale = v_scale * o_scale
The two scale factors should be static constant for cuda graph capture.
Either (bmm1_scale, bmm2_scale) or (bmm1_scale_log2_tensor, bmm2_scale_tensor) should be provided.
For static constant scale factors, the scale factors should be provided as float.
- (bmm1_scale, bmm2_scale)
For on-device fused scale tensors, which could dynamically change, the scale factors should be provided as torch.Tensor.
- (bmm1_scale_log2_tensor, bmm2_scale_tensor)
- Currently, only fp8 tensor core operation supports this mode.
When both are provided, the dynamic scale factor tensors will be used.
"""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

For consistency with trtllm_batch_decode_with_kv_cache_mla and other functions in this file, please update the docstring for xqa_batch_decode_with_kv_cache_mla to use the NumPy docstring format. This improves readability and maintainability, following the spirit of PEP 257 for docstring conventions.

 """  Parameters  ----------  query: [batch_size, q_len_per_request, num_heads, head_dim_qk], head_dim_qk = qk_nope_head_dim (kv_lora_rank) + qk_rope_head_dim, should be concated q_nope + q_rope; q_len_per_request is the MTP query length.  kv_cache: [num_pages, page_size, head_dim_ckv + head_dim_kpe], should be concated ckv_cache + kpe_cache  workspace_buffer: torch.Tensor. Must be initialized to 0 for its first use.  qk_nope_head_dim: qk_nope_head_dim, must be 128  kv_lora_rank: kv_lora_rank, must be 512  qk_rope_head_dim: qk_rope_head_dim, must be 64  block_tables: page_table of kv cache, [batch_size, num_pages]  seq_lens: query_len  max_seq_len: max sequence length for kv_cache  out: output tensor, if not provided, will be allocated internally  bmm1_scale: fused scale for mla bmm1 input. Can be a float or a torch.Tensor.  bmm2_scale: fused scale for mla bmm2 input. Can be a float or a torch.Tensor.  sinks: additional value per head in the denominator of the softmax.   Note  ----  In MLA, the actual BMM1 and BMM2 scales applied would be fused as:  bmm1_scale = q_scale * k_scale * sm_scale / (head_dim_qk ** 0.5)  bmm2_scale = v_scale * o_scale   The two scale factors should be static constant for cuda graph capture.  Either (bmm1_scale, bmm2_scale) or (bmm1_scale_log2_tensor, bmm2_scale_tensor) should be provided.   For static constant scale factors, the scale factors should be provided as float.  - (bmm1_scale, bmm2_scale)  For on-device fused scale tensors, which could dynamically change, the scale factors should be provided as torch.Tensor.  - (bmm1_scale_log2_tensor, bmm2_scale_tensor)  - Currently, only fp8 tensor core operation supports this mode.  When both are provided, the dynamic scale factor tensors will be used.  """
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request is a nice refactoring that moves MLA-related code from decode.py to a new mla.py file, improving code organization. The documentation has also been updated accordingly. I have a few minor suggestions to improve code style and documentation consistency.

Comment on lines +27 to +31
from .mla import (
trtllm_batch_decode_with_kv_cache_mla as trtllm_batch_decode_with_kv_cache_mla,
xqa_batch_decode_with_kv_cache_mla as xqa_batch_decode_with_kv_cache_mla,
)
from .xqa import xqa, xqa_mla as xqa_mla
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The as clauses in these imports are redundant. You can simplify them for better readability.

Suggested change
from .mla import (
trtllm_batch_decode_with_kv_cache_mla as trtllm_batch_decode_with_kv_cache_mla,
xqa_batch_decode_with_kv_cache_mla as xqa_batch_decode_with_kv_cache_mla,
)
from .xqa import xqa, xqa_mla as xqa_mla
from .mla import (
trtllm_batch_decode_with_kv_cache_mla,
xqa_batch_decode_with_kv_cache_mla,
)
from .xqa import xqa, xqa_mla
bmm1_scale = q_scale * k_scale * sm_scale / (head_dim_qk ** 0.5)
bmm2_scale = v_scale * o_scale
or,
bmm1_scale = torch.Tensor([q_scale * k_scale * sm_scale / (head_dim_qk ** 0.5))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

There seems to be a typo in the docstring. A closing parenthesis ) is missing in the torch.Tensor example.

Suggested change
bmm1_scale = torch.Tensor([q_scale * k_scale * sm_scale / (head_dim_qk ** 0.5))
bmm1_scale = torch.Tensor([q_scale * k_scale * sm_scale / (head_dim_qk ** 0.5)])
Comment on lines +710 to +740
"""
Parameters:
query: [batch_size, q_len_per_request, num_heads, head_dim_qk], head_dim_qk = qk_nope_head_dim (kv_lora_rank) + qk_rope_head_dim, should be concated q_nope + q_rope; q_len_per_request is the MTP query length.
kv_cache: [num_pages, page_size, head_dim_ckv + head_dim_kpe], should be concated ckv_cache + kpe_cache
workspace_buffer: torch.Tensor. Must be initialized to 0 for its first use.
qk_nope_head_dim: qk_nope_head_dim, must be 128
kv_lora_rank: kv_lora_rank, must be 512
qk_rope_head_dim: qk_rope_head_dim, must be 64
block_tables: page_table of kv cache, [batch_size, num_pages]
seq_lens: query_len
max_seq_len: max sequence length for kv_cache
out: output tensor, if not provided, will be allocated internally
bmm1_scale: fused scale for mla bmm1 input. Can be a float or a torch.Tensor.
bmm2_scale: fused scale for mla bmm2 input. Can be a float or a torch.Tensor.
sinks: additional value per head in the denominator of the softmax.
Note:
In MLA, the actual BMM1 and BMM2 scales applied would be fused as:
bmm1_scale = q_scale * k_scale * sm_scale / (head_dim_qk ** 0.5)
bmm2_scale = v_scale * o_scale
The two scale factors should be static constant for cuda graph capture.
Either (bmm1_scale, bmm2_scale) or (bmm1_scale_log2_tensor, bmm2_scale_tensor) should be provided.
For static constant scale factors, the scale factors should be provided as float.
- (bmm1_scale, bmm2_scale)
For on-device fused scale tensors, which could dynamically change, the scale factors should be provided as torch.Tensor.
- (bmm1_scale_log2_tensor, bmm2_scale_tensor)
- Currently, only fp8 tensor core operation supports this mode.
When both are provided, the dynamic scale factor tensors will be used.
"""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The docstring for this function doesn't follow the same numpy-style format as trtllm_batch_decode_with_kv_cache_mla in this file. For consistency, it would be great to update it.

 """  Parameters  ----------  query: torch.Tensor  [batch_size, q_len_per_request, num_heads, head_dim_qk], head_dim_qk = qk_nope_head_dim (kv_lora_rank) + qk_rope_head_dim, should be concated q_nope + q_rope; q_len_per_request is the MTP query length.  kv_cache: torch.Tensor  [num_pages, page_size, head_dim_ckv + head_dim_kpe], should be concated ckv_cache + kpe_cache  workspace_buffer: torch.Tensor  Must be initialized to 0 for its first use.  qk_nope_head_dim: int  qk_nope_head_dim, must be 128  kv_lora_rank: int  kv_lora_rank, must be 512  qk_rope_head_dim: int  qk_rope_head_dim, must be 64  block_tables: torch.Tensor  page_table of kv cache, [batch_size, num_pages]  seq_lens: torch.Tensor  query_len  max_seq_len: int  max sequence length for kv_cache  out: Optional[torch.Tensor]  output tensor, if not provided, will be allocated internally  bmm1_scale: Union[float, torch.Tensor]  fused scale for mla bmm1 input. Can be a float or a torch.Tensor.  bmm2_scale: Union[float, torch.Tensor]  fused scale for mla bmm2 input. Can be a float or a torch.Tensor.  sinks: Optional[List[torch.Tensor]]  additional value per head in the denominator of the softmax.   Note  ----  In MLA, the actual BMM1 and BMM2 scales applied would be fused as:  bmm1_scale = q_scale * k_scale * sm_scale / (head_dim_qk ** 0.5)  bmm2_scale = v_scale * o_scale   The two scale factors should be static constant for cuda graph capture.  Either (bmm1_scale, bmm2_scale) or (bmm1_scale_log2_tensor, bmm2_scale_tensor) should be provided.   For static constant scale factors, the scale factors should be provided as float.  - (bmm1_scale, bmm2_scale)  For on-device fused scale tensors, which could dynamically change, the scale factors should be provided as torch.Tensor.  - (bmm1_scale_log2_tensor, bmm2_scale_tensor)  - Currently, only fp8 tensor core operation supports this mode.  When both are provided, the dynamic scale factor tensors will be used.  """
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request is a good refactoring that moves MLA-related code into its own mla.py module and updates the documentation accordingly. This improves code organization. I've identified a potential issue where a parameter could be silently ignored, and have also suggested some improvements for documentation and code style consistency.

bmm1_scale = bmm1_scale * log2e
if isinstance(bmm2_scale, torch.Tensor):
assert bmm2_scale.dtype == torch.float32
if backend == "xqa":
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

When backend is 'xqa', the function calls xqa_batch_decode_with_kv_cache_mla, which does not support sparse attention (sparse_mla_top_k > 0). However, there is no check to prevent this. If a user calls this function with backend='xqa' and sparse_mla_top_k > 0, the sparse parameter will be silently ignored. You should add a check to raise an error in this case.

 if backend == "xqa": if sparse_mla_top_k > 0: raise ValueError("XQA backend does not support sparse MLA attention.")
Comment on lines +27 to +31
from .mla import (
trtllm_batch_decode_with_kv_cache_mla as trtllm_batch_decode_with_kv_cache_mla,
xqa_batch_decode_with_kv_cache_mla as xqa_batch_decode_with_kv_cache_mla,
)
from .xqa import xqa, xqa_mla as xqa_mla
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The aliases for the imported functions are redundant. You can simplify these imports for better readability.

Suggested change
from .mla import (
trtllm_batch_decode_with_kv_cache_mla as trtllm_batch_decode_with_kv_cache_mla,
xqa_batch_decode_with_kv_cache_mla as xqa_batch_decode_with_kv_cache_mla,
)
from .xqa import xqa, xqa_mla as xqa_mla
from .mla import (
trtllm_batch_decode_with_kv_cache_mla,
xqa_batch_decode_with_kv_cache_mla,
)
from .xqa import xqa, xqa_mla
Comment on lines +544 to +587
"""
Parameters
----------
query: [batch_size, q_len_per_request, num_heads, head_dim_qk], head_dim_qk = qk_nope_head_dim (kv_lora_rank) + qk_rope_head_dim, should be concated q_nope + q_rope; q_len_per_request is the MTP query length.
kv_cache: [num_pages, page_size, head_dim_ckv + head_dim_kpe], should be concated ckv_cache + kpe_cache
workspace_buffer: [num_semaphores, 4], used for multi_block mode. Must be initialized to 0 for its first use.
qk_nope_head_dim: qk_nope_head_dim, must be 128
kv_lora_rank: kv_lora_rank, must be 512
qk_rope_head_dim: qk_rope_head_dim, must be 64
sparse_mla_top_k: sparse MLA top k, must be 0 for non-sparse MLA.
block_tables: page_table of kv cache, [batch_size, num_pages]
seq_lens: query_len
max_seq_len: max sequence length for kv_cache
out: output tensor, if not provided, will be allocated internally
bmm1_scale: fused scale for mla bmm1 input.
when using trtllm-gen backend, it can be a torch.Tensor with dtype torch.float32.
bmm2_scale: fused scale for mla bmm2 input.
when using trtllm-gen backend, it can be a torch.Tensor with dtype torch.float32.
sinks: additional value per head in the denominator of the softmax.
backend : str = "auto"
The implementation backend, could be ``auto``/``xqa`` or ``trtllm-gen``. Defaults to ``auto``.
When set to ``auto``, the backend will be chosen based on the device architecture and kernel availability.
For sm_100 and sm_103 (blackwell architecture), ``auto`` will choose ``trtllm-gen`` backend.
For sm_120 (blackwell architecture), ``auto`` will choose ``xqa`` backend.
Note
----
In MLA, the actual BMM1 and BMM2 scales applied would be fused as:
bmm1_scale = q_scale * k_scale * sm_scale / (head_dim_qk ** 0.5)
bmm2_scale = v_scale * o_scale
or,
bmm1_scale = torch.Tensor([q_scale * k_scale * sm_scale / (head_dim_qk ** 0.5))
bmm2_scale = torch.Tensor([v_scale * o_scale])
The two scale factors should be static constant for cuda graph capture.
Either (bmm1_scale, bmm2_scale) or (bmm1_scale_log2_tensor, bmm2_scale_tensor) should be provided.
For static constant scale factors, the scale factors should be provided as float.
- (bmm1_scale, bmm2_scale)
For on-device fused scale tensors, which could dynamically change, the scale factors should be provided as torch.Tensor.
- (bmm1_scale_log2_tensor, bmm2_scale_tensor)
- Currently, only fp8 tensor core operation supports this mode.
When both are provided, the dynamic scale factor tensors will be used.
"""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The docstring for trtllm_batch_decode_with_kv_cache_mla is missing parameter types, and the enable_pdl parameter is not documented. Adding these would improve clarity and consistency with the function's type hints.

 """  Parameters  ----------  query: torch.Tensor  [batch_size, q_len_per_request, num_heads, head_dim_qk], head_dim_qk = qk_nope_head_dim (kv_lora_rank) + qk_rope_head_dim, should be concated q_nope + q_rope; q_len_per_request is the MTP query length.  kv_cache: torch.Tensor  [num_pages, page_size, head_dim_ckv + head_dim_kpe], should be concated ckv_cache + kpe_cache  workspace_buffer: torch.Tensor  [num_semaphores, 4], used for multi_block mode. Must be initialized to 0 for its first use.  qk_nope_head_dim: int  qk_nope_head_dim, must be 128  kv_lora_rank: int  kv_lora_rank, must be 512  qk_rope_head_dim: int  qk_rope_head_dim, must be 64  sparse_mla_top_k: int  sparse MLA top k, must be 0 for non-sparse MLA.  block_tables: torch.Tensor  page_table of kv cache, [batch_size, num_pages]  seq_lens: torch.Tensor  query_len  max_seq_len: int  max sequence length for kv_cache  out: Optional[torch.Tensor]  output tensor, if not provided, will be allocated internally  bmm1_scale: Union[float, torch.Tensor]  fused scale for mla bmm1 input.  when using trtllm-gen backend, it can be a torch.Tensor with dtype torch.float32.  bmm2_scale: Union[float, torch.Tensor]  fused scale for mla bmm2 input.  when using trtllm-gen backend, it can be a torch.Tensor with dtype torch.float32.  sinks: Optional[List[torch.Tensor]]  additional value per head in the denominator of the softmax.  enable_pdl: Optional[bool]  Whether to enable Programmatic Dependent Launch (PDL).  backend : str  The implementation backend, could be ``auto``/``xqa`` or ``trtllm-gen``. Defaults to ``auto``.  When set to ``auto``, the backend will be chosen based on the device architecture and kernel availability.  For sm_100 and sm_103 (blackwell architecture), ``auto`` will choose ``trtllm-gen`` backend.  For sm_120 (blackwell architecture), ``auto`` will choose ``xqa`` backend.   Note  ----  In MLA, the actual BMM1 and BMM2 scales applied would be fused as:  bmm1_scale = q_scale * k_scale * sm_scale / (head_dim_qk ** 0.5)  bmm2_scale = v_scale * o_scale  or,  bmm1_scale = torch.Tensor([q_scale * k_scale * sm_scale / (head_dim_qk ** 0.5))  bmm2_scale = torch.Tensor([v_scale * o_scale])   The two scale factors should be static constant for cuda graph capture.  Either (bmm1_scale, bmm2_scale) or (bmm1_scale_log2_tensor, bmm2_scale_tensor) should be provided.   For static constant scale factors, the scale factors should be provided as float.  - (bmm1_scale, bmm2_scale)  For on-device fused scale tensors, which could dynamically change, the scale factors should be provided as torch.Tensor.  - (bmm1_scale_log2_tensor, bmm2_scale_tensor)  - Currently, only fp8 tensor core operation supports this mode.  When both are provided, the dynamic scale factor tensors will be used.  """
Comment on lines +710 to +740
"""
Parameters:
query: [batch_size, q_len_per_request, num_heads, head_dim_qk], head_dim_qk = qk_nope_head_dim (kv_lora_rank) + qk_rope_head_dim, should be concated q_nope + q_rope; q_len_per_request is the MTP query length.
kv_cache: [num_pages, page_size, head_dim_ckv + head_dim_kpe], should be concated ckv_cache + kpe_cache
workspace_buffer: torch.Tensor. Must be initialized to 0 for its first use.
qk_nope_head_dim: qk_nope_head_dim, must be 128
kv_lora_rank: kv_lora_rank, must be 512
qk_rope_head_dim: qk_rope_head_dim, must be 64
block_tables: page_table of kv cache, [batch_size, num_pages]
seq_lens: query_len
max_seq_len: max sequence length for kv_cache
out: output tensor, if not provided, will be allocated internally
bmm1_scale: fused scale for mla bmm1 input. Can be a float or a torch.Tensor.
bmm2_scale: fused scale for mla bmm2 input. Can be a float or a torch.Tensor.
sinks: additional value per head in the denominator of the softmax.
Note:
In MLA, the actual BMM1 and BMM2 scales applied would be fused as:
bmm1_scale = q_scale * k_scale * sm_scale / (head_dim_qk ** 0.5)
bmm2_scale = v_scale * o_scale
The two scale factors should be static constant for cuda graph capture.
Either (bmm1_scale, bmm2_scale) or (bmm1_scale_log2_tensor, bmm2_scale_tensor) should be provided.
For static constant scale factors, the scale factors should be provided as float.
- (bmm1_scale, bmm2_scale)
For on-device fused scale tensors, which could dynamically change, the scale factors should be provided as torch.Tensor.
- (bmm1_scale_log2_tensor, bmm2_scale_tensor)
- Currently, only fp8 tensor core operation supports this mode.
When both are provided, the dynamic scale factor tensors will be used.
"""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The docstring for xqa_batch_decode_with_kv_cache_mla is inconsistent with the numpy docstring format used in trtllm_batch_decode_with_kv_cache_mla. It's also missing parameter types and the enable_pdl parameter. For consistency and clarity, it should be updated.

 """  Parameters  ----------  query: torch.Tensor  [batch_size, q_len_per_request, num_heads, head_dim_qk], head_dim_qk = qk_nope_head_dim (kv_lora_rank) + qk_rope_head_dim, should be concated q_nope + q_rope; q_len_per_request is the MTP query length.  kv_cache: torch.Tensor  [num_pages, page_size, head_dim_ckv + head_dim_kpe], should be concated ckv_cache + kpe_cache  workspace_buffer: torch.Tensor  torch.Tensor. Must be initialized to 0 for its first use.  qk_nope_head_dim: int  qk_nope_head_dim, must be 128  kv_lora_rank: int  kv_lora_rank, must be 512  qk_rope_head_dim: int  qk_rope_head_dim, must be 64  block_tables: torch.Tensor  page_table of kv cache, [batch_size, num_pages]  seq_lens: torch.Tensor  query_len  max_seq_len: int  max sequence length for kv_cache  out: Optional[torch.Tensor]  output tensor, if not provided, will be allocated internally  bmm1_scale: Union[float, torch.Tensor]  fused scale for mla bmm1 input. Can be a float or a torch.Tensor.  bmm2_scale: Union[float, torch.Tensor]  fused scale for mla bmm2 input. Can be a float or a torch.Tensor.  sinks: Optional[List[torch.Tensor]]  additional value per head in the denominator of the softmax.  enable_pdl: Optional[bool]  Whether to enable Programmatic Dependent Launch (PDL).   Note  ----  In MLA, the actual BMM1 and BMM2 scales applied would be fused as:  bmm1_scale = q_scale * k_scale * sm_scale / (head_dim_qk ** 0.5)  bmm2_scale = v_scale * o_scale   The two scale factors should be static constant for cuda graph capture.  Either (bmm1_scale, bmm2_scale) or (bmm1_scale_log2_tensor, bmm2_scale_tensor) should be provided.   For static constant scale factors, the scale factors should be provided as float.  - (bmm1_scale, bmm2_scale)  For on-device fused scale tensors, which could dynamically change, the scale factors should be provided as torch.Tensor.  - (bmm1_scale_log2_tensor, bmm2_scale_tensor)  - Currently, only fp8 tensor core operation supports this mode.  When both are provided, the dynamic scale factor tensors will be used.  """
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🧹 Nitpick comments (4)
flashinfer/mla.py (4)

87-91: Consider using _ for unpacked but unused variable H.

The variable H at line 87 is unpacked but never used. Per static analysis hint, prefix it with an underscore to indicate it's intentionally unused:

- B_q, Q_len, H, D_q = query.shape + B_q, Q_len, _H, D_q = query.shape

The commented-out num_heads check (lines 89-91) with the TODO suggests this might be DeepSeek-specific. Consider either removing the dead code or documenting the decision.


541-541: Use explicit Optional[bool] instead of implicit None default.

Per PEP 484, the type annotation should explicitly indicate optionality:

- enable_pdl: bool = None, + enable_pdl: Optional[bool] = None,

703-703: Unused parameter max_seq_len should be documented or removed.

The max_seq_len parameter is declared but never used in the function body. If this is intentional (e.g., for API consistency with trtllm_batch_decode_with_kv_cache_mla), consider adding a comment. Otherwise, it should be removed to avoid confusion.


708-708: Use explicit Optional[bool] type annotation.

- enable_pdl: bool = None, + enable_pdl: Optional[bool] = None,
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 4efb7bb and 8e9250d.

📒 Files selected for processing (4)
  • benchmarks/routines/attention.py (1 hunks)
  • docs/api/attention.rst (2 hunks)
  • flashinfer/decode.py (1 hunks)
  • flashinfer/mla.py (3 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
flashinfer/decode.py (2)
flashinfer/mla.py (2)
  • trtllm_batch_decode_with_kv_cache_mla (526-690)
  • xqa_batch_decode_with_kv_cache_mla (694-804)
flashinfer/xqa.py (4)
  • xqa (65-112)
  • xqa (148-333)
  • xqa_mla (358-391)
  • xqa_mla (420-530)
benchmarks/routines/attention.py (1)
flashinfer/mla.py (1)
  • trtllm_batch_decode_with_kv_cache_mla (526-690)
🪛 Ruff (0.14.7)
flashinfer/mla.py

77-77: Avoid specifying long messages outside the exception class

(TRY003)


79-79: Avoid specifying long messages outside the exception class

(TRY003)


81-81: Avoid specifying long messages outside the exception class

(TRY003)


83-83: Avoid specifying long messages outside the exception class

(TRY003)


85-85: Avoid specifying long messages outside the exception class

(TRY003)


87-87: Unpacked variable H is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


93-95: Avoid specifying long messages outside the exception class

(TRY003)


100-102: Avoid specifying long messages outside the exception class

(TRY003)


107-109: Avoid specifying long messages outside the exception class

(TRY003)


111-113: Avoid specifying long messages outside the exception class

(TRY003)


541-541: PEP 484 prohibits implicit Optional

Convert to T | None

(RUF013)


603-605: Avoid specifying long messages outside the exception class

(TRY003)


607-607: Avoid specifying long messages outside the exception class

(TRY003)


609-611: Avoid specifying long messages outside the exception class

(TRY003)


639-639: Avoid specifying long messages outside the exception class

(TRY003)


653-653: Consider (*query.shape[:-1], kv_lora_rank) instead of concatenation

Replace with (*query.shape[:-1], kv_lora_rank)

(RUF005)


690-690: Avoid specifying long messages outside the exception class

(TRY003)


703-703: Unused function argument: max_seq_len

(ARG001)


708-708: PEP 484 prohibits implicit Optional

Convert to T | None

(RUF013)


747-749: Avoid specifying long messages outside the exception class

(TRY003)


751-753: Avoid specifying long messages outside the exception class

(TRY003)


755-755: Avoid specifying long messages outside the exception class

(TRY003)


769-769: Consider (*query.shape[:-1], kv_lora_rank) instead of concatenation

Replace with (*query.shape[:-1], kv_lora_rank)

(RUF005)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Deploy Docs
🔇 Additional comments (3)
docs/api/attention.rst (1)

50-50: LGTM! Documentation entries align with the refactored module structure.

The new symbols xqa_mla under flashinfer.xqa and trtllm_batch_decode_with_kv_cache_mla under flashinfer.mla are correctly documented and match the module locations in the code changes.

Also applies to: 102-105

benchmarks/routines/attention.py (1)

1839-1852: LGTM! The module path update correctly reflects the refactored MLA function location.

The call to flashinfer.mla.trtllm_batch_decode_with_kv_cache_mla is consistent with the PR's refactoring of MLA functions into the dedicated mla.py module. The parameter mappings and shape handling (unsqueeze/squeeze) are appropriate.

Consider addressing the TODO comment at line 1844 regarding the hardcoded qk_nope_head_dim=128 to improve code clarity for future maintainers.

flashinfer/decode.py (1)

25-31: LGTM! Backward-compatible re-exports maintain API stability.

The explicit aliasing pattern (from .mla import trtllm_batch_decode_with_kv_cache_mla as trtllm_batch_decode_with_kv_cache_mla) correctly re-exports the moved functions, ensuring existing code that imports from flashinfer.decode continues to work without modification. The comment at line 26 effectively documents the rationale.

Comment on lines +652 to +663
if out is None:
out_shape = query.shape[:-1] + (kv_lora_rank,)
out = torch.empty(out_shape, dtype=torch.bfloat16, device=query.device)
else:
batch_size, _, num_q_heads, _ = query.shape
check_shape_dtype_device(
out,
[batch_size, num_q_heads, kv_lora_rank],
torch.bfloat16,
query.device,
"out",
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Shape mismatch between allocated output tensor and validation for provided output.

When out is None, the allocated shape is 4D: query.shape[:-1] + (kv_lora_rank,) = (batch_size, q_len, num_heads, kv_lora_rank).

However, when out is provided, it's validated against a 3D shape: [batch_size, num_q_heads, kv_lora_rank].

This inconsistency could cause issues when users provide their own output tensor.

 if out is None: - out_shape = query.shape[:-1] + (kv_lora_rank,) + out_shape = (*query.shape[:-1], kv_lora_rank) out = torch.empty(out_shape, dtype=torch.bfloat16, device=query.device) else: batch_size, _, num_q_heads, _ = query.shape check_shape_dtype_device( out, - [batch_size, num_q_heads, kv_lora_rank], + [batch_size, 1, num_q_heads, kv_lora_rank], torch.bfloat16, query.device, "out", )

Alternatively, verify if the 3D validation shape is intentional and the output is squeezed elsewhere, but this should be documented clearly.

📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
if out is None:
out_shape = query.shape[:-1] + (kv_lora_rank,)
out = torch.empty(out_shape, dtype=torch.bfloat16, device=query.device)
else:
batch_size, _, num_q_heads, _ = query.shape
check_shape_dtype_device(
out,
[batch_size, num_q_heads, kv_lora_rank],
torch.bfloat16,
query.device,
"out",
)
if out is None:
out_shape = (*query.shape[:-1], kv_lora_rank)
out = torch.empty(out_shape, dtype=torch.bfloat16, device=query.device)
else:
batch_size, q_len, num_q_heads, _ = query.shape
check_shape_dtype_device(
out,
[batch_size, q_len, num_q_heads, kv_lora_rank],
torch.bfloat16,
query.device,
"out",
)
🧰 Tools
🪛 Ruff (0.14.7)

653-653: Consider (*query.shape[:-1], kv_lora_rank) instead of concatenation

Replace with (*query.shape[:-1], kv_lora_rank)

(RUF005)

🤖 Prompt for AI Agents
In flashinfer/mla.py around lines 652 to 663, the code allocates out as a 4D tensor (query.shape[:-1] + (kv_lora_rank,) => [batch_size, q_len, num_q_heads, kv_lora_rank]) but validates a provided out against a 3D shape [batch_size, num_q_heads, kv_lora_rank], causing a shape mismatch; update the validation to expect the same 4D shape (use [batch_size, query.shape[1], num_q_heads, kv_lora_rank] or equivalent) when out is provided, and keep dtype/device checks the same, or alternatively change the allocation to 3D if the function truly expects a 3D output—ensure both allocation and validation use the identical shape and add a brief comment clarifying the expected out dimensionality. 
Comment on lines +768 to +779
if out is None:
out_shape = query.shape[:-1] + (kv_lora_rank,)
out = torch.empty(out_shape, dtype=torch.bfloat16, device=query.device)
else:
batch_size, _, num_q_heads, _ = query.shape
check_shape_dtype_device(
out,
[batch_size, num_q_heads, kv_lora_rank],
torch.bfloat16,
query.device,
"out",
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Same shape mismatch issue as in trtllm_batch_decode_with_kv_cache_mla.

The allocated output shape is 4D but the validation for provided output checks against 3D. Apply a consistent fix:

 if out is None: - out_shape = query.shape[:-1] + (kv_lora_rank,) + out_shape = (*query.shape[:-1], kv_lora_rank) out = torch.empty(out_shape, dtype=torch.bfloat16, device=query.device) else: batch_size, _, num_q_heads, _ = query.shape check_shape_dtype_device( out, - [batch_size, num_q_heads, kv_lora_rank], + [batch_size, 1, num_q_heads, kv_lora_rank], torch.bfloat16, query.device, "out", )
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
if out is None:
out_shape = query.shape[:-1] + (kv_lora_rank,)
out = torch.empty(out_shape, dtype=torch.bfloat16, device=query.device)
else:
batch_size, _, num_q_heads, _ = query.shape
check_shape_dtype_device(
out,
[batch_size, num_q_heads, kv_lora_rank],
torch.bfloat16,
query.device,
"out",
)
if out is None:
out_shape = (*query.shape[:-1], kv_lora_rank)
out = torch.empty(out_shape, dtype=torch.bfloat16, device=query.device)
else:
batch_size, _, num_q_heads, _ = query.shape
check_shape_dtype_device(
out,
[batch_size, 1, num_q_heads, kv_lora_rank],
torch.bfloat16,
query.device,
"out",
)
🧰 Tools
🪛 Ruff (0.14.7)

769-769: Consider (*query.shape[:-1], kv_lora_rank) instead of concatenation

Replace with (*query.shape[:-1], kv_lora_rank)

(RUF005)

🤖 Prompt for AI Agents
In flashinfer/mla.py around lines 768 to 779, the code allocates out as a 4D tensor but the provided-output validation expects a 3D tensor, causing a shape mismatch; fix by allocating out with the same 3D shape the validator expects: extract batch_size and num_q_heads from query (batch_size, _, num_q_heads, _) and set out_shape = (batch_size, num_q_heads, kv_lora_rank) when out is None, preserving dtype and device so the allocated tensor matches the check_shape_dtype_device call. 
@yzh119
Copy link
Collaborator

yzh119 commented Dec 3, 2025

/bot run

@flashinfer-bot
Copy link
Collaborator

GitLab MR !173 has been created, and the CI pipeline #39508973 is currently running. I'll report back once the pipeline job completes.

@flashinfer-bot
Copy link
Collaborator

[FAILED] Pipeline #39508973: 4/20 passed

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

3 participants