Skip to content

[HiCache] feat: add draft KV cache backing for L2/L3#21125

Open
alphabetc1 wants to merge 4 commits intosgl-project:mainfrom
alphabetc1:feat/hicache_spec_2
Open

[HiCache] feat: add draft KV cache backing for L2/L3#21125
alphabetc1 wants to merge 4 commits intosgl-project:mainfrom
alphabetc1:feat/hicache_spec_2

Conversation

@alphabetc1
Copy link
Collaborator

@alphabetc1 alphabetc1 commented Mar 22, 2026

Motivation

see #16964

How wo reproduce:

1. launch with spec+hicache SGLANG_HICACHE_FILE_BACKEND_STORAGE_DIR=/root/code/tmp/sglang_hicache_file_test \ python3 -m sglang.launch_server \ --model /models/ZhipuAI/GLM-4.7-FP8/ \ --enable-hierarchical-cache \ --hicache-io-backend direct \ --hicache-mem-layout page_first_direct \ --hicache-write-policy write_through --host 0.0.0.0 \ --mem-fraction-static 0.55 --page-size 64 \ --port 7000 \ --reasoning-parser glm45 \ --served-model-name GLM-4.7 \ --speculative-algorithm EAGLE \ --speculative-eagle-topk 1 \ --speculative-num-draft-tokens 8 \ --speculative-num-steps 7 \ --tool-call-parser glm47 \ --tp 8 \ --trust-remote-code \ --model-loader-extra-config '{"enable_multithread_load": true, "num_threads": 64}' \ --hicache-size 60 \ --hicache-storage-prefetch-policy wait_complete \ --hicache-storage-backend file 2. benchmark for i in 1 2; do python3 -m sglang.bench_serving \ --backend sglang-oai \ --dataset-name random \ --num-prompts 4 \ --model /models/ZhipuAI/GLM-4.7-FP8/ \ --dataset-path /root/code/datasets/ShareGPT_V3_unfiltered_cleaned_split.json \ --random-input-len 150000 \ --random-output-len 1024 \ --random-range-ratio 0.3 \ --max-concurrency 4 \ --warmup-requests 0 \ --seed 77 \ --host 0.0.0.0 --port 7000 done 

benchmark result before this feature(draft L2/L3kvcache):

============ Serving Benchmark Result ============ Backend: sglang-oai Traffic request rate: inf Max request concurrency: 4 Successful requests: 4 Benchmark duration (s): 21.80 Total input tokens: 422526 Total input text tokens: 422526 Total generated tokens: 2861 Total generated tokens (retokenized): 2861 Request throughput (req/s): 0.18 Input token throughput (tok/s): 19378.87 Output token throughput (tok/s): 131.22 Peak output token throughput (tok/s): 40.00 Peak concurrent requests: 4 Total token throughput (tok/s): 19510.08 Concurrency: 2.90 Accept length: 4.77 ----------------End-to-End Latency---------------- Mean E2E Latency (ms): 15826.53 Median E2E Latency (ms): 15680.88 P90 E2E Latency (ms): 21601.95 P99 E2E Latency (ms): 21773.48 ---------------Time to First Token---------------- Mean TTFT (ms): 7729.05 Median TTFT (ms): 8167.43 P99 TTFT (ms): 13053.19 -----Time per Output Token (excl. 1st token)------ Mean TPOT (ms): 11.89 Median TPOT (ms): 11.58 P99 TPOT (ms): 15.91 ---------------Inter-Token Latency---------------- Mean ITL (ms): 11.34 Median ITL (ms): 10.24 P95 ITL (ms): 17.06 P99 ITL (ms): 25.61 Max ITL (ms): 603.61 

benchmark result after this feature(draft L2/L3kvcache):
accelt length 4.77->6.90

============ Serving Benchmark Result ============ Backend: sglang-oai Traffic request rate: inf Max request concurrency: 4 Successful requests: 4 Benchmark duration (s): 19.14 Total input tokens: 422526 Total input text tokens: 422526 Total generated tokens: 2861 Total generated tokens (retokenized): 2861 Request throughput (req/s): 0.21 Input token throughput (tok/s): 22080.00 Output token throughput (tok/s): 149.51 Peak output token throughput (tok/s): 40.00 Peak concurrent requests: 4 Total token throughput (tok/s): 22229.51 Concurrency: 2.91 Accept length: 6.90 ----------------End-to-End Latency---------------- Mean E2E Latency (ms): 13933.53 Median E2E Latency (ms): 15631.02 P90 E2E Latency (ms): 18434.37 P99 E2E Latency (ms): 19059.05 ---------------Time to First Token---------------- Mean TTFT (ms): 8588.41 Median TTFT (ms): 8269.18 P99 TTFT (ms): 15607.73 -----Time per Output Token (excl. 1st token)------ Mean TPOT (ms): 7.20 Median TPOT (ms): 7.18 P99 TPOT (ms): 8.79 ---------------Inter-Token Latency---------------- Mean ITL (ms): 7.48 Median ITL (ms): 6.43 P95 ITL (ms): 13.25 P99 ITL (ms): 18.70 Max ITL (ms): 232.90 ================================================== 

Modifications

Accuracy Tests

Benchmarking and Profiling

Checklist

Review Process

  1. Ping Merge Oncalls to start the PR flow. See the PR Merge Process.
  2. Get approvals from CODEOWNERS and other reviewers.
  3. Trigger CI tests with comments or contact authorized users to do so.
    • /tag-run-ci-label, /rerun-failed-ci, /tag-and-rerun-ci
  4. After green CI and required approvals, ask Merge Oncalls to merge.
@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces a draft KV cache backing for the HiCache system, specifically targeting L2/L3 caches. This enhancement addresses a regression issue in speculative decoding where the draft model could see stale or incorrect slot indices after a load_back operation. The changes ensure that the draft KV pool is properly synchronized with the main KV cache, improving the accuracy and stability of speculative decoding.

Highlights

  • HiCache Draft KV Pool: Introduces a draft KV cache backing for L2/L3 HiCache to improve speculative decoding accuracy by ensuring draft KV is restored to correct indices after load_back operations.
  • Integration with HiRadixCache and HiMambaRadixCache: The draft KV pool functionality is integrated into both HiRadixCache and HiMambaRadixCache via a mixin class, providing a unified approach for managing draft KV in hierarchical caches.
  • Hooks for Write, Load, and Evict Operations: The changes include hooks into the HiCache write, load, and evict paths to manage the draft KV pool, ensuring consistency between the main and draft KV caches.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for backing up the draft model's KV cache to L2 (host) and L3 (storage) when using HiCache with speculative decoding. This is achieved by a new HiCacheDraftMixin that hooks into HiCache's data movement operations (write, load, evict) to synchronize the draft KV cache with the main model's KV cache. The changes are well-integrated into HiRadixCache and HiMambaRadixCache. My main feedback is to refactor a piece of duplicated logic in the scheduler for better maintainability.

Comment on lines +806 to +814
# Mirror the logic in init_disaggregation() to locate the draft KV pool.
if self.spec_algorithm.supports_spec_v2() and self.enable_overlap:
if self.server_args.enable_multi_layer_eagle:
draft_runner = self.draft_worker.draft_worker.draft_runner_list[0]
else:
draft_runner = self.draft_worker.draft_worker.draft_runner
draft_kv_pool = draft_runner.token_to_kv_pool
else:
draft_kv_pool = self.draft_worker.model_runner.token_to_kv_pool
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The logic to locate the draft KV pool is duplicated from init_disaggregation (lines 946-959). To improve maintainability and avoid future inconsistencies, consider refactoring this logic into a helper method, e.g., _get_draft_kv_pool(), and calling it from both _maybe_register_hicache_draft and init_disaggregation.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

1 participant