-
- Notifications
You must be signed in to change notification settings - Fork 11.7k
[Kernel]Support W4A8 Grouped GEMM on Hopper #29691
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request adds support for W4A8 Grouped GEMM on Hopper GPUs, which is a significant feature for running quantized Mixture-of-Experts models efficiently. The changes span across C++ CUDA kernels, Python bindings, and integration into the model execution layers. The implementation looks solid, with new tests for the functionality. I've identified a couple of critical issues related to data type checks that could cause runtime failures for supported configurations. Addressing these will improve the robustness of the new kernel.
ba25cff to e7cf2d3 Compare | This pull request has merge conflicts that must be resolved before it can be |
3c1fdfa to e043c39 Compare
dsikka left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The CT integration looks clean to me! Do we have a test model we can add?
| @dsikka thanks! We can use https://huggingface.co/czhu-cohere/Qwen3-30B-A3B-quantized.w4a8 |
LucasWilkinson left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! Thanks; amazing work!
| cutlass::gemm::GroupProblemShape<Shape<int, int, int>>; // <M,N,K> per | ||
| // group | ||
| using MmaType = cutlass::float_e4m3_t; | ||
| using QuantType = cutlass::int4b_t; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
any chance this could be extended to mxfp4 too? would be nice if we could make this compatible with gpt-oss (could be done in a future PR)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
mxfp4 is an e8m0 scaling factor for every 32 elements? I think there is a group size limitation of 128 here though because the activation is 8 bits
Signed-off-by: czhu-cohere <conway.zhu@cohere.com>
Signed-off-by: czhu-cohere <conway.zhu@cohere.com>
Signed-off-by: czhu-cohere <conway.zhu@cohere.com>
Signed-off-by: czhu-cohere <conway.zhu@cohere.com>
Signed-off-by: czhu-cohere <conway.zhu@cohere.com>
Signed-off-by: czhu-cohere <conway.zhu@cohere.com>
Signed-off-by: czhu-cohere <conway.zhu@cohere.com>
Signed-off-by: czhu-cohere <conway.zhu@cohere.com>
Signed-off-by: czhu-cohere <conway.zhu@cohere.com>
Signed-off-by: czhu-cohere <conway.zhu@cohere.com>
Signed-off-by: czhu-cohere <conway.zhu@cohere.com>
Signed-off-by: czhu-cohere <conway.zhu@cohere.com>
Signed-off-by: czhu-cohere <conway.zhu@cohere.com>
Signed-off-by: czhu-cohere <conway.zhu@cohere.com>
Signed-off-by: czhu-cohere <conway.zhu@cohere.com>
Signed-off-by: czhu-cohere <conway.zhu@cohere.com>
Signed-off-by: czhu-cohere <conway.zhu@cohere.com>
Signed-off-by: czhu-cohere <conway.zhu@cohere.com>
Signed-off-by: czhu-cohere <conway.zhu@cohere.com>
Signed-off-by: czhu-cohere <conway.zhu@cohere.com>
Signed-off-by: czhu-cohere <conway.zhu@cohere.com>
Signed-off-by: czhu-cohere <conway.zhu@cohere.com>
Signed-off-by: czhu-cohere <conway.zhu@cohere.com>
c08fc15 to 811563e Compare
Purpose
As title; the benefit of W4A8 is it can use fp8 tensor cores while still maintaining the low memory footprint of W4A16 (with negligible quality loss). In addition there is no Machete-like impl in vLLM for W4A16 grouped gemm so the compute gains should be even larger compared to the current Marlin kernels.
The CUTLASS kernel implementation follows example 69 which uses a LUT-based method for fast INT4 -> FP8 conversion. Similarly to W4A8 dense, we also add per-channel/per-token epilogue.
We have uploaded a W4A8 quantized variant of Qwen3-30B-A3B as an e2e sanity check.
C++ changes
csrc/quantization/cutlass_w4a8/w4a8_utils.cucsrc/quantization/cutlass_w4a8/get_group_starts.cuhcsrc/quantization/cutlass_w4a8/w4a8_grouped_mm_entry.cuencode_and_reorder_int4b), we construct the layout object and serialize it to a torch tensor so that we can pass it into the grouped gemm at runtime. This is to avoid having to reconstruct the layout itself at runtime, which would incur significant overhead when the number of experts is large.static_assertandlayout_widthshould guarantee that the layout can be serialized to the expected torch tensor dtype/sizecsrc/quantization/w8a8/cutlass/moe/moe_data.cuget_cutlass_moe_mm_problem_sizesis coupled withSwapAB, so I added an argument to allow the user to explicitly specify SwapAB is true/false (forRSGEMM it is always true, since the argument to be dequantized - B - needs to be in the LHS)Python changes
vllm/model_executor/layers/fused_moe/config.pyvllm/model_executor/layers/fused_moe/modular_kernel.pyw1_scaleare used for the group-wise scales.vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.pyFusedMoeWeightScaleSupported.GROUP.valueandFusedMoeWeightScaleSupported.CHANNEL.valueto load group and channel scales respectivelycutlass:reorder_tensorexplained above (In practice that means for small MoEs like Qwen 30B you may not be able to do TP2)s_strides1/2which store strides for the group scales are stored as shape[num_experts, 2]and dtypeint64since that is what the kernel expectsb_strides1/2is returned by the reordering op and saved to pass in at runtimevllm/model_executor/layers/fused_moe/cutlass_moe.pySwapABis true always.Limitations
Have not implemented/checked compatibility with the different EP options other than default.
Test Plan
kernel correctness test -
tests/kernels/quantization/test_cutlass_w4a8_moe.pye2e eval - lm_eval gsm8k, compare qwen3-30b-a3b w4a16 and w4a8 variants
Test Result
tests/kernels/quantization/test_cutlass_w4a8_moe.py- passlm_eval
8k prefill for
Qwen3-30B-A3Bcomparing w4a8 and w4a16Note that the expert sizes for
Qwen3-30B-A3Bare quite small and it seems hard to approach peak FP8 TFLOPs with these shapes under the current schedule; larger experts can get higher flops. We leave investigation of this to future work.Essential Elements of an Effective PR Description Checklist
supported_models.mdandexamplesfor a new model.