- Notifications
You must be signed in to change notification settings - Fork 589
A unified API for the MNNVL and single-node AllReduce kernels. #2130
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
A unified API for the MNNVL and single-node AllReduce kernels. #2130
Conversation
| Important Review skippedDraft detected. Please check the settings in the CodeRabbit UI or the You can disable this status message by setting the ✨ Finishing touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
| max_token_num: int, | ||
| hidden_dim: int, | ||
| dtype: torch.dtype, | ||
| topology: str, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| topology: str, | |
| topology: Literal["single_node", "multi_node"], |
| max_token_num: int = None, | ||
| hidden_dim: int = None, | ||
| dtype: torch.dtype = None, | ||
| topology: str = "single_node", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think it is needed longer term since we will use the same pytorch symmetric API to allocate symmetric memory for single and multi-node (under the cover pytorch/NCCL/NVSHMEM will detect platform and decides the right mem allocation handle)
| input: torch.Tensor, | ||
| workspace: AllReduceFusionWorkspace, | ||
| pattern: int, | ||
| launch_with_pdl: bool = False, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why the advantage to give pdl control to the user?
| Args: | ||
| input: Input tensor [token_num, hidden_dim] | ||
| workspace: Workspace object (type determines backend) | ||
| pattern: Fusion pattern (AllReduceFusionPattern constant, 0-5) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
All they all 2-kernel overlap or some are real fusion kernels?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
with one-shot mnnvl it's real fusion. And I think similar for the trtllm_ar kernels. It's just two-shot mnnvl that is the 2-kernel overlap.
| }, | ||
| heuristic_func=_workspace_creation_heuristic, | ||
| ) | ||
| def create_allreduce_fusion_workspace( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could create_allreduce_fusion_workspace take an optional workspace argument? If workspace is big enough or too big this is a noop (maybe just updating backend selection). If it is too small, destroy current workspace and allocate a bigger one.
When we switch to mem pool, we should be able to call create_allreduce_fusion_workspace at each forward pass and memory will just get reused from the mempool (instead of new allocations).
CC @Amir-19
| - Workspace(max_token_num=2048, hidden_dim=4096) can handle: | ||
| - (token_num=2048, hidden_dim=4096) ✓ | ||
| - (token_num=1024, hidden_dim=4096) ✓ | ||
| - (token_num=4096, hidden_dim=2048) ✓ (same total size) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I only see FW adjusting the num of tokens but hidden_dim should be fixed per model.
| ... max_token_num=2048, | ||
| ... hidden_dim=4096, | ||
| ... dtype=torch.bfloat16, | ||
| ... topology="single_node" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we had a check now to detect topology? before we switch to the mempool allocation?
📌 Description
A unified API for the MNNVL and single-node AllReduce kernels.
The backend will be chosen during workspace creation. We can either pick it explicitly, or use the "auto" backend to have a heuristic pick the best backend.
🔍 Related Issues
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes