Skip to content

Commit 811563e

Browse files
committed
reoder config
Signed-off-by: czhu-cohere <conway.zhu@cohere.com>
1 parent e8dbdc9 commit 811563e

File tree

1 file changed

+10
-11
lines changed

1 file changed

+10
-11
lines changed

vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -205,14 +205,14 @@ def get_moe_method(
205205
return CompressedTensorsW8A8Int8MoEMethod(
206206
weight_quant, input_quant, layer.moe_config
207207
)
208-
elif quant_config._is_dynamic_token_w4a8_int(weight_quant, input_quant):
209-
return CompressedTensorsW4A8Int8MoEMethod(
210-
weight_quant, input_quant, layer.moe_config
211-
)
212208
elif quant_config._is_fp8_w4a8_sm90(weight_quant, input_quant):
213209
logger.info_once("Using CompressedTensorsW4A8Fp8MoEMethod")
214210
return CompressedTensorsW4A8Fp8MoEMethod(
215-
quant_config, layer.moe_config, layer_name
211+
weight_quant, input_quant, layer.moe_config
212+
)
213+
elif quant_config._is_dynamic_token_w4a8_int(weight_quant, input_quant):
214+
return CompressedTensorsW4A8Int8MoEMethod(
215+
weight_quant, input_quant, layer.moe_config
216216
)
217217
else:
218218
raise RuntimeError(
@@ -2436,16 +2436,15 @@ def _act_kind(s: str) -> int:
24362436
class CompressedTensorsW4A8Fp8MoEMethod(CompressedTensorsMoEMethod):
24372437
def __init__(
24382438
self,
2439-
quant_config: "CompressedTensorsConfig", # type: ignore # noqa E501
2439+
weight_quant: QuantizationArgs,
2440+
input_quant: QuantizationArgs,
24402441
moe: FusedMoEConfig,
24412442
layer_name: str | None = None,
24422443
):
24432444
super().__init__(moe)
2444-
self.quant_config = quant_config
2445-
self.weight_quant = self.quant_config.target_scheme_map["Linear"].get("weights")
2446-
self.input_quant = self.quant_config.target_scheme_map["Linear"].get(
2447-
"input_activations"
2448-
)
2445+
self.weight_quant = weight_quant
2446+
self.input_quant = input_quant
2447+
24492448
self.group_size = self.weight_quant.group_size
24502449
self.num_bits = self.weight_quant.num_bits
24512450
self.packed_factor = 32 // self.num_bits

0 commit comments

Comments
 (0)