|
90 | 90 | from vllm.scalar_type import scalar_types |
91 | 91 | from vllm.utils.deep_gemm import ( |
92 | 92 | get_col_major_tma_aligned_tensor, |
| 93 | + get_mk_alignment_for_contiguous_layout, |
93 | 94 | is_deep_gemm_e8m0_used, |
94 | 95 | ) |
| 96 | +from vllm.utils.import_utils import has_deep_gemm |
95 | 97 |
|
96 | 98 | logger = init_logger(__name__) |
97 | 99 |
|
@@ -1088,39 +1090,68 @@ def select_gemm_impl( |
1088 | 1090 |
|
1089 | 1091 | return experts |
1090 | 1092 |
|
1091 | | - # triton path |
1092 | | - from vllm.model_executor.layers.fused_moe.batched_triton_or_deep_gemm_moe import ( # noqa: E501 |
1093 | | - BatchedTritonOrDeepGemmExperts, |
| 1093 | + from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import ( |
| 1094 | + BatchedDeepGemmExperts, |
| 1095 | + ) |
| 1096 | + from vllm.model_executor.layers.fused_moe.fused_batched_moe import ( |
| 1097 | + BatchedTritonExperts, |
1094 | 1098 | ) |
1095 | 1099 | from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import ( |
1096 | 1100 | TritonOrDeepGemmExperts, |
1097 | 1101 | ) |
1098 | 1102 |
|
1099 | 1103 | assert not self.rocm_aiter_moe_enabled and not self.use_marlin |
1100 | 1104 |
|
| 1105 | + use_deep_gemm = envs.VLLM_USE_DEEP_GEMM and envs.VLLM_MOE_USE_DEEP_GEMM |
| 1106 | + |
1101 | 1107 | if ( |
1102 | 1108 | prepare_finalize.activation_format |
1103 | 1109 | == FusedMoEActivationFormat.BatchedExperts |
1104 | 1110 | ): |
1105 | 1111 | max_num_tokens_per_rank = prepare_finalize.max_num_tokens_per_rank() |
1106 | 1112 | assert max_num_tokens_per_rank is not None |
1107 | 1113 |
|
1108 | | - logger.debug("BatchedTritonExperts(%s)", self.__class__.__name__) |
1109 | | - return BatchedTritonOrDeepGemmExperts( |
1110 | | - max_num_tokens=max_num_tokens_per_rank, |
1111 | | - num_dispatchers=prepare_finalize.num_dispatchers(), |
1112 | | - quant_config=self.moe_quant_config, |
1113 | | - allow_deep_gemm=( |
1114 | | - envs.VLLM_USE_DEEP_GEMM and envs.VLLM_MOE_USE_DEEP_GEMM |
1115 | | - ), |
| 1114 | + if use_deep_gemm and not has_deep_gemm(): |
| 1115 | + raise RuntimeError( |
| 1116 | + "DeepGEMM requested for MoE layer but not installed." |
| 1117 | + ) |
| 1118 | + |
| 1119 | + compatible_with_deep_gemm = ( |
| 1120 | + self.moe_quant_config.use_fp8_w8a8 |
| 1121 | + and self.moe_quant_config.block_shape |
| 1122 | + == get_mk_alignment_for_contiguous_layout() |
1116 | 1123 | ) |
| 1124 | + |
| 1125 | + # If this MoE layer is compatible with DeepGEMM, the proper env |
| 1126 | + # vars are set and DeepGEMM is not installed, throw an error. |
| 1127 | + if use_deep_gemm and compatible_with_deep_gemm and not has_deep_gemm(): |
| 1128 | + raise RuntimeError( |
| 1129 | + f"MoE layer incompatible with DeepGEMM, expected " |
| 1130 | + f"fp8==True, got {self.moe_quant_config.use_fp8_w8a8}" |
| 1131 | + f"or block_shape {self.moe_quant_config.block_shape}" |
| 1132 | + f"=={get_mk_alignment_for_contiguous_layout()}." |
| 1133 | + ) |
| 1134 | + |
| 1135 | + if use_deep_gemm and compatible_with_deep_gemm and has_deep_gemm(): |
| 1136 | + logger.debug("BatchedDeepGemmExperts(%s)", self.__class__.__name__) |
| 1137 | + return BatchedDeepGemmExperts( |
| 1138 | + max_num_tokens=max_num_tokens_per_rank, |
| 1139 | + num_dispatchers=prepare_finalize.num_dispatchers(), |
| 1140 | + quant_config=self.moe_quant_config, |
| 1141 | + ) |
| 1142 | + else: |
| 1143 | + logger.debug("BatchedTritonExperts(%s)", self.__class__.__name__) |
| 1144 | + return BatchedTritonExperts( |
| 1145 | + max_num_tokens=max_num_tokens_per_rank, |
| 1146 | + num_dispatchers=prepare_finalize.num_dispatchers(), |
| 1147 | + quant_config=self.moe_quant_config, |
| 1148 | + ) |
| 1149 | + |
1117 | 1150 | else: |
1118 | 1151 | logger.debug("TritonOrDeepGemmExperts(%s)", self.__class__.__name__) |
1119 | 1152 | return TritonOrDeepGemmExperts( |
1120 | 1153 | self.moe_quant_config, |
1121 | | - allow_deep_gemm=( |
1122 | | - envs.VLLM_USE_DEEP_GEMM and envs.VLLM_MOE_USE_DEEP_GEMM |
1123 | | - ), |
| 1154 | + allow_deep_gemm=use_deep_gemm, |
1124 | 1155 | ) |
1125 | 1156 |
|
1126 | 1157 | def get_fused_moe_quant_config( |
|
0 commit comments