Skip to content

BatchPrefillWithPagedKVCacheWrapper.plan() got an unexpected keyword argument 'head_dim' #3165

@ruckc

Description

@ruckc

Given this code in text-generation-inference

state.plan(
qo_indptr=cu_seqlens,
paged_kv_indptr=indptr,
paged_kv_indices=block_tables,
paged_kv_last_page_len=last_page_len,
num_qo_heads=num_heads,
num_kv_heads=num_kv_heads,
head_dim=head_size,
kv_data_type=kv_dtype,
q_data_type=q_dtype,
page_size=page_size,
)

and this code from flashinfer

https://github.com/flashinfer-ai/flashinfer/blob/55576c626421b5ee7e7ebe74afd26465c8ae863f/flashinfer/prefill.py#L1164-L1188

 def plan( self, qo_indptr: torch.Tensor, paged_kv_indptr: torch.Tensor, paged_kv_indices: torch.Tensor, paged_kv_last_page_len: torch.Tensor, num_qo_heads: int, num_kv_heads: int, head_dim_qk: int, page_size: int, head_dim_vo: Optional[int] = None, custom_mask: Optional[torch.Tensor] = None, packed_custom_mask: Optional[torch.Tensor] = None, causal: bool = False, pos_encoding_mode: str = "NONE", use_fp16_qk_reduction: bool = False, sm_scale: Optional[float] = None, window_left: int = -1, logits_soft_cap: Optional[float] = None, rope_scale: Optional[float] = None, rope_theta: Optional[float] = None, q_data_type: Union[str, torch.dtype] = "float16", kv_data_type: Optional[Union[str, torch.dtype]] = None, non_blocking: bool = True, ) -> None:

I'm getting:

2025-04-11T13:41:22.164186Z ERROR warmup{max_input_length=None max_prefill_tokens=4096 max_total_tokens=None max_batch_size=None}:warmup: text_generation_router_v3::client: backends/v3/src/client/mod.rs:45: Server error: BatchPrefillWithPagedKVCacheWrapper.plan() got an unexpected keyword argument 'head_dim' 2025-04-11T13:41:22.164259Z ERROR text_generation_launcher: Method Warmup encountered an error. Traceback (most recent call last): File "/app/venv/bin/text-generation-server", line 10, in <module> sys.exit(app()) File "/app/venv/lib/python3.12/site-packages/typer/main.py", line 323, in __call__ return get_command(self)(*args, **kwargs) File "/app/venv/lib/python3.12/site-packages/click/core.py", line 1161, in __call__ return self.main(*args, **kwargs) File "/app/venv/lib/python3.12/site-packages/typer/core.py", line 743, in main return _main( File "/app/venv/lib/python3.12/site-packages/typer/core.py", line 198, in _main rv = self.invoke(ctx) File "/app/venv/lib/python3.12/site-packages/click/core.py", line 1697, in invoke return _process_result(sub_ctx.command.invoke(sub_ctx)) File "/app/venv/lib/python3.12/site-packages/click/core.py", line 1443, in invoke return ctx.invoke(self.callback, **ctx.params) File "/app/venv/lib/python3.12/site-packages/click/core.py", line 788, in invoke return __callback(*args, **kwargs) File "/app/venv/lib/python3.12/site-packages/typer/main.py", line 698, in wrapper return callback(**use_params) File "/app/venv/lib/python3.12/site-packages/text_generation_server/cli.py", line 119, in serve server.serve( File "/app/venv/lib/python3.12/site-packages/text_generation_server/server.py", line 315, in serve asyncio.run( File "/usr/lib/python3.12/asyncio/runners.py", line 194, in run return runner.run(main) File "/usr/lib/python3.12/asyncio/runners.py", line 118, in run return self._loop.run_until_complete(task) File "/usr/lib/python3.12/asyncio/base_events.py", line 674, in run_until_complete self.run_forever() File "/usr/lib/python3.12/asyncio/base_events.py", line 641, in run_forever self._run_once() File "/usr/lib/python3.12/asyncio/base_events.py", line 1987, in _run_once handle._run() File "/usr/lib/python3.12/asyncio/events.py", line 88, in _run self._context.run(self._callback, *self._args) File "/app/venv/lib/python3.12/site-packages/grpc_interceptor/server.py", line 165, in invoke_intercept_method return await self.intercept( > File "/app/venv/lib/python3.12/site-packages/text_generation_server/interceptor.py", line 24, in intercept return await response File "/app/venv/lib/python3.12/site-packages/opentelemetry/instrumentation/grpc/_aio_server.py", line 120, in _unary_interceptor raise error File "/app/venv/lib/python3.12/site-packages/opentelemetry/instrumentation/grpc/_aio_server.py", line 111, in _unary_interceptor return await behavior(request_or_iterator, context) File "/app/venv/lib/python3.12/site-packages/text_generation_server/server.py", line 144, in Warmup self.model.warmup(batch, max_input_tokens, max_total_tokens) File "/app/venv/lib/python3.12/site-packages/text_generation_server/models/flash_causal_lm.py", line 1548, in warmup _, _batch, _ = self.generate_token(batch) File "/usr/lib/python3.12/contextlib.py", line 81, in inner return func(*args, **kwds) File "/app/venv/lib/python3.12/site-packages/text_generation_server/models/flash_causal_lm.py", line 1928, in generate_token out, speculative_logits = self.forward(batch, adapter_data) File "/app/venv/lib/python3.12/site-packages/text_generation_server/models/flash_causal_lm.py", line 1810, in forward with self._forward_context( File "/usr/lib/python3.12/contextlib.py", line 137, in __enter__ return next(self.gen) File "/app/venv/lib/python3.12/site-packages/text_generation_server/layers/attention/flashinfer.py", line 86, in use_prefill_with_paged_kv_state state.plan( TypeError: BatchPrefillWithPagedKVCacheWrapper.plan() got an unexpected keyword argument 'head_dim' 

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions