Skip to content

Conversation

@hypdeb
Copy link
Contributor

@hypdeb hypdeb commented Dec 2, 2025

Purpose

Improve performance and support of Mistral Large 3 on Blackwell.

Details

  • Added per-tensor scaled Triton configs for MoE (for Eagle draft model)
  • (WIP) Added per-block scaled Triton configs for MoE (for target model)
  • Added support for Flashinfer TRTLLM per-tensor scaled FP8 MoE kernels (for Eagle draft model)
  • (WIP) Added support for Flashinfer TRTLLM per-block scaled FP8 MoE kernels (for target model)
  • Fixed Llama4 routing for FP4 MoE
  • Added support for Mistral config format in benchmarks/kernels/benchmark_moe.py
  • Added support for Mistral tokenizer in vllm/benchmarks/throughput.py

Best Performance Usage

FP8 Checkpoint on DGX B200 (8 devices)

The FP8 model will fit on a single node.
At low concurrencies, deploy with TP8:

VLLM_USE_TRTLLM_RAGGED_DEEPSEEK_PREFILL=1 \ VLLM_ATTENTION_BACKEND=FLASHINFER_MLA \ VLLM_USE_FLASHINFER_MOE_FP8=1 \ VLLM_FLASHINFER_MOE_BACKEND=latency \ vllm serve /models/Mistral-Large-3-675B-Instruct-2512-NVFP4 \ -tp 8 --kv-cache-dtype fp8 --no-enable-prefix-caching \ --config-format mistral --load-format mistral --tokenizer-mode mistral \ --max_model_len 65536 --max_num_seqs 512 --limit-mm-per-prompt '{"image":10}' \ --tool-call-parser mistral --enable-auto-tool-choice 

At higher concurrencies (128 concurrent requests and above), deploy with DP8 and expert parallelism:

VLLM_USE_TRTLLM_RAGGED_DEEPSEEK_PREFILL=1 \ VLLM_ATTENTION_BACKEND=FLASHINFER_MLA \ VLLM_USE_FLASHINFER_MOE_FP8=1 \ VLLM_FLASHINFER_MOE_BACKEND=latency \ vllm serve /models/Mistral-Large-3-675B-Instruct-2512-NVFP4 \ --data-parallel-size 8 --enable-expert-parallel --kv-cache-dtype fp8 --no-enable-prefix-caching \ --config-format mistral --load-format mistral --tokenizer-mode mistral \ --max_model_len 65536 --max_num_seqs 512 --limit-mm-per-prompt '{"image":10}' \ --tool-call-parser mistral --enable-auto-tool-choice 

NVFP4

For NVFP4 checkpoints add the following to leverage the optimized kernels from Flashinfer:

VLLM_NVFP4_GEMM_BACKEND=cutlass VLLM_USE_FLASHINFER_MOE_FP4=1 

With a version of Flashinfer >0.5.3:

VLLM_NVFP4_GEMM_BACKEND=flashinfer-cudnn VLLM_USE_FLASHINFER_MOE_FP4=1 

A bug in the auto-tuner fixed recently (flashinfer-ai/flashinfer#2140) allows using flashinfer-cudnn.

GB200 P/D Disaggregated Dynamo Deployment

There are two options to set up a Dynamo P/D disaggregated deployment of this model. The first one is available immediately and relies on the processing pipeline of Dynamo. The second is pending a PR on Dynamo to enable delegating pre-processing to the vLLM backend.

For compatibility with ToT vLLM, you might need to include some changes that are not currently in upstream Dynamo:

With Dynamo request processing

Start by copying config.json from Ministral to your model directory.

With delegated request processing

Pending on some changes (TODO LINK) in Dynamo, you will be able to skip the file-copying step above.

Next steps

We have identified further optimizations which will be part of other PRs:

  • FP8 context attention
  • Routing GEMM optimization

Contributors

@dbari, @DanBlanaru, @evezhier, @hypdeb

hypdeb and others added 3 commits December 2, 2025 04:13
Signed-off-by: jdebache <jdebache@nvidia.com>
Signed-off-by: Dimitrios Bariamis <12195802+dbari@users.noreply.github.com>
@mergify mergify bot added ci/build performance Performance-related issues nvidia v1 labels Dec 2, 2025
Julien Debache and others added 3 commits December 2, 2025 05:46
Signed-off-by: Julien Debache <jdebache@cpu-0002.cm.cluster>
Signed-off-by: Julien Debache <jdebache@cpu-0002.cm.cluster>
Signed-off-by: Dan Blanaru <48605845+DanBlanaru@users.noreply.github.com>
Copy link
Member

@mgoin mgoin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems there are a few things mixed into this one. Do you think we could prioritize the critical perf features like kernel support and tuned configs?

torchvision==0.24.0 # Required for phi3v processor. See https://github.com/pytorch/vision?tab=readme-ov-file#installation for corresponding version
# FlashInfer should be updated together with the Dockerfile
flashinfer-python==0.5.3
nvtx==0.2.13
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this isn't a big deal but do we need this dep?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We've added some nvtx ranges to the gpu_model_runner.py to make profiling easier, which uses nvtx. This can be split into another PR if it is desirable.

Comment on lines 2987 to 3008

# Count context tokens per request
context_requests = 0
decode_requests = 0

for req in scheduler_output.scheduled_new_reqs:
context_len = len(req.prompt_token_ids) if req.prompt_token_ids else 0
num_computed = req.num_computed_tokens
if num_computed < context_len:
context_requests += 1
else:
decode_requests += 1
# For cached requests
for i, req_id in enumerate(scheduler_output.scheduled_cached_reqs.req_ids):
context_len = self.requests[req_id].num_prompt_tokens
num_computed = scheduler_output.scheduled_cached_reqs.num_computed_tokens[i]

if num_computed < context_len:
context_requests += 1
else:
decode_requests += 1

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we put this in another PR? We don't want to eat this cost when profiling isn't enabled

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will do.

@hypdeb
Copy link
Contributor Author

hypdeb commented Dec 2, 2025

It seems there are a few things mixed into this one. Do you think we could prioritize the critical perf features like kernel support and tuned configs?

I'll start by splitting the NVTX stuff out, see how it looks after this.

Julien Debache added 2 commits December 2, 2025 09:33
Signed-off-by: Julien Debache <jdebache@cpu-0002.cm.cluster>
Signed-off-by: Julien Debache <jdebache@cpu-0002.cm.cluster>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci/build nvidia performance Performance-related issues v1

4 participants