Skip to content

Commit 2902c34

Browse files
bnellnmtlrmchlsmth
andauthored
[Kernels] Remove BatchedTritonOrDeepGemmExperts and default fallback to Triton (#29929)
Signed-off-by: Bill Nell <bnell@redhat.com> Signed-off-by: bnellnm <49004751+bnellnm@users.noreply.github.com> Co-authored-by: Tyler Michael Smith <tyler@neuralmagic.com>
1 parent ac18865 commit 2902c34

File tree

5 files changed

+46
-217
lines changed

5 files changed

+46
-217
lines changed

docs/design/moe_kernel_features.md

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,6 @@ To be used with a particular `FusedMoEPrepareAndFinalize` subclass, MoE kernels
9090
| cutlass_fp8 | standard,</br>batched | fp8 | A,T | silu, gelu | Y | Y | [`cutlass_moe_fp8`][vllm.model_executor.layers.fused_moe.cutlass_moe.cutlass_moe_fp8],</br>[`CutlassExpertsFp8`][vllm.model_executor.layers.fused_moe.cutlass_moe.CutlassExpertsFp8],</br>[`CutlasBatchedExpertsFp8`][vllm.model_executor.layers.fused_moe.cutlass_moe.CutlassBatchedExpertsFp8] |
9191
| flashinfer | standard | nvfp4,</br>fp8 | T | <sup>5</sup> | N | Y | [`flashinfer_cutlass_moe_fp4`][vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe.flashinfer_cutlass_moe_fp4],</br>[`FlashInferExperts`][vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe.FlashInferExperts] |
9292
| gpt oss triton | standard | N/A | N/A | <sup>5</sup> | Y | Y | [`triton_kernel_fused_experts`][vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe.triton_kernel_fused_experts],</br>[`OAITritonExperts`][vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe.OAITritonExperts] |
93-
| deep gemm+triton<sup>2</sup> | standard,</br>batched | all<sup>1</sup> | G(128),A,T | silu, gelu | <sup>6</sup> | Y | [`TritonOrDeepGemmExperts`][vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe.TritonOrDeepGemmExperts],</br>[`BatchedTritonOrDeepGemmExperts`][vllm.model_executor.layers.fused_moe.batched_triton_or_deep_gemm_moe.BatchedTritonOrDeepGemmExperts] |
9493
| marlin | standard,</br>batched | <sup>3</sup> / N/A | <sup>3</sup> / N/A | silu,</br>swigluoai | Y | Y | [`fused_marlin_moe`][vllm.model_executor.layers.fused_moe.fused_marlin_moe.fused_marlin_moe],</br>[`MarlinExperts`][vllm.model_executor.layers.fused_moe.fused_marlin_moe.MarlinExperts],</br>[`BatchedMarlinExperts`][vllm.model_executor.layers.fused_moe.fused_marlin_moe.BatchedMarlinExperts] |
9594
| trtllm | standard | mxfp4,</br>nvfp4 | G(16),G(32) | <sup>5</sup> | N | Y | [`TrtLlmGenExperts`][vllm.model_executor.layers.fused_moe.trtllm_moe.TrtLlmGenExperts] |
9695
| pallas | standard | N/A | N/A | silu | N | N | [`fused_moe`][vllm.model_executor.layers.fused_moe.moe_pallas.fused_moe] |
@@ -114,5 +113,5 @@ The following table shows "families" of modular kernels that are intended to wor
114113
| backend | `FusedMoEPrepareAndFinalize` subclasses | `FusedMoEPermuteExpertsUnpermute` subclasses |
115114
|---------|-----------------------------------------|----------------------------------------------|
116115
| deepep_high_throughput | `DeepEPHTPrepareAndFinalize` | `DeepGemmExperts`,</br>`TritonExperts`,</br>`TritonOrDeepGemmExperts`,</br>`CutlassExpertsFp8`, </br>`MarlinExperts` |
117-
| deepep_low_latency,</br>pplx | `DeepEPLLPrepareAndFinalize`,</br>`PplxPrepareAndFinalize` | `BatchedDeepGemmExperts`,</br>`BatchedTritonExperts`,</br>`BatchedTritonOrDeepGemmExperts`,</br>`CutlassBatchedExpertsFp8`,</br>`BatchedMarlinExperts` |
116+
| deepep_low_latency,</br>pplx | `DeepEPLLPrepareAndFinalize`,</br>`PplxPrepareAndFinalize` | `BatchedDeepGemmExperts`,</br>`BatchedTritonExperts`,</br>`CutlassBatchedExpertsFp8`,</br>`BatchedMarlinExperts` |
118117
| flashinfer | `FlashInferCutlassMoEPrepareAndFinalize` | `FlashInferExperts` |

tests/kernels/moe/modular_kernel_tools/mk_objects.py

Lines changed: 0 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,6 @@
1313
from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import (
1414
BatchedDeepGemmExperts,
1515
)
16-
from vllm.model_executor.layers.fused_moe.batched_triton_or_deep_gemm_moe import (
17-
BatchedTritonOrDeepGemmExperts,
18-
)
1916
from vllm.model_executor.layers.fused_moe.config import (
2017
FusedMoEConfig,
2118
FusedMoEQuantConfig,
@@ -286,16 +283,6 @@ def expert_info(kind) -> ExpertInfo:
286283
needs_matching_quant=False,
287284
needs_deep_gemm=True,
288285
)
289-
register_experts(
290-
BatchedTritonOrDeepGemmExperts,
291-
batched_format,
292-
common_float_and_int_types,
293-
blocked_quantization_support=True,
294-
supports_chunking=False,
295-
supports_expert_map=False,
296-
needs_matching_quant=True,
297-
needs_deep_gemm=True,
298-
)
299286
register_experts(
300287
TritonOrDeepGemmExperts,
301288
standard_format,
@@ -457,10 +444,6 @@ def make_fused_experts(
457444
kwargs = batch_kwargs | quant_kwargs
458445
print(f"Making BatchedTritonExperts {kwargs} ...")
459446
experts = BatchedTritonExperts(**kwargs)
460-
elif fused_experts_type == BatchedTritonOrDeepGemmExperts:
461-
kwargs = batch_kwargs | quant_kwargs | deepgemm_kwargs
462-
print(f"Making BatchedTritonOrDeepGemmExperts {kwargs} ...")
463-
experts = BatchedTritonOrDeepGemmExperts(**kwargs)
464447
elif fused_experts_type == DeepGemmExperts:
465448
print(f"Making DeepGemmExperts {quant_config} ...")
466449
experts = DeepGemmExperts(quant_config)

vllm/model_executor/layers/fused_moe/__init__.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -60,9 +60,6 @@ def get_config() -> dict[str, Any] | None:
6060
from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import (
6161
BatchedDeepGemmExperts,
6262
)
63-
from vllm.model_executor.layers.fused_moe.batched_triton_or_deep_gemm_moe import ( # noqa: E501
64-
BatchedTritonOrDeepGemmExperts,
65-
)
6663
from vllm.model_executor.layers.fused_moe.cutlass_moe import (
6764
CutlassBatchedExpertsFp8,
6865
CutlassExpertsFp8,
@@ -98,7 +95,6 @@ def get_config() -> dict[str, Any] | None:
9895
"DeepGemmExperts",
9996
"BatchedDeepGemmExperts",
10097
"TritonOrDeepGemmExperts",
101-
"BatchedTritonOrDeepGemmExperts",
10298
]
10399
else:
104100
# Some model classes directly use the custom ops. Add placeholders

vllm/model_executor/layers/fused_moe/batched_triton_or_deep_gemm_moe.py

Lines changed: 0 additions & 180 deletions
This file was deleted.

vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py

Lines changed: 45 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -90,8 +90,10 @@
9090
from vllm.scalar_type import scalar_types
9191
from vllm.utils.deep_gemm import (
9292
get_col_major_tma_aligned_tensor,
93+
get_mk_alignment_for_contiguous_layout,
9394
is_deep_gemm_e8m0_used,
9495
)
96+
from vllm.utils.import_utils import has_deep_gemm
9597

9698
logger = init_logger(__name__)
9799

@@ -1088,39 +1090,68 @@ def select_gemm_impl(
10881090

10891091
return experts
10901092

1091-
# triton path
1092-
from vllm.model_executor.layers.fused_moe.batched_triton_or_deep_gemm_moe import ( # noqa: E501
1093-
BatchedTritonOrDeepGemmExperts,
1093+
from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import (
1094+
BatchedDeepGemmExperts,
1095+
)
1096+
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
1097+
BatchedTritonExperts,
10941098
)
10951099
from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import (
10961100
TritonOrDeepGemmExperts,
10971101
)
10981102

10991103
assert not self.rocm_aiter_moe_enabled and not self.use_marlin
11001104

1105+
use_deep_gemm = envs.VLLM_USE_DEEP_GEMM and envs.VLLM_MOE_USE_DEEP_GEMM
1106+
11011107
if (
11021108
prepare_finalize.activation_format
11031109
== FusedMoEActivationFormat.BatchedExperts
11041110
):
11051111
max_num_tokens_per_rank = prepare_finalize.max_num_tokens_per_rank()
11061112
assert max_num_tokens_per_rank is not None
11071113

1108-
logger.debug("BatchedTritonExperts(%s)", self.__class__.__name__)
1109-
return BatchedTritonOrDeepGemmExperts(
1110-
max_num_tokens=max_num_tokens_per_rank,
1111-
num_dispatchers=prepare_finalize.num_dispatchers(),
1112-
quant_config=self.moe_quant_config,
1113-
allow_deep_gemm=(
1114-
envs.VLLM_USE_DEEP_GEMM and envs.VLLM_MOE_USE_DEEP_GEMM
1115-
),
1114+
if use_deep_gemm and not has_deep_gemm():
1115+
raise RuntimeError(
1116+
"DeepGEMM requested for MoE layer but not installed."
1117+
)
1118+
1119+
compatible_with_deep_gemm = (
1120+
self.moe_quant_config.use_fp8_w8a8
1121+
and self.moe_quant_config.block_shape
1122+
== get_mk_alignment_for_contiguous_layout()
11161123
)
1124+
1125+
# If this MoE layer is compatible with DeepGEMM, the proper env
1126+
# vars are set and DeepGEMM is not installed, throw an error.
1127+
if use_deep_gemm and compatible_with_deep_gemm and not has_deep_gemm():
1128+
raise RuntimeError(
1129+
f"MoE layer incompatible with DeepGEMM, expected "
1130+
f"fp8==True, got {self.moe_quant_config.use_fp8_w8a8}"
1131+
f"or block_shape {self.moe_quant_config.block_shape}"
1132+
f"=={get_mk_alignment_for_contiguous_layout()}."
1133+
)
1134+
1135+
if use_deep_gemm and compatible_with_deep_gemm and has_deep_gemm():
1136+
logger.debug("BatchedDeepGemmExperts(%s)", self.__class__.__name__)
1137+
return BatchedDeepGemmExperts(
1138+
max_num_tokens=max_num_tokens_per_rank,
1139+
num_dispatchers=prepare_finalize.num_dispatchers(),
1140+
quant_config=self.moe_quant_config,
1141+
)
1142+
else:
1143+
logger.debug("BatchedTritonExperts(%s)", self.__class__.__name__)
1144+
return BatchedTritonExperts(
1145+
max_num_tokens=max_num_tokens_per_rank,
1146+
num_dispatchers=prepare_finalize.num_dispatchers(),
1147+
quant_config=self.moe_quant_config,
1148+
)
1149+
11171150
else:
11181151
logger.debug("TritonOrDeepGemmExperts(%s)", self.__class__.__name__)
11191152
return TritonOrDeepGemmExperts(
11201153
self.moe_quant_config,
1121-
allow_deep_gemm=(
1122-
envs.VLLM_USE_DEEP_GEMM and envs.VLLM_MOE_USE_DEEP_GEMM
1123-
),
1154+
allow_deep_gemm=use_deep_gemm,
11241155
)
11251156

11261157
def get_fused_moe_quant_config(

0 commit comments

Comments
 (0)