@@ -205,14 +205,14 @@ def get_moe_method(
205205 return CompressedTensorsW8A8Int8MoEMethod (
206206 weight_quant , input_quant , layer .moe_config
207207 )
208- elif quant_config ._is_dynamic_token_w4a8_int (weight_quant , input_quant ):
209- return CompressedTensorsW4A8Int8MoEMethod (
210- weight_quant , input_quant , layer .moe_config
211- )
212208 elif quant_config ._is_fp8_w4a8_sm90 (weight_quant , input_quant ):
213209 logger .info_once ("Using CompressedTensorsW4A8Fp8MoEMethod" )
214210 return CompressedTensorsW4A8Fp8MoEMethod (
215- quant_config , layer .moe_config , layer_name
211+ weight_quant , input_quant , layer .moe_config
212+ )
213+ elif quant_config ._is_dynamic_token_w4a8_int (weight_quant , input_quant ):
214+ return CompressedTensorsW4A8Int8MoEMethod (
215+ weight_quant , input_quant , layer .moe_config
216216 )
217217 else :
218218 raise RuntimeError (
@@ -2436,16 +2436,15 @@ def _act_kind(s: str) -> int:
24362436class CompressedTensorsW4A8Fp8MoEMethod (CompressedTensorsMoEMethod ):
24372437 def __init__ (
24382438 self ,
2439- quant_config : "CompressedTensorsConfig" , # type: ignore # noqa E501
2439+ weight_quant : QuantizationArgs ,
2440+ input_quant : QuantizationArgs ,
24402441 moe : FusedMoEConfig ,
24412442 layer_name : str | None = None ,
24422443 ):
24432444 super ().__init__ (moe )
2444- self .quant_config = quant_config
2445- self .weight_quant = self .quant_config .target_scheme_map ["Linear" ].get ("weights" )
2446- self .input_quant = self .quant_config .target_scheme_map ["Linear" ].get (
2447- "input_activations"
2448- )
2445+ self .weight_quant = weight_quant
2446+ self .input_quant = input_quant
2447+
24492448 self .group_size = self .weight_quant .group_size
24502449 self .num_bits = self .weight_quant .num_bits
24512450 self .packed_factor = 32 // self .num_bits
0 commit comments