Skip to content

Commit 2dcd12d

Browse files
[torch.compile] Fix tests for torch==2.9 inductor partition (#26116)
Signed-off-by: ProExpertProg <lgovedic@redhat.com> Signed-off-by: Luka Govedič <lgovedic@redhat.com>
1 parent 579d2e5 commit 2dcd12d

File tree

8 files changed

+138
-72
lines changed

8 files changed

+138
-72
lines changed

tests/compile/piecewise/test_full_cudagraph.py

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from vllm import LLM, SamplingParams
1212
from vllm.config import CompilationConfig
1313
from vllm.platforms import current_platform
14+
from vllm.utils import is_torch_equal_or_newer
1415

1516

1617
@contextlib.contextmanager
@@ -32,28 +33,32 @@ def temporary_environ(env_vars):
3233
os.environ[k] = v
3334

3435

35-
test_params_full_cudagraph = []
36+
model_backends_full_cudagraph = []
3637

3738
# deepseek-ai/DeepSeek-V2-Lite with MLA
3839
MLA_backends = ["FlashMLA", "FlashAttentionMLA", "CutlassMLA"]
3940
for mla_backend in MLA_backends:
40-
test_params_full_cudagraph.append(
41-
pytest.param(("deepseek-ai/DeepSeek-V2-Lite", backend_configs[mla_backend]))
41+
model_backends_full_cudagraph.append(
42+
("deepseek-ai/DeepSeek-V2-Lite", backend_configs[mla_backend])
4243
)
4344

4445
# Qwen/Qwen2-1.5B-Instruct with other backends
4546
other_backend_configs = [
4647
backend_configs[c] for c in backend_configs if c not in MLA_backends
4748
]
4849
for backend_config in other_backend_configs:
49-
test_params_full_cudagraph.append(
50-
pytest.param(("Qwen/Qwen2-1.5B-Instruct", backend_config))
51-
)
50+
model_backends_full_cudagraph.append(("Qwen/Qwen2-1.5B-Instruct", backend_config))
5251

5352

5453
@pytest.fixture(scope="class")
5554
def llm_pair(request):
56-
model, backend_config = request.param
55+
model, backend_config, use_inductor_graph_partition = request.param
56+
backend_config.comp_config["use_inductor_graph_partition"] = (
57+
use_inductor_graph_partition
58+
)
59+
60+
if use_inductor_graph_partition and not is_torch_equal_or_newer("2.9.0.dev"):
61+
pytest.skip("Inductor graph partition only supported in torch>=2.9")
5762

5863
# Dynamically skip test if GPU capability is not met
5964
if (
@@ -104,7 +109,15 @@ def llm_pair(request):
104109
)
105110

106111

107-
@pytest.mark.parametrize("llm_pair", test_params_full_cudagraph, indirect=True)
112+
@pytest.mark.parametrize(
113+
"llm_pair",
114+
[
115+
pytest.param((model, backend_config, use_inductor_graph_partition))
116+
for model, backend_config in model_backends_full_cudagraph
117+
for use_inductor_graph_partition in [True, False]
118+
],
119+
indirect=True,
120+
)
108121
class TestFullCUDAGraph:
109122
"""
110123
Use a class such that an llm pair is constructed once for all

tests/compile/piecewise/test_multiple_graphs.py

Lines changed: 27 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
are compiled and graph captured separately.
66
"""
77

8+
import pytest
89
import torch
910
from torch import nn
1011

@@ -190,7 +191,12 @@ def run_model(
190191
return output.cpu()
191192

192193

193-
def test_multi_graph_piecewise_compile_outputs_equal():
194+
@pytest.mark.parametrize("use_inductor_graph_partition", [False, True])
195+
def test_multi_graph_piecewise_compile(use_inductor_graph_partition: bool):
196+
if use_inductor_graph_partition:
197+
# FIXME(luka/boyuan): this currently fails
198+
pytest.skip("Inductor graph partition not supported with multi-graph")
199+
194200
outputs = []
195201

196202
# piecewise compile
@@ -200,6 +206,7 @@ def test_multi_graph_piecewise_compile_outputs_equal():
200206
use_cudagraph=True,
201207
splitting_ops=["silly::attention"],
202208
cudagraph_capture_sizes=[1, 2],
209+
use_inductor_graph_partition=use_inductor_graph_partition,
203210
)
204211
)
205212
cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE
@@ -220,16 +227,24 @@ def test_multi_graph_piecewise_compile_outputs_equal():
220227
# static tensor addresses
221228
inputs = torch.randn(BATCH_SIZE, MLP_SIZE).cuda()
222229

223-
with compilation_counter.expect(
224-
num_graphs_seen=2, # two graphs for the model
225-
num_piecewise_graphs_seen=6,
230+
if use_inductor_graph_partition:
231+
# Splitting happens at Inductor lowering level,
232+
# total piecewise fx graphs is equal to total graphs
233+
num_piecewise_fx = 2
234+
num_piecewise_capturable_fx = 2
235+
else:
226236
# attn_one, attn_two each has 3 piecewise graphs
227237
# (pre attn, post attn, silly_attention) each
228-
num_piecewise_capturable_graphs_seen=4,
238+
num_piecewise_fx = 6
229239
# attn_one, attn_two has pre attn and post attn each, total=4
230-
num_backend_compilations=4, # num_piecewise_capturable_graphs_seen
231-
num_cudagraph_captured=8,
232-
# num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
240+
num_piecewise_capturable_fx = 4
241+
242+
with compilation_counter.expect(
243+
num_graphs_seen=2, # two graphs for the model
244+
num_piecewise_graphs_seen=num_piecewise_fx,
245+
num_piecewise_capturable_graphs_seen=num_piecewise_capturable_fx,
246+
num_backend_compilations=num_piecewise_capturable_fx,
247+
num_cudagraph_captured=8, # num_cudagraph_sizes * num_partitions
233248
):
234249
outputs.append(run_model(vllm_config, model, inputs, cudagraph_runtime_mode))
235250

@@ -268,6 +283,7 @@ def test_multi_graph_piecewise_compile_outputs_equal():
268283
level=CompilationLevel.PIECEWISE,
269284
use_cudagraph=False,
270285
splitting_ops=["silly::attention"],
286+
use_inductor_graph_partition=use_inductor_graph_partition,
271287
)
272288
)
273289
cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE
@@ -286,9 +302,9 @@ def test_multi_graph_piecewise_compile_outputs_equal():
286302

287303
with compilation_counter.expect(
288304
num_graphs_seen=2,
289-
num_piecewise_graphs_seen=6,
290-
num_piecewise_capturable_graphs_seen=4,
291-
num_backend_compilations=4,
305+
num_piecewise_graphs_seen=num_piecewise_fx,
306+
num_piecewise_capturable_graphs_seen=num_piecewise_capturable_fx,
307+
num_backend_compilations=num_piecewise_capturable_fx,
292308
num_cudagraph_captured=0, # no cudagraph captured
293309
):
294310
outputs.append(run_model(vllm_config, model, inputs, cudagraph_runtime_mode))

tests/compile/piecewise/test_toy_llama.py

Lines changed: 74 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
initialized randomly with a fixed seed.
1010
"""
1111

12+
from copy import deepcopy
1213
from dataclasses import dataclass
1314
from typing import Any
1415

@@ -26,6 +27,7 @@
2627
set_current_vllm_config,
2728
)
2829
from vllm.forward_context import BatchDescriptor, set_forward_context
30+
from vllm.utils import is_torch_equal_or_newer
2931

3032
# This import automatically registers `torch.ops.silly.attention`
3133
from .. import silly_attention # noqa: F401
@@ -257,27 +259,13 @@ def tractable_computation(
257259

258260

259261
@torch.inference_mode
260-
def run_model(
261-
llama_config, use_compile: bool, backend: str, split_attn: bool = False
262-
) -> torch.Tensor:
263-
if use_compile:
264-
compilation_config = CompilationConfig(
265-
level=CompilationLevel.PIECEWISE,
266-
use_cudagraph=True,
267-
backend=backend,
268-
cudagraph_capture_sizes=[1, 2],
269-
)
270-
if split_attn:
271-
compilation_config.splitting_ops = ["silly::attention"]
272-
cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE
273-
else:
274-
compilation_config = CompilationConfig(
275-
level=CompilationLevel.NO_COMPILATION,
276-
)
277-
cudagraph_runtime_mode = CUDAGraphMode.NONE
262+
def run_model(llama_config, compile_config: CompilationConfig) -> torch.Tensor:
263+
# Start with a fresh copy to make sure there's no cache dir sharing
264+
compile_config = deepcopy(compile_config)
265+
cudagraph_runtime_mode = compile_config.cudagraph_mode
278266

279267
vllm_config = VllmConfig(
280-
compilation_config=compilation_config, additional_config=llama_config
268+
compilation_config=compile_config, additional_config=llama_config
281269
)
282270
with set_current_vllm_config(vllm_config):
283271
model = (
@@ -338,8 +326,25 @@ def run_model(
338326
return output.cpu()
339327

340328

341-
@pytest.mark.parametrize("backend", ["inductor", "eager"])
342-
def test_toy_llama(backend: str):
329+
@pytest.mark.parametrize(
330+
"backend, use_inductor_graph_partition",
331+
[
332+
("eager", False), # No inductor
333+
("inductor", False), # Inductor, Dynamo partition
334+
("inductor", True), # Inductor, Inductor partition
335+
],
336+
)
337+
def test_toy_llama(
338+
backend: str, use_inductor_graph_partition: bool, monkeypatch, tmp_path
339+
):
340+
# We disable the vLLM compile cache into a new tmp dir for 2 reasons:
341+
# 1. To make sure we can properly track the number of Inductor compilations.
342+
# 2. Inductor partitioning does not play nicely with Autograd cache (below)
343+
monkeypatch.setenv("VLLM_DISABLE_COMPILE_CACHE", "1")
344+
345+
if use_inductor_graph_partition and not is_torch_equal_or_newer("2.9.0.dev"):
346+
pytest.skip("Inductor graph partition only supported in torch>=2.9")
347+
343348
# compare output with and without piecewise compilation
344349

345350
llama_config = LlamaConfig(
@@ -350,6 +355,32 @@ def test_toy_llama(backend: str):
350355
hidden_size=128, mlp_size=256, vocab_size=128, num_layers=2, tractable_init=True
351356
)
352357

358+
compile_config_no_compile = CompilationConfig(
359+
level=CompilationLevel.NO_COMPILATION,
360+
cudagraph_mode=CUDAGraphMode.NONE,
361+
backend="eager",
362+
)
363+
364+
compile_config_no_split = CompilationConfig(
365+
level=CompilationLevel.PIECEWISE,
366+
use_inductor_graph_partition=use_inductor_graph_partition,
367+
cudagraph_mode=CUDAGraphMode.PIECEWISE,
368+
backend=backend,
369+
cudagraph_capture_sizes=[1, 2],
370+
)
371+
372+
# FIXME(luka/boyuan): the graph from the previous test case
373+
# (no inductor partition) gets cached by AotAutograd so then the
374+
# compilation with inductor partitioning incorrectly loads an unpartitioned
375+
# graph and never partitions. I think this is a bug with custom inductor
376+
# partitioning but does not affect vLLM more generally as vLLM uses its own
377+
# cache (which takes inductor partitioning into account).
378+
if use_inductor_graph_partition:
379+
compile_config_no_split.inductor_compile_config["force_disable_caches"] = True
380+
381+
compile_config_split = deepcopy(compile_config_no_split)
382+
compile_config_split.splitting_ops = ["silly::attention"]
383+
353384
outputs = []
354385
with compilation_counter.expect(
355386
num_graphs_seen=0,
@@ -358,44 +389,44 @@ def test_toy_llama(backend: str):
358389
num_backend_compilations=0,
359390
num_cudagraph_captured=0,
360391
):
361-
outputs.append(run_model(llama_config, backend="eager", use_compile=False))
362-
run_model(tractable_config, backend="eager", use_compile=False)
392+
outputs.append(run_model(llama_config, compile_config_no_compile))
393+
394+
run_model(tractable_config, compile_config_no_compile)
363395

364396
if backend == "inductor":
365397
kwargs = {"num_inductor_compiles": 1, "num_eager_compiles": 0}
366398
else:
367399
kwargs = {"num_eager_compiles": 1, "num_inductor_compiles": 0}
368400

369401
with compilation_counter.expect(
370-
# One graph for the model
371-
num_graphs_seen=1,
402+
num_graphs_seen=1, # one graph for the model
372403
num_piecewise_graphs_seen=1,
373404
num_piecewise_capturable_graphs_seen=1,
374-
# num_piecewise_capturable_graphs_seen
375-
num_backend_compilations=1,
376-
# num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
405+
num_backend_compilations=1, # num_piecewise_capturable_graphs_seen
377406
num_cudagraph_captured=2,
378407
**kwargs,
379408
):
380-
outputs.append(run_model(llama_config, backend=backend, use_compile=True))
381-
run_model(tractable_config, backend=backend, use_compile=True)
409+
outputs.append(run_model(llama_config, compile_config_no_split))
410+
411+
run_model(tractable_config, compile_config_no_split)
412+
413+
if use_inductor_graph_partition:
414+
num_piecewise_fx = 1
415+
num_piecewise_capturable_fx = 1
416+
else:
417+
num_piecewise_fx = 2 * llama_config.num_layers + 1
418+
num_piecewise_capturable_fx = 1 + llama_config.num_layers
382419

383420
with compilation_counter.expect(
384421
num_graphs_seen=1, # one graph for the model
385-
num_piecewise_graphs_seen=2 * llama_config.num_layers + 1, # 2 * num_layers + 1
386-
num_piecewise_capturable_graphs_seen=1
387-
+ llama_config.num_layers, # 1 + num_layers
388-
num_backend_compilations=1
389-
+ llama_config.num_layers, # num_piecewise_capturable_graphs_seen
390-
num_cudagraph_captured=2
391-
* (
392-
1 + llama_config.num_layers
393-
), # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
422+
num_piecewise_graphs_seen=num_piecewise_fx,
423+
num_piecewise_capturable_graphs_seen=num_piecewise_capturable_fx,
424+
num_backend_compilations=num_piecewise_capturable_fx,
425+
# num_cudagraph_sizes * num_partitions
426+
num_cudagraph_captured=2 * (1 + llama_config.num_layers),
394427
):
395-
outputs.append(
396-
run_model(llama_config, backend=backend, use_compile=True, split_attn=True)
397-
)
398-
run_model(tractable_config, backend=backend, use_compile=True, split_attn=True)
428+
outputs.append(run_model(llama_config, compile_config_split))
429+
run_model(tractable_config, compile_config_split)
399430

400431
for i in range(1, len(outputs)):
401432
assert torch.allclose(outputs[0], outputs[i])

tests/compile/silly_attention.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,5 +62,4 @@ def silly_attention_fake(
6262
mutates_args=["out"],
6363
fake_impl=silly_attention_fake,
6464
target_lib=silly_lib,
65-
tags=(torch._C.Tag.cudagraph_unsafe,),
6665
)

tests/compile/test_decorator.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ def test_ignore_torch_compile_decorator():
7373
use_cudagraph=True,
7474
splitting_ops=["silly::attention"],
7575
cudagraph_capture_sizes=[1, 2],
76+
use_inductor_graph_partition=False, # TODO test both?
7677
)
7778
)
7879
cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE
@@ -188,6 +189,7 @@ def test_conditional_compile_enable_if():
188189
use_cudagraph=True,
189190
splitting_ops=["silly::attention"],
190191
cudagraph_capture_sizes=[1, 2],
192+
use_inductor_graph_partition=False, # TODO test both
191193
),
192194
)
193195
cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE
@@ -220,6 +222,7 @@ def test_conditional_compile_enable_if():
220222
use_cudagraph=True,
221223
splitting_ops=["silly::attention"],
222224
cudagraph_capture_sizes=[1, 2],
225+
use_inductor_graph_partition=False, # TODO test both?
223226
),
224227
)
225228

vllm/attention/layer.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -38,10 +38,6 @@
3838

3939
logger = init_logger(__name__)
4040
USE_XFORMERS_OPS = None
41-
try:
42-
tag_cudagraph_unsafe = (torch._C.Tag.cudagraph_unsafe,)
43-
except AttributeError:
44-
tag_cudagraph_unsafe = () # type: ignore[assignment]
4541

4642

4743
def check_xformers_availability():
@@ -879,7 +875,6 @@ def unified_attention_fake(
879875
op_name="unified_attention",
880876
op_func=unified_attention,
881877
fake_impl=unified_attention_fake,
882-
tags=tag_cudagraph_unsafe,
883878
)
884879

885880

@@ -931,7 +926,6 @@ def unified_attention_with_output_fake(
931926
op_func=unified_attention_with_output,
932927
mutates_args=["output", "output_block_scale"],
933928
fake_impl=unified_attention_with_output_fake,
934-
tags=tag_cudagraph_unsafe,
935929
)
936930

937931

0 commit comments

Comments
 (0)