- Notifications
You must be signed in to change notification settings - Fork 13
Open
Description
Problem
inferencerlabs/NVIDIA-Nemotron-3-Super-120B-A12B-MLX-9bit gets 7 tok/s on M3 Ultra (512GB). Inferencer app claims 33.7 tok/s on same hardware using "modified MLX". The smaller Nano-30B variant gets 62.5 tok/s — confirming our MoE implementation is correct at smaller scale.
Profiling Data (per token, steady state)
| Layer Type | Time | % | Count | Per-layer |
|---|---|---|---|---|
| MoE | 93ms | 75% | 40 | 2.3ms |
| Mamba | 28ms | 23% | 40 | 0.7ms |
| Attention | 3.5ms | 3% | 8 | 0.4ms |
| Total | 125ms | 88 |
97% of time is GPU compute (asyncEval), only 3.5ms is Swift overhead (model() graph construction).
Root Cause
Each MoE layer does 6 sequential Metal kernel dispatches:
- gate matmul (4096→512 topk)
- fc1_latent_proj (4096→1024 quantized matmul)
- gather_qmm fc1 (1024→2688 across 22/512 experts)
- relu2 (element-wise)
- gather_qmm fc2 (2688→1024 across 22/512 experts)
- fc2_latent_proj + score_sum (1024→4096)
Each dispatch writes to main memory. The [22, 2688] intermediate between fc1 and fc2 is written/read unnecessarily.
Active Parameter Analysis
- Total params: 120.7B (136 GB at 8-bit)
- Active per token: 12.8B (14.4 GB) — only 10.6% due to 22/512 MoE sparsity
- M3 Ultra bandwidth: ~800 GB/s
- Theoretical: 800/14.4 = 55.7 tok/s (sparse) vs 800/136 = 5.9 tok/s (full read)
- Measured: 7 tok/s (reading ~114 GB — nearly full)
- Inferencer: 33.7 tok/s (reading ~24 GB — partially sparse)
What Was Tried
- Staged eval (per-layer graph breaking): No effect — data deps already serialize
- Fused MoE Metal kernel (fc1→relu2→fc2→score in one dispatch): Kernel compiles and runs but has 8-bit affine dequantization bug — outputs are wrong after 1 token
What's Needed
Fused MoE kernel (primary — 75% of time)
- Single Metal kernel: fc1 → relu2 → fc2 → score_weight per expert
- Keep [2688] intermediate in registers/threadgroup shared memory
- Correct 8-bit affine dequant matching MLX's
quantize()format - Grid: (32 threads, K=22 experts, OUT_DIM=1024)
- Target: 93ms → ~20ms per token
Fused Mamba kernel (secondary — 23% of time)
- in_proj → conv1d → SSM → norm → out_proj in one dispatch
- Target: 28ms → ~7ms per token
Combined target: ~30ms/tok = ~33 tok/s
Model Config
88 layers: 40 Mamba (M) + 40 MoE (E) + 8 Attention (*) 512 experts, top-22, latent_size=1024, intermediate=2688 hidden=4096, 8-bit quantized (affine, group_size=32) Pattern: MEMEMEM*EMEMEMEM*EMEMEMEM*E... Files
Scripts/patches/NemotronH.swift— model implementationScripts/patches/SwitchLayers.swift— SwitchLinear/QuantizedSwitchLinearvendor/mlx-swift-lm/Libraries/MLXLLM/Models/SSM.swift— existing Metal kernel pattern to follow
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels