Skip to content

Nemotron-3-Super-120B: 7 tok/s vs 33 tok/s claimed — fused MoE kernel needed #52

@scouzi1966

Description

@scouzi1966

Problem

inferencerlabs/NVIDIA-Nemotron-3-Super-120B-A12B-MLX-9bit gets 7 tok/s on M3 Ultra (512GB). Inferencer app claims 33.7 tok/s on same hardware using "modified MLX". The smaller Nano-30B variant gets 62.5 tok/s — confirming our MoE implementation is correct at smaller scale.

Profiling Data (per token, steady state)

Layer Type Time % Count Per-layer
MoE 93ms 75% 40 2.3ms
Mamba 28ms 23% 40 0.7ms
Attention 3.5ms 3% 8 0.4ms
Total 125ms 88

97% of time is GPU compute (asyncEval), only 3.5ms is Swift overhead (model() graph construction).

Root Cause

Each MoE layer does 6 sequential Metal kernel dispatches:

  1. gate matmul (4096→512 topk)
  2. fc1_latent_proj (4096→1024 quantized matmul)
  3. gather_qmm fc1 (1024→2688 across 22/512 experts)
  4. relu2 (element-wise)
  5. gather_qmm fc2 (2688→1024 across 22/512 experts)
  6. fc2_latent_proj + score_sum (1024→4096)

Each dispatch writes to main memory. The [22, 2688] intermediate between fc1 and fc2 is written/read unnecessarily.

Active Parameter Analysis

  • Total params: 120.7B (136 GB at 8-bit)
  • Active per token: 12.8B (14.4 GB) — only 10.6% due to 22/512 MoE sparsity
  • M3 Ultra bandwidth: ~800 GB/s
  • Theoretical: 800/14.4 = 55.7 tok/s (sparse) vs 800/136 = 5.9 tok/s (full read)
  • Measured: 7 tok/s (reading ~114 GB — nearly full)
  • Inferencer: 33.7 tok/s (reading ~24 GB — partially sparse)

What Was Tried

  1. Staged eval (per-layer graph breaking): No effect — data deps already serialize
  2. Fused MoE Metal kernel (fc1→relu2→fc2→score in one dispatch): Kernel compiles and runs but has 8-bit affine dequantization bug — outputs are wrong after 1 token

What's Needed

Fused MoE kernel (primary — 75% of time)

  • Single Metal kernel: fc1 → relu2 → fc2 → score_weight per expert
  • Keep [2688] intermediate in registers/threadgroup shared memory
  • Correct 8-bit affine dequant matching MLX's quantize() format
  • Grid: (32 threads, K=22 experts, OUT_DIM=1024)
  • Target: 93ms → ~20ms per token

Fused Mamba kernel (secondary — 23% of time)

  • in_proj → conv1d → SSM → norm → out_proj in one dispatch
  • Target: 28ms → ~7ms per token

Combined target: ~30ms/tok = ~33 tok/s

Model Config

88 layers: 40 Mamba (M) + 40 MoE (E) + 8 Attention (*) 512 experts, top-22, latent_size=1024, intermediate=2688 hidden=4096, 8-bit quantized (affine, group_size=32) Pattern: MEMEMEM*EMEMEMEM*EMEMEMEM*E... 

Files

  • Scripts/patches/NemotronH.swift — model implementation
  • Scripts/patches/SwitchLayers.swift — SwitchLinear/QuantizedSwitchLinear
  • vendor/mlx-swift-lm/Libraries/MLXLLM/Models/SSM.swift — existing Metal kernel pattern to follow

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions