Skip to content

Conversation

@nvmbreughe
Copy link
Contributor

@nvmbreughe nvmbreughe commented Nov 21, 2025

📌 Description

A unified API for the MNNVL and single-node AllReduce kernels.

The backend will be chosen during workspace creation. We can either pick it explicitly, or use the "auto" backend to have a heuristic pick the best backend.

🔍 Related Issues

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Nov 21, 2025

Important

Review skipped

Draft detected.

Please check the settings in the CodeRabbit UI or the .coderabbit.yaml file in this repository. To trigger a single review, invoke the @coderabbitai review command.

You can disable this status message by setting the reviews.review_status to false in the CodeRabbit configuration file.

✨ Finishing touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

max_token_num: int,
hidden_dim: int,
dtype: torch.dtype,
topology: str,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
topology: str,
topology: Literal["single_node", "multi_node"],
max_token_num: int = None,
hidden_dim: int = None,
dtype: torch.dtype = None,
topology: str = "single_node",
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think it is needed longer term since we will use the same pytorch symmetric API to allocate symmetric memory for single and multi-node (under the cover pytorch/NCCL/NVSHMEM will detect platform and decides the right mem allocation handle)

input: torch.Tensor,
workspace: AllReduceFusionWorkspace,
pattern: int,
launch_with_pdl: bool = False,
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why the advantage to give pdl control to the user?

Args:
input: Input tensor [token_num, hidden_dim]
workspace: Workspace object (type determines backend)
pattern: Fusion pattern (AllReduceFusionPattern constant, 0-5)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All they all 2-kernel overlap or some are real fusion kernels?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

with one-shot mnnvl it's real fusion. And I think similar for the trtllm_ar kernels. It's just two-shot mnnvl that is the 2-kernel overlap.

},
heuristic_func=_workspace_creation_heuristic,
)
def create_allreduce_fusion_workspace(
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could create_allreduce_fusion_workspace take an optional workspace argument? If workspace is big enough or too big this is a noop (maybe just updating backend selection). If it is too small, destroy current workspace and allocate a bigger one.

When we switch to mem pool, we should be able to call create_allreduce_fusion_workspace at each forward pass and memory will just get reused from the mempool (instead of new allocations).
CC @Amir-19

- Workspace(max_token_num=2048, hidden_dim=4096) can handle:
- (token_num=2048, hidden_dim=4096) ✓
- (token_num=1024, hidden_dim=4096) ✓
- (token_num=4096, hidden_dim=2048) ✓ (same total size)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I only see FW adjusting the num of tokens but hidden_dim should be fixed per model.

... max_token_num=2048,
... hidden_dim=4096,
... dtype=torch.bfloat16,
... topology="single_node"
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we had a check now to detect topology? before we switch to the mempool allocation?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

3 participants