Skip to content

Commit ba25cff

Browse files
committed
undo debug stuff
Signed-off-by: czhu-cohere <conway.zhu@cohere.com>
1 parent 5b793e3 commit ba25cff

File tree

5 files changed

+8
-72
lines changed

5 files changed

+8
-72
lines changed

vllm/model_executor/layers/fused_moe/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@ def get_config() -> dict[str, Any] | None:
9393
"cutlass_moe_fp4",
9494
"CutlassExpertsFp8",
9595
"CutlassBatchedExpertsFp8",
96+
"CutlassExpertsW4A8Fp8",
9697
"TritonExperts",
9798
"BatchedTritonExperts",
9899
"DeepGemmExperts",

vllm/model_executor/layers/fused_moe/cutlass_moe.py

Lines changed: 1 addition & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -27,19 +27,6 @@
2727
logger = init_logger(__name__)
2828

2929

30-
# print utilities
31-
def print_args_info(*args, **kwargs):
32-
print("=== positional args ===")
33-
for i, a in enumerate(args):
34-
print(f"\narg[{i}]:")
35-
# print tensor info only
36-
if isinstance(a, torch.Tensor):
37-
print(f" shape : {tuple(a.shape)}")
38-
print(f" stride: {tuple(a.stride())}")
39-
print(f" dtype : {a.dtype}")
40-
print(f" device: {a.device}")
41-
42-
4330
def run_cutlass_moe_fp8(
4431
output: torch.Tensor,
4532
hidden_states: torch.Tensor,
@@ -207,24 +194,6 @@ def run_cutlass_moe_fp8(
207194
# this rank handles only partial tokens, or when it is batched .
208195
mm1_out.fill_(0)
209196

210-
# print(f'Printing information for first moe call')
211-
# print_args_info(
212-
# mm1_out,
213-
# a1q,
214-
# w1,
215-
# a1q_scale,
216-
# w1_scale,
217-
# expert_offsets,
218-
# problem_sizes1,
219-
# ab_strides1,
220-
# ab_strides1,
221-
# c_strides1,
222-
# per_act_token,
223-
# per_out_ch,
224-
# )
225-
# print problem shapes and stuff
226-
# print(f'{problem_sizes1=}')
227-
# print(f'{expert_offsets=}')
228197
ops.cutlass_moe_mm(
229198
mm1_out,
230199
a1q,
@@ -248,22 +217,7 @@ def run_cutlass_moe_fp8(
248217

249218
if expert_map is not None:
250219
mm2_out.fill_(0)
251-
# print('=========================')
252-
# print(f'Printing information for second moe call...')
253-
# print_args_info(
254-
# mm2_out,
255-
# a2q,
256-
# w2,
257-
# a2q_scale,
258-
# w2_scale,
259-
# expert_offsets,
260-
# problem_sizes2,
261-
# ab_strides2,
262-
# ab_strides2,
263-
# c_strides2,
264-
# per_act_token,
265-
# per_out_ch,
266-
# )
220+
267221
ops.cutlass_moe_mm(
268222
mm2_out,
269223
a2q,

vllm/model_executor/layers/fused_moe/fused_marlin_moe.py

Lines changed: 1 addition & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -29,19 +29,6 @@
2929
from vllm.scalar_type import ScalarType, scalar_types
3030

3131

32-
# print utilities
33-
def print_args_info(*args, **kwargs):
34-
print("=== positional args ===")
35-
for i, a in enumerate(args):
36-
print(f"\narg[{i}]:")
37-
# print tensor info only
38-
if isinstance(a, torch.Tensor):
39-
print(f" shape : {tuple(a.shape)}")
40-
print(f" stride: {tuple(a.stride())}")
41-
print(f" dtype : {a.dtype}")
42-
print(f" device: {a.device}")
43-
44-
4532
def default_activation_func(
4633
activation: str, output: torch.Tensor, input: torch.Tensor
4734
) -> None:
@@ -124,13 +111,7 @@ def _fused_marlin_moe(
124111
hidden_states.dtype == torch.half
125112
or torch.cuda.get_device_capability(hidden_states.device)[0] >= 9
126113
)
127-
# print('printing stuff before moe 1')
128-
# print_args_info(
129-
# hidden_states,
130-
# w1,
131-
# w1_scale,
132-
# intermediate_cache1
133-
# )
114+
134115
intermediate_cache1 = ops.moe_wna16_marlin_gemm(
135116
hidden_states,
136117
intermediate_cache1,

vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2396,8 +2396,6 @@ def create_weights(
23962396
layer.register_parameter("w2_weight_packed", w2_weight_packed)
23972397
set_weight_attrs(w2_weight_packed, extra_weight_attrs)
23982398

2399-
# TODO(czhu): fix TP > 1 case, probably this and other stuff
2400-
# needs change
24012399
# weight_scale refers to the group-wise scales
24022400
w13_weight_scale = torch.nn.Parameter(
24032401
torch.ones(
@@ -2528,8 +2526,11 @@ def process_weights_after_loading(self, layer):
25282526
)
25292527
replace_parameter(layer, "w2_weight_packed", w2_weight_shuffled)
25302528

2531-
def maybe_make_prepare_finalize(self) -> mk.FusedMoEPrepareAndFinalize | None:
2532-
return super().maybe_make_prepare_finalize()
2529+
def maybe_make_prepare_finalize(
2530+
self,
2531+
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
2532+
) -> mk.FusedMoEPrepareAndFinalize | None:
2533+
return super().maybe_make_prepare_finalize(routing_tables)
25332534

25342535
def get_fused_moe_quant_config(
25352536
self, layer: torch.nn.Module

vllm/model_executor/models/registry.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,6 @@
7474
"ChatGLMForConditionalGeneration": ("chatglm", "ChatGLMForCausalLM"),
7575
"CohereForCausalLM": ("commandr", "CohereForCausalLM"),
7676
"Cohere2ForCausalLM": ("commandr", "CohereForCausalLM"),
77-
"Cohere2MoeForCausalLM": ("commandr", "Cohere2MoeForCausalLM"),
7877
"CwmForCausalLM": ("llama", "LlamaForCausalLM"),
7978
"DbrxForCausalLM": ("dbrx", "DbrxForCausalLM"),
8079
"DeciLMForCausalLM": ("nemotron_nas", "DeciLMForCausalLM"),

0 commit comments

Comments
 (0)