99initialized randomly with a fixed seed.
1010"""
1111
12+ from copy import deepcopy
1213from dataclasses import dataclass
1314from typing import Any
1415
2627 set_current_vllm_config ,
2728)
2829from vllm .forward_context import BatchDescriptor , set_forward_context
30+ from vllm .utils import is_torch_equal_or_newer
2931
3032# This import automatically registers `torch.ops.silly.attention`
3133from .. import silly_attention # noqa: F401
@@ -257,27 +259,13 @@ def tractable_computation(
257259
258260
259261@torch .inference_mode
260- def run_model (
261- llama_config , use_compile : bool , backend : str , split_attn : bool = False
262- ) -> torch .Tensor :
263- if use_compile :
264- compilation_config = CompilationConfig (
265- level = CompilationLevel .PIECEWISE ,
266- use_cudagraph = True ,
267- backend = backend ,
268- cudagraph_capture_sizes = [1 , 2 ],
269- )
270- if split_attn :
271- compilation_config .splitting_ops = ["silly::attention" ]
272- cudagraph_runtime_mode = CUDAGraphMode .PIECEWISE
273- else :
274- compilation_config = CompilationConfig (
275- level = CompilationLevel .NO_COMPILATION ,
276- )
277- cudagraph_runtime_mode = CUDAGraphMode .NONE
262+ def run_model (llama_config , compile_config : CompilationConfig ) -> torch .Tensor :
263+ # Start with a fresh copy to make sure there's no cache dir sharing
264+ compile_config = deepcopy (compile_config )
265+ cudagraph_runtime_mode = compile_config .cudagraph_mode
278266
279267 vllm_config = VllmConfig (
280- compilation_config = compilation_config , additional_config = llama_config
268+ compilation_config = compile_config , additional_config = llama_config
281269 )
282270 with set_current_vllm_config (vllm_config ):
283271 model = (
@@ -338,8 +326,25 @@ def run_model(
338326 return output .cpu ()
339327
340328
341- @pytest .mark .parametrize ("backend" , ["inductor" , "eager" ])
342- def test_toy_llama (backend : str ):
329+ @pytest .mark .parametrize (
330+ "backend, use_inductor_graph_partition" ,
331+ [
332+ ("eager" , False ), # No inductor
333+ ("inductor" , False ), # Inductor, Dynamo partition
334+ ("inductor" , True ), # Inductor, Inductor partition
335+ ],
336+ )
337+ def test_toy_llama (
338+ backend : str , use_inductor_graph_partition : bool , monkeypatch , tmp_path
339+ ):
340+ # We disable the vLLM compile cache into a new tmp dir for 2 reasons:
341+ # 1. To make sure we can properly track the number of Inductor compilations.
342+ # 2. Inductor partitioning does not play nicely with Autograd cache (below)
343+ monkeypatch .setenv ("VLLM_DISABLE_COMPILE_CACHE" , "1" )
344+
345+ if use_inductor_graph_partition and not is_torch_equal_or_newer ("2.9.0.dev" ):
346+ pytest .skip ("Inductor graph partition only supported in torch>=2.9" )
347+
343348 # compare output with and without piecewise compilation
344349
345350 llama_config = LlamaConfig (
@@ -350,6 +355,32 @@ def test_toy_llama(backend: str):
350355 hidden_size = 128 , mlp_size = 256 , vocab_size = 128 , num_layers = 2 , tractable_init = True
351356 )
352357
358+ compile_config_no_compile = CompilationConfig (
359+ level = CompilationLevel .NO_COMPILATION ,
360+ cudagraph_mode = CUDAGraphMode .NONE ,
361+ backend = "eager" ,
362+ )
363+
364+ compile_config_no_split = CompilationConfig (
365+ level = CompilationLevel .PIECEWISE ,
366+ use_inductor_graph_partition = use_inductor_graph_partition ,
367+ cudagraph_mode = CUDAGraphMode .PIECEWISE ,
368+ backend = backend ,
369+ cudagraph_capture_sizes = [1 , 2 ],
370+ )
371+
372+ # FIXME(luka/boyuan): the graph from the previous test case
373+ # (no inductor partition) gets cached by AotAutograd so then the
374+ # compilation with inductor partitioning incorrectly loads an unpartitioned
375+ # graph and never partitions. I think this is a bug with custom inductor
376+ # partitioning but does not affect vLLM more generally as vLLM uses its own
377+ # cache (which takes inductor partitioning into account).
378+ if use_inductor_graph_partition :
379+ compile_config_no_split .inductor_compile_config ["force_disable_caches" ] = True
380+
381+ compile_config_split = deepcopy (compile_config_no_split )
382+ compile_config_split .splitting_ops = ["silly::attention" ]
383+
353384 outputs = []
354385 with compilation_counter .expect (
355386 num_graphs_seen = 0 ,
@@ -358,44 +389,44 @@ def test_toy_llama(backend: str):
358389 num_backend_compilations = 0 ,
359390 num_cudagraph_captured = 0 ,
360391 ):
361- outputs .append (run_model (llama_config , backend = "eager" , use_compile = False ))
362- run_model (tractable_config , backend = "eager" , use_compile = False )
392+ outputs .append (run_model (llama_config , compile_config_no_compile ))
393+
394+ run_model (tractable_config , compile_config_no_compile )
363395
364396 if backend == "inductor" :
365397 kwargs = {"num_inductor_compiles" : 1 , "num_eager_compiles" : 0 }
366398 else :
367399 kwargs = {"num_eager_compiles" : 1 , "num_inductor_compiles" : 0 }
368400
369401 with compilation_counter .expect (
370- # One graph for the model
371- num_graphs_seen = 1 ,
402+ num_graphs_seen = 1 , # one graph for the model
372403 num_piecewise_graphs_seen = 1 ,
373404 num_piecewise_capturable_graphs_seen = 1 ,
374- # num_piecewise_capturable_graphs_seen
375- num_backend_compilations = 1 ,
376- # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
405+ num_backend_compilations = 1 , # num_piecewise_capturable_graphs_seen
377406 num_cudagraph_captured = 2 ,
378407 ** kwargs ,
379408 ):
380- outputs .append (run_model (llama_config , backend = backend , use_compile = True ))
381- run_model (tractable_config , backend = backend , use_compile = True )
409+ outputs .append (run_model (llama_config , compile_config_no_split ))
410+
411+ run_model (tractable_config , compile_config_no_split )
412+
413+ if use_inductor_graph_partition :
414+ num_piecewise_fx = 1
415+ num_piecewise_capturable_fx = 1
416+ else :
417+ num_piecewise_fx = 2 * llama_config .num_layers + 1
418+ num_piecewise_capturable_fx = 1 + llama_config .num_layers
382419
383420 with compilation_counter .expect (
384421 num_graphs_seen = 1 , # one graph for the model
385- num_piecewise_graphs_seen = 2 * llama_config .num_layers + 1 , # 2 * num_layers + 1
386- num_piecewise_capturable_graphs_seen = 1
387- + llama_config .num_layers , # 1 + num_layers
388- num_backend_compilations = 1
389- + llama_config .num_layers , # num_piecewise_capturable_graphs_seen
390- num_cudagraph_captured = 2
391- * (
392- 1 + llama_config .num_layers
393- ), # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
422+ num_piecewise_graphs_seen = num_piecewise_fx ,
423+ num_piecewise_capturable_graphs_seen = num_piecewise_capturable_fx ,
424+ num_backend_compilations = num_piecewise_capturable_fx ,
425+ # num_cudagraph_sizes * num_partitions
426+ num_cudagraph_captured = 2 * (1 + llama_config .num_layers ),
394427 ):
395- outputs .append (
396- run_model (llama_config , backend = backend , use_compile = True , split_attn = True )
397- )
398- run_model (tractable_config , backend = backend , use_compile = True , split_attn = True )
428+ outputs .append (run_model (llama_config , compile_config_split ))
429+ run_model (tractable_config , compile_config_split )
399430
400431 for i in range (1 , len (outputs )):
401432 assert torch .allclose (outputs [0 ], outputs [i ])
0 commit comments