📌 Introduction Link to heading
CUDA Graphs capture a sequence of GPU operations and replay them as a single unit, eliminating the CPU launch overhead for each individual kernel. This is especially beneficial for workloads with small batches or models with many small kernels, where CPU overhead can become a bottleneck.
When you use torch.compile(mode="reduce-overhead") or enable options={"triton.cudagraphs": True}, PyTorch automatically leverages CUDA Graphs to accelerate your model. But how does this actually work under the hood? This post dives into the internals of torch.compile’s CUDA Graph integration, covering the compilation pipeline, compatibility checks, graph partitioning, and the two implementation approaches (CUDA Graph Trees vs. legacy).
This post is based on PyTorch v2.9.0. All source code links point to this version.
🔧 CUDA Graph Basics Link to heading
Before diving into torch.compile, let’s briefly review how CUDA Graphs work.
CUDA Graphs fundamentally require:
Asynchronous execution model - All captured operations must be asynchronous CUDA calls (no CPU synchronization inside the graph).
Static graph topology - The sequence of operations (nodes and edges) must be fixed at capture time.
Static node parameters - All kernel parameters must remain constant, including:
- Grid/block dimensions
- Kernel arguments
- Memory addresses - Input/output tensor addresses must be identical between capture and replay
The static memory address requirement is particularly important: CUDA graphs “bake in” the actual pointers used during capture. If you replay with different addresses, the graph will read/write to the wrong memory locations.
Here’s how to use CUDA Graphs manually in PyTorch:
# Allocate static input buffer (address must stay fixed)
static_input = torch.randn(batch_size, features, device="cuda")
g = torch.cuda.CUDAGraph()
# Warmup run (JIT compile kernels, stabilize memory allocations)
s = torch.cuda.Stream()
s.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(s):
for _ in range(3):
output = model(static_input)
torch.cuda.current_stream().wait_stream(s)
# Capture
with torch.cuda.graph(g):
output = model(static_input)
# Replay (copy new data into static buffer first)
static_input.copy_(new_data)
g.replay() # No CPU overhead per kernel
Key concepts:
- Static input buffers - Input tensors must have fixed memory addresses. Copy new data into static buffers before replay.
- Warmup - Run the model before capture to JIT compile kernels (e.g., Triton, cuBLAS) and stabilize memory allocations.
- Capture - Record all CUDA operations into a graph using
torch.cuda.graph(). During capture, PyTorch allocates outputs from a private memory pool dedicated to this graph. The caching allocator guarantees these addresses remain stable across replays - memory is never freed back to the general pool while the graph exists. - Replay - Execute the entire graph as one unit, eliminating per-kernel CPU launch overhead.
For more details on how CUDA graphs interact with PyTorch’s memory allocator (private pools, memory reservation, replay guarantees), see Note [Interaction with CUDA graph capture] in the caching allocator implementation.
🚀 Using CUDA Graphs with PyTorch Link to heading
As you can see from the manual approach above, using CUDA Graphs correctly requires careful management of static buffers, warmup runs, and memory pools. PyTorch provides two higher-level APIs to simplify this: make_graphed_callables and torch.compile.
make_graphed_callables wraps functions or modules to automatically handle CUDA Graph capture and replay:
# Wrap model for CUDA Graph acceleration
model = torch.cuda.make_graphed_callables(model, sample_inputs)
output = model(x) # Automatically replays captured graph
While simpler than raw CUDA Graph APIs, you still need to ensure CUDA Graph compatibility - providing sample inputs with correct shapes, handling operations that can’t be captured. When wrapping multiple callables, they share a private memory pool and form a sequential chain - you must replay them in exactly the same order they were captured. Different orders or non-sequential structures (e.g., branching) are not supported and cause memory corruption.
torch.compile takes this further by handling all the complexity automatically - static buffer allocation, warmup, capture, replay, and even graph partitioning for incompatible ops are all managed for you:
# Use reduce-overhead mode (recommended)
@torch.compile(mode="reduce-overhead")
def forward(x):
return model(x)
Other modes that enable CUDA Graphs: "max-autotune". You can also use options={"triton.cudagraphs": True} for explicit control.
The rest of this post dives into how torch.compile implements CUDA Graph integration internally.
🔍 Compilation Flow Overview Link to heading
When you call torch.compile(), your code goes through a multi-stage compilation pipeline. CUDA Graph integration happens in the final Inductor stage, after all kernel optimizations are complete.
Stage breakdown:
TorchDynamo - Analyzes Python bytecode and captures FX Graphs (an intermediate representation of PyTorch operations as a directed acyclic graph). Graph breaks (e.g., data-dependent control flow) produce multiple FX Graphs, each processed independently.
AOT Autograd - Traces joint forward + backward graph, decomposes high-level operators into low-level operators (PrimTorch), and partitions into separate forward and backward graphs. Each graph goes through Inductor independently.
Inductor - Lowers each graph to optimized code, performs CUDA Graph compatibility checks, partitions the graph at incompatible op boundaries, generates Triton kernels, and finally wraps the compiled code with CUDA Graph logic.
Now let’s dive deeper into how Inductor handles CUDA Graph integration. The key call stack starting from codegen_and_compile() is:
codegen_and_compile() # compile_fx.py:1154
├── graph.run() # Lower FX → Inductor IR
├── graph.compile_to_module() # graph.py:2312
│ └── scheduler.codegen() # Partitioning + Triton/wrapper codegen
├── get_first_incompatible_cudagraph_node() # FX-level compat check
└── return CompiledFxGraph
└── post_compile() # CUDA Graph wrapping
The diagram below shows these stages visually:
Inductor IR
⚠️ Skip CUDA graph for
torch.cond/while_loop"] B --> CG subgraph CG["Codegen"] direction TB C1["Partitioning
(optional)"] --> C2["Gen Triton Kernel"] --> C3["Gen Python Wrapper"] end CG --> D["FX-level
Compat Check"] D --> E["CUDA Graph
Wrapping"] E --> F["🚀 Executable"] style F fill:#2d5a27,stroke:#1a3518,color:#fff
Here’s what each stage does:
- Lower to Inductor IR -
GraphLowering.run()converts FX graph nodes to Inductor’s internal representation. Notably,torch.condandtorch.while_loopsetdisable_cudagraphs_reasonduring lowering, which disables CUDA Graphs for the entire compiled graph (not just a partition). These higher-order control flow operators avoid Dynamo graph breaks but prevent any CUDA Graph capture. - Codegen -
Scheduler.codegen()performs optional graph partitioning (splitting at incompatible IR nodes viashould_partition()), generates Triton kernels viaTritonScheduling.define_kernel(), and generates Python wrapper code viaPythonWrapperCodegen. - FX-level Compat Check -
get_first_incompatible_cudagraph_node()scans the original FX graph forforbidden_setops (e.g.,aten._local_scalar_dense) that always disable CUDA Graphs. - CUDA Graph Wrapping -
CompiledFxGraph.post_compile()wraps compiled code withcudagraphify()to enable CUDA Graph capture at runtime.
Note that there are two levels of compatibility checking: the IR-level check during partitioning (via should_partition()) determines partition boundaries, while the FX-level check (via get_first_incompatible_cudagraph_node()) can disable CUDA Graphs entirely for certain ops. If the FX-level check finds a forbidden_set op, it disables CUDA Graphs for all partitions within that compiled graph—the partitioning work still happens, but none of the partitions will be cudagraphified.
The following sections explain these stages in detail.
🔨 Codegen: Partitioning & Triton Generation Link to heading
The key entry point is Scheduler.codegen(), which dispatches based on the optional graph_partition config:
def codegen(self) -> None:
return (
self._codegen_partitions() # graph_partition=True (default)
if config.graph_partition
else self._codegen(self.nodes) # graph_partition=False
)
With graph_partition=True (default), _codegen_partitions() first calls graph_partition() to split the Inductor IR at incompatible node boundaries. The splitting decision is made by should_partition(). The following operations will introduce a partition boundary:
- Non-GPU ops - Operations that run on CPU or other non-CUDA devices
- Device copies (
ir.DeviceCopy) - Cross-device transfers like.cpu()or.cuda() - Conditional ops (
ir.Conditional) - Control flow that can’t be captured in a static graph liketorch.cond()ortorch.while_loop() - Unbacked symbolic bindings - Dynamic shapes that aren’t backed by concrete values at capture time, e.g.
x[x > 0]. Note: this triggers partitioning, but the FX-level check (described below) will also detect unbacked symbols and disable CUDA Graphs entirely, making this partitioning ineffective for unbacked symbols in practice. cudagraph_unsafetagged ops - Custom operators explicitly marked as incompatible with CUDA Graphs
After partitioning, _codegen() generates Triton kernels and wrapper code for each partition. Cudagraphable partitions become separate wrapper functions; non-cudagraphable ops are inlined in the main call function to execute in eager mode. If no cudagraphable partitions exist, CUDA Graph is disabled entirely.
With graph_partition=False, no partitioning occurs—all code goes into a single call function. The decision to disable CUDA Graph for incompatible ops is handled later in the FX-level compatibility check.
Note that partitioning only affects the wrapper code structure—the same _codegen() generates identical Triton kernels regardless of the partition setting.
Here’s an example that triggers partitioning:
def fn(x):
x = torch.relu(x)
cpu_val = x.sum().cpu() # Device copy triggers partition
x = torch.softmax(x, dim=-1)
return x, cpu_val
compiled_fn = torch.compile(fn, mode="reduce-overhead")
The .cpu() call becomes prims.device_put in the FX graph, then DeviceCopy in Inductor IR. Running with TORCH_COMPILE_DEBUG=1 shows:
cudagraph partition due to non gpu ops
cudagraph partition into 2 partitions
The diagram below illustrates how DeviceCopy splits the graph into two cudagraphable partitions:
(Pointwise)"] --> O2["sum
(Reduction)"] --> O3[".cpu()
(DeviceCopy)"] --> O4["softmax
(Reduction + Pointwise)"] end Original -->|"graph_partition()"| After subgraph After["After Partitioning"] direction LR subgraph CG1["🟢 CUDA Graph 1"] direction LR A1["relu"] --> A2["sum"] end CG1 --> Eager["⚠️ .cpu()
(eager)"] Eager --> CG2 subgraph CG2["🟢 CUDA Graph 2"] direction LR B1["softmax"] end end
🔍 FX-level Compatibility Check Link to heading
After code generation completes, several checks determine whether CUDA Graphs should be disabled entirely:
1. Dynamic shapes check - If config.triton.cudagraph_skip_dynamic_graphs=True (default: False) and the graph has symbolic shape inputs, CUDA Graphs are disabled entirely. When False (default), dynamic shapes are supported by re-recording a new CUDA graph for each unique set of input sizes—this is the int_key dispatch mechanism in CUDA Graph Trees.
2. Incompatible node check - get_first_incompatible_cudagraph_node() scans the original FX graph for ops that disable CUDA Graphs:
forbidden_setops always disable CUDA Graphs for the entire compiled function:aten._local_scalar_densefrom.item(), which only appears in the Dynamo graph whencapture_scalar_outputs=True, otherwise.item()breaks the Dynamo graphrun_and_save_rng_stateandrun_with_rng_statefrom activation checkpointing (torch.utils.checkpoint) with random ops
Non-deterministic ops: when
torch.are_deterministic_algorithms_enabled(), scatter/index_put ops are also added toforbidden_setcudagraph_unsafetagged ops disable CUDA Graphs only whengraph_partition=False. With partitioning enabled (default), these ops trigger IR-level partitioning instead.Unbacked symbols always disable CUDA Graphs. If any node has an output with unbacked symbols (in shape or storage offset), CUDA Graphs are disabled—even if the final output has a backed shape. Common sources of unbacked symbols include:
Unbacked shape (data-dependent output size):
torch.nonzero(),torch.masked_select()- output size depends on tensor valuestorch.unique(),torch.unique_consecutive()- number of unique elements is data-dependent (fake_impls.py#L324)torch.repeat_interleave()- output size depends on repeat countstorch.bincount()- output size depends on input valuestorch.nn.utils.rnn.pack_padded_sequence()- packed batch size is data-dependent- Nested tensors (
torch.nested) with variable-length sequences
Unbacked storage offset (data-dependent memory location):
torch.select(x, dim, idx)orx[:, idx]whenidxis a data-dependent value whose sign can’t be determined at compile time—shape is known, but storage offset is unbacked
For example,
x[x > 0].sum()disables CUDA Graphs because the intermediate indexing result has data-dependent shape, even though.sum()produces a 0-d tensor with static shape[].
3. Device check - check_lowering_disable_cudagraph() disables CUDA Graphs if the graph spans multiple CUDA devices. CPU nodes also disable CUDA Graphs, but only when graph_partition=False; with partitioning enabled (default), CPU ops are handled by IR-level partitioning instead.
📌
.cpu()vs.item(): These behave very differently with CUDA Graphs:
.cpu()becomesDeviceCopyin Inductor IR → partitionable. Surrounding GPU ops can still be cudagraphified; the.cpu()runs in eager mode between graph replays..item()dispatches toaten._local_scalar_dense→ forbidden. Disables CUDA Graph entirely because it requires CPU-GPU synchronization and the returned value typically affects subsequent computation.
The key difference from IR-level partitioning: these FX-level checks always disable CUDA Graphs for the entire compiled graph, regardless of the graph_partition setting.
📦 Post-Compile Stage: CUDA Graph Wrapping Link to heading
After code generation and compatibility checks, CompiledFxGraph.post_compile() wraps the compiled code with CUDA Graph logic. If disable_cudagraphs_reason was set by the FX-level check, CUDA Graphs are skipped entirely. Otherwise, the wrapping proceeds based on whether partitioning is enabled:
• Memory pool sharing
• Re-recording support"] P5 -->|False| P7["Legacy impl
• Single recording
• Simpler"]
With graph_partition=True (default), cudagraph_partition_post_compile() iterates over each cudagraphable partition and wraps it with cudagraphify(). Partition functions are guaranteed cudagraphable since non-cudagraphable ops were already inlined in the main call function during codegen.
With graph_partition=False, cudagraph_post_compile() performs additional runtime checks before wrapping the entire callable. These checks populate cudagraph_fail_reasons and include: mutated inputs, complex memory overlap, and non-tensor inputs. If any check fails, CUDA Graphs are skipped but the compiled Triton kernels still run. There’s no try-and-fallback: the decision is made via static analysis before capture.
Notice that CUDA Graph capture doesn’t happen at compile time—it happens at first runtime execution (warmup → record → replay).
Once wrapping proceeds, cudagraphify() selects between two implementations based on the cudagraph_trees config:
def cudagraphify(model, static_input_idxs, *, device_index, ...):
if config.triton.cudagraph_trees:
cudagraphify_fn = new_cudagraphify_impl # Tree-based (default)
else:
cudagraphify_fn = cudagraphify_impl # Legacy (no pool sharing)
The tree-based approach (default) supports memory pool sharing across forward/backward graphs and can re-record new branches when execution paths change. The legacy approach is simpler but uses more memory since each graph has its own pool.
🔄 Legacy Implementation Link to heading
The legacy cudagraphify_impl() follows the classic CUDA Graph pattern. Note that warmup and capture happen when the function is first called at runtime, not at compile time:
def cudagraphify_impl(model, inputs, static_input_idxs):
# 1. Allocate static buffers for inputs
static_inputs = [static_input(x) if idx not in static_input_idxs else x
for idx, x in enumerate(inputs)]
# 2. Warmup on separate stream
stream = torch.cuda.Stream()
with torch.cuda.stream(stream):
model(list(static_inputs))
# 3. Record the graph
graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(graph, stream=stream):
static_outputs = model(list(static_inputs))
# 4. Return replay function
def run(new_inputs):
# Copy new data into static buffers
for idx in copy_indices:
static_inputs[idx].copy_(new_inputs[idx])
graph.replay()
return static_outputs
return run # Called at runtime for subsequent executions
• Allocate static buffers
• Handle expanded dims"] --> B B["2. Warmup Phase
• Separate stream
• Use static inputs"] --> C C["3. Record Phase
• torch.cuda.graph(g)
• No pool sharing"] --> D D["4. Return run() Function
• Copy inputs → buffers
• graph.replay()"]
The first step allocates static buffers for inputs. Here “static” and “dynamic” refer to whether the input tensor’s memory address stays the same across calls:
| Input Type | Description | Behavior |
|---|---|---|
| Static | Tensor reuses the same memory address (e.g., model parameters, outputs from previous CUDA Graph) | No copy needed; pointer captured directly |
| Dynamic | Tensor may have different memory address each call (e.g., batch data from dataloader) | Allocate a static buffer; copy data on each replay |
Shape changes trigger full Dynamo recompilation with the legacy implementation. Since the legacy approach records a single CUDA graph with fixed tensor shapes, it forces specialization on all symbolic inputs by calling int(t) on each SymInt. This creates static guards that fail on any shape change, requiring a new Dynamo+Inductor compilation for each unique shape.
While simple, the legacy implementation has one key limitation: no memory pool sharing. Each CUDA Graph allocates its own memory, which can’t be reused across graphs. This increases memory consumption, especially when multiple graphs (e.g., forward and backward) are captured.
To address the memory sharing limitation, PyTorch introduced CUDA Graph Trees as the default implementation.
🔗 The Memory Pool Sharing Problem Link to heading
The natural solution to reduce memory consumption is to let multiple CUDA Graphs share a memory pool. PyTorch supports this via the pool parameter:
graph1 = torch.cuda.CUDAGraph()
graph2 = torch.cuda.CUDAGraph()
with torch.cuda.graph(graph1):
y1 = func1(x1)
with torch.cuda.graph(graph2, pool=graph1.pool()): # Share pool!
y2 = func2(x2)
However, sharing a memory pool introduces a critical constraint: graphs must be replayed in the same order they were recorded. Let’s see what happens when this rule is violated:
def func1(x1):
t1 = x1 * 3 # x1=1.0 → t1=3.0
y1 = t1 + 5 # y1=8.0
return y1
def func2(x2):
y2 = x2 ** 2 # x2=2.0 → y2=4.0
return y2
x1 = torch.tensor([1.0], device='cuda')
x2 = torch.tensor([2.0], device='cuda')
# Record graph1 first, then graph2 (sharing pool)
with torch.cuda.graph(graph1):
y1 = func1(x1)
with torch.cuda.graph(graph2, pool=graph1.pool()):
y2 = func2(x2)
# Replay in WRONG order: graph2 first, then graph1
graph2.replay()
graph1.replay()
print(f"y1={y1.item()}, y2={y2.item()}")
Running this produces:
# During capture:
x1.data_ptr()=0x7f8743000000, t1.data_ptr()=0x7f8728600000, y1.data_ptr()=0x7f8728600200
x2.data_ptr()=0x7f8743000200, y2.data_ptr()=0x7f8728600000 ← Same as t1!
# Results:
Correct: y1.item()=8.0, y2.item()=4.0
Actual: y1.item()=8.0, y2.item()=3.0 ← y2 is WRONG!
Notice that y2 and t1 share the same memory address (0x7f8728600000). When graph2 was recorded, t1 was already dead, so the allocator reused its memory for y2. But when we replay in the wrong order—graph2 then graph1—graph1 writes t1=3.0 to that address after graph2 wrote y2=4.0, corrupting the result.
For sequential graphs that always execute in the same order, this constraint is easy to follow—just replay in recording order. But real-world workloads often have more complex execution patterns:
- Graph breaks produce multiple graphs that may execute in varying sequences
- Training loops interleave forward graphs, backward graphs, and optimizer steps—especially with pipeline parallelism where micro-batches overlap
- Conditional branches may skip certain graphs entirely on some iterations
In these cases, enforcing a strict linear replay order becomes impractical. What we really need is a way to share memory pools while supporting flexible execution paths.
⚠️ The Stale Allocator State Problem Link to heading
Beyond replay order, there’s another subtle issue when mixing CUDA Graph replay with new recordings. When you replay a CUDA graph, only GPU operations execute—the CPU-side allocator bookkeeping is not updated.
Consider a function with conditional branches:
@torch.compile(mode="reduce-overhead")
def foo(x):
y = x * x # Graph A
if y.sum() > 0:
z = y + 1 # Graph B
else:
z = y - 1 # Graph C
out = z * 2 # Graph D
return out
This creates a diamond-shaped graph where different inputs take different paths:
y = x * x"] --> B["Graph B
z = y + 1"] A --> C["Graph C
z = y - 1"] B --> D["Graph D
out = z * 2"] C --> D
Now suppose we’ve already recorded path A → B → D. On a new call, we replay A, but the input takes the else branch—we need to record Graph C for the first time.
CPU Allocator State (assuming y→0x1000, z→0x2000, out→0x3000; y.sum() ignored for simplicity):
| Stage | 0x1000 | 0x2000 | 0x3000 | Note |
|---|---|---|---|---|
| After record A | y | free | free | |
| After record B | y | z | free | |
| After record D | free | free | out | y, z recycled |
| After replay A | free (has y data) | free | out | Stale state! |
| Try to record C | 💥 z | free | out | Overwrites y! |
The problem: replay executes GPU operations but doesn’t update CPU-side allocator bookkeeping. After replaying A, the allocator still thinks 0x1000 is free (its state from end of D’s recording). When we record Graph C, the allocator assigns z to 0x1000—overwriting y which is still needed for z = y - 1.
This is the core challenge: how do we share memory pools (for efficiency) while correctly tracking tensor liveness across replay and new recordings?
🌳 CUDA Graph Trees Link to heading
The tree-based implementation (torch/_inductor/cudagraph_trees.py) solves these limitations through a sophisticated system of checkpointing and tensor liveness tracking. From the module docstring:
CUDA graph trees are a safety abstraction over CUDAGraphs, similar to make_graph_callables, which share the same memory pool. Sharing a memory pool is an extremely important optimization when chaining multiple CUDA graphs together, as it prevents you from needing to copy intermediate tensors from one graph to the next, and reduces overall memory usage by allowing dead memory from the first pool to be reused in the second.
Tree Structure Link to heading
As mentioned earlier, traditional CUDA Graphs with shared pools have two fundamental problems:
- Strict replay order—if you record graphs A then B, you must replay them in order A → B, otherwise memory corruption occurs;
- Memory clobbering—when graphs share a pool, later graphs reuse memory from earlier graphs, and if you still hold references to earlier outputs, they can be silently corrupted.
CUDA Graph Trees solve both problems by organizing recordings into a tree structure where each path represents a valid execution sequence. For the foo(x) example above, the tree structure looks like:
y = x * x
─────────
state: {y@0x1000}"] --> B["Graph B
z = y + 1
─────────
state: {z@0x2000}"] A --> C["Graph C
z = y - 1
─────────
state: {z@0x2000}"] B --> D1["Graph D1
out = z * 2
─────────
state: {out@0x3000}"] C --> D2["Graph D2
out = z * 2
─────────
state: {out@0x3000}"]
Note that Graph D appears twice in the tree—once as a child of B (path A→B→D1) and once as a child of C (path A→C→D2). Each path through the tree represents a valid execution sequence, and after each recording we save a checkpoint of the caching allocator state.
To understand how checkpointing enables new branch recording, consider this scenario: we’ve already recorded path A→B→D1, saved checkpoints after each recording, and replayed them for multiple times. Now a new input takes the else branch, requiring us to record Graph C for the first time:
- After replaying Graph A: GPU has computed
yat0x1000, but CPU allocator state is stale (still reflects the state after D1 was recorded, where0x1000was marked free) - Before recording Graph C: The tree manager restores Graph A’s checkpointed allocator state, which correctly shows
0x1000is allocated toy - Recording Graph C: Now when C computes
z = y - 1, the allocator knows0x1000is occupied, so it allocateszat a different address (e.g.,0x2000) - After recording Graph C: A new checkpoint is saved capturing this state, enabling future branches from C
- Recording Graph D2: The tree now has two paths to Graph D—one from B (already recorded as D1) and one from C (not yet recorded). Since we came through C, we need to record a new D as a child of C. The tree manager restores Graph C’s checkpoint, which shows
yat0x1000andzat0x2000 - After recording Graph D2: D2 computes
out = z * 2and is added as a child of C. A new checkpoint is saved, creating a second path A→C→D2 alongside the existing A→B→D1
This is why the tree structure is essential—each node’s checkpoint captures the allocator state at that point, allowing safe branching in any direction.
Key Components Link to heading
The implementation centers on three key components:
CUDAGraphTreeManageris the per-device manager that maintains the tree structure of recordings, tracks tensor liveness and generations, handles warmup → recording → execution transitions, and manages memory pool sharing and checkpointing.CUDAGraphNoderepresents a single recording—it stores the captured CUDA Graph, tracks parent/child relationships, and checkpoints allocator state after recording for later restoration.CUDAWarmupNodeis a simplified wrapper used during warmup runs. UnlikeCUDAGraphNode, it doesn’t record a CUDA Graph but runs the function eagerly within the CUDA graph memory pool.
The tree structure is actually a forest—there’s one tree per device, but each tree can have multiple roots. The roots dictionary maps each FunctionID to a list of CUDAGraphNode objects. A compiled function can have multiple root nodes when called with different invariants that can’t be replayed—for example, different input shapes, different static input addresses, or different tensor properties. Each root represents a distinct “entry point” for that function, and the tree branches from there based on execution paths.
All recordings share a single memory pool per device (torch.cuda.graph_pool_handle()). With CUDA Graph Trees, this enables efficient memory usage—for example, if you have paths A→B and A→B’, the total memory required is only max(mem(A,B), mem(A,B')) rather than mem(A,B) + mem(A,B'). The checkpointing mechanism makes this safe by tracking exactly which memory regions are live at each point in the tree.
Warmup Phase Link to heading
Before recording a CUDA Graph, the function must be “warmed up” to JIT compile kernels (e.g., Triton, cuBLAS) and stabilize memory allocations. Unlike the legacy implementation which warms up using global memory pool, CUDA Graph Trees runs warmup within the shared memory pool using CUDAWarmupNode.
This design choice is important for memory efficiency. If warmup ran outside the pool and then we kept a copy of inputs for recording, we’d incur a memory penalty. By running warmup in the pool, the allocations are immediately tracked and can be reused. From the source comments:
We run graph warmups in the cudagraph memory pool and return the result on the first invocation of a function. For many models it is important to reclaim activations as you run the backward. If we were to warm up the model and keep an extra copy of the inputs around to subsequently use for recording, we would incur a memory penalty.
Key differences between CUDAWarmupNode and CUDAGraphNode:
- No graph recording—warmup runs eagerly, just tracking output storages
- No input copying—inputs don’t need to be copied into static buffers
- Transient—
CUDAWarmupNodeis not stored in the tree; it’s discarded after warmup completes and the real tree is built only fromCUDAGraphNodeinstances
The warmup-to-recording transition happens automatically: after a function completes warmup, try_end_curr_warmup() sets current_node = None, discarding the CUDAWarmupNode. The next invocation will then record a CUDAGraphNode. The system clears cuBLAS workspace caches via clear_cublass_cache() before warmup and recording to prevent persistent allocations from conflicting with the CUDA graph pool. For models with stochastic operations (e.g., dropout), RNG generator states are registered with the graph via graph.register_generator_state() so that random values advance correctly on each replay.
Checkpoint and Restore Link to heading
As illustrated above, checkpoint save and restore is the key to enabling new branch recording. Each CUDAGraphNode saves a checkpoint of the allocator state after each recording using torch._C._cuda_getCheckpointState() (see cudagraph_trees.py#L1373-L1375):
# At end of recording (in CUDAGraphNode._record)
self.checkpointed_caching_state = torch._C._cuda_getCheckpointState(
self.device, self.cuda_graphs_pool
)
When recording a new graph that follows an existing node (i.e., current_node is not None), apply_checkpoint_execution_state_in_allocator() restores the allocator state using torch._C._cuda_setCheckpointPoolState(). Note that for the very first recording (root node), no restore is needed since the allocator is already in a clean state:
def apply_checkpoint_execution_state_in_allocator(self):
# 1. Get checkpoint saved when this node was recorded
state = self.current_node.checkpointed_caching_state
# 2. Find tensors that are CURRENTLY live (via weakrefs)
live_storages = list(self.current_node.path_live_weakrefs())
# 3. Find tensors that DIED since replay (in eager code between graphs)
ptrs_to_deallocate = self.current_node.data_ptrs_dead_since_invocation()
# 4. Restore allocator state, telling it what's actually live
torch._C._cuda_setCheckpointPoolState(device, state, [], live_storages)
# 5. Free memory for tensors that died in eager region
for ptr in ptrs_to_deallocate:
torch._C._cuda_cudaCachingAllocator_raw_delete(ptr)
Liveness Tracking Link to heading
Simply restoring the allocator checkpoint from the parent CUDAGraphNode is not enough. Consider what happens between graph recordings: Graph A replays → eager code runs → some output tensors go out of scope and are garbage collected → new Graph B needs to be recorded. The parent’s checkpoint reflects the allocator state at the end of recording, when all outputs were still alive. If we blindly restore that state, the allocator would believe those now-dead tensors are still occupying memory, preventing their memory from being reused.
To solve this, CUDA Graph Trees tracks tensor liveness using weak references. Each CUDAGraphNode maintains path_weakrefs—weak references to the storage of all outputs along the path from root to current CUDAGraphNode. When path_live_weakrefs() is called, it iterates through these weak references and returns only those that are still alive (i.e., the underlying storage hasn’t been garbage collected). Similarly, data_ptrs_dead_since_invocation() compares the recorded liveness at the end of graph execution against the current liveness, identifying which tensors have died in the interim.
With this liveness information, the restore process can accurately reclaim memory: it first restores the allocator checkpoint via torch._C._cuda_setCheckpointPoolState(), then explicitly frees the dead tensors’ memory via torch._C._cuda_cudaCachingAllocator_raw_delete(). This ensures the allocator state precisely reflects what’s actually in use, making it safe to record a new branch. For more details on the underlying allocator checkpointing mechanism, see Note [Checkpointing PrivatePoolState].
Execution Flow Link to heading
With checkpoint and restore in place, CUDAGraphTreeManager._run() orchestrates the complete lifecycle—warmup, recording, and replay. The simplified logic is shown below. The function first checks for input mutations that would require eager fallback. If a function hasn’t been warmed up yet, it restores the checkpoint and runs eagerly. For warmed-up functions, it tries to find a matching child node for replay (the fast path). If no match exists, it restores the checkpoint and records a new branch.
def _run(self, new_inputs, function_id):
# End current recording/warmup if needed
if self.in_recording:
self.try_end_curr_recording(function_id)
# Check for input mutations → fall back to eager execution
if self.non_cudagraph_managed_mutation_hint[...]:
return self.ids_to_funcs[function_id].model(new_inputs)
# Need warmup? Restore checkpoint first if we were in execution state
if function_id not in self.warmed_up_functions:
if self.path_state == ExecutionState.EXECUTION:
self.apply_checkpoint_execution_state_in_allocator()
return self.run_eager(new_inputs, function_id)
# Try to find matching child node for replay (fast path)
for child in child_nodes[function_id]:
if child.check_invariants(new_inputs) == SUCCESS:
return self.execute_node(child, new_inputs)
# No match found → restore checkpoint and record a new branch
if self.current_node is not None:
self.apply_checkpoint_execution_state_in_allocator()
return self.record_function(new_inputs, function_id)
Input Mutation Handling Link to heading
Notice the code above has a specific check for input mutations (non_cudagraph_managed_mutation_hint). CUDA Graph Trees handle input mutations carefully. Consider this example:
@torch.compile(mode="reduce-overhead")
def mutating_fn(x):
x.add_(1) # In-place mutation
return x * 2
Note that CUDA graphs require static memory addresses for inputs. For user-provided tensors that change each iteration, CUDA Graph Trees creates a static copy (e.g., static_x) and copies data into it before each replay. When a compiled function mutates its inputs, the system checks via check_for_mutation() whether the mutated inputs are “CUDA graph managed”—either static inputs (parameters/buffers) or outputs from previous CUDA graphs in the tree. Mutations on these managed tensors are safe because their memory addresses are stable across replays. However, if a function mutates a dynamic input (like x above), the recorded graph would mutate static_x, not the user’s original x. This means the mutation wouldn’t be visible to the caller, leading to incorrect results. In this case, the system falls back to eager execution to ensure correctness.
Generation Tracking Link to heading
While checkpointing handles memory consistency for branch recording, CUDA Graph Trees also need to manage memory reuse across iterations. Without this, the system would have to keep all previous outputs alive forever (since Python references might still exist), preventing memory reuse between iterations. Consider this example:
@torch.compile(mode="reduce-overhead")
def my_model(x):
return torch.matmul(x, x)
x = torch.randn(10, 10, device="cuda")
y1 = my_model(x) # First invocation: output at 0x1000
y2 = my_model(x) # Second invocation: reuses 0x1000, overwrites y1's data
print(y1) # Problem: y1 now contains y2's data (corrupted!)
The second invocation reuses the same memory address 0x1000 for y2, overwriting y1’s data. If the user still holds a reference to y1 and accesses it, they would get corrupted data. Rather than returning corrupted data silently, CUDA Graph Trees detects this situation and raises a clear error:
RuntimeError: Error: accessing tensor output of CUDAGraphs that has been overwritten by a subsequent run.
To prevent overwriting, clone the tensor outside of torch.compile() or call torch.compiler.cudagraph_mark_step_begin() before each model invocation.
If you need to preserve outputs across multiple iterations, follow the error message’s advice and clone them outside of torch.compile:
y1 = my_model(x).clone() # Clone to preserve across iterations
y2 = my_model(x) # Now y1 is safe
This automatic detection in CUDA Graph Trees is implemented using “generations” (self.current_gen) tracked via GenerationTracker.generation in TorchDynamo. When a new torch.compile invocation (i.e., a new call to a function wrapped with torch.compile) occurs, the generation increments. The decision to start a new generation is made by can_start_new_generation():
def can_start_new_generation(self) -> bool:
if not self.in_new_torch_compile_invocation():
return False
if self.user_invoked_mark_step():
return True
return not self.running_forwards_with_pending_backwards
When this returns True, dealloc_current_path_weakrefs() is called to invalidate previous outputs before allowing memory reuse. Now when the user tries to access y1, PyTorch raises a clear RuntimeError: accessing tensor output of CUDAGraphs that has been overwritten by a subsequent run rather than returning corrupted data.
The automatic generation detection logic handles inference and training differently, as their memory reuse requirements differ:
Inference Mode. For inference (or when using
torch.no_grad()), each new invocation starts a new generation immediately. This means outputs from the previous invocation are invalidated and their memory is reused.Training Mode. Both forward and backward passes are recorded as nodes in the same CUDA graph tree, sharing the same memory pool per device. Each function is registered with its
CompilationMode(FORWARD,BACKWARD, orINFERENCE), and a typical training iteration forms a path like:Forward₁ → Forward₂ → ... → Backwardₙ → ... → Backward₁(backward runs in roughly reverse order of forward).For training, the generation heuristic must ensure forward outputs survive until backward completes. The system tracks
running_forwards_with_pending_backwards: when a forward runs, this flag is set toTrue; when a backward runs, it’s cleared (seeNote: [Backward Generation Handling]). While this flag isTrue,can_start_new_generation()returnsFalse, preventing premature memory reuse.Manual Control. If the automatic heuristics above don’t fit your use case, you can explicitly mark iteration boundaries by calling
torch.compiler.cudagraph_mark_step_begin()before each model invocation. This incrementsMarkStepBox.mark_step_counter, causinguser_invoked_mark_step()to returnTrueand bypassing therunning_forwards_with_pending_backwardscheck.
Note that
cudagraph_mark_step_begin()doesn’t prevent the memory overwrite—it enables earlier detection. When a new generation starts, previous outputs are invalidated so that accessing them raises a clear error instead of returning corrupted data. If you need to preserve outputs across iterations, you must still clone them as shown in the example above.
Re-recording Limits Link to heading
Before replaying an existing graph, the system calls CUDAGraphNode.check_invariants() to verify:
- CUDA graph managed tensor addresses—outputs from previous graphs in the path must have stable addresses
- Tensor liveness pattern—tensors expected to be dead before this graph must still be dead
- Static input addresses (when
rerecord_if_static_inputs_change=True)—parameter/buffer addresses must be stable
If any check fails, the system returns a CheckInvariantStatus indicating the mismatch type. When no child passes all checks, a new graph is recorded as a sibling branch. Common causes include:
- Static input address changes: When
inline_inbuilt_nn_modules=False(legacy behavior), parameter tensors are passed as inputs with expected stable addresses. If you reassignmodel.param.data = ...or the optimizer moves tensors, re-recording is triggered. Seetest_rerecord_if_static_input_address_changed. - Tensor liveness pattern changes: Consider the sequence: Graph A replays → eager code runs and deallocates some of A’s outputs → Graph B needs to be recorded. Before recording B, the system restores A’s allocator checkpoint and frees the tensors that died in eager mode. Graph B is then recorded with
expected_dead_indicescapturing which tensors were dead at that point. On subsequent runs, if a tensor that was dead during recording is now alive, replaying B would clobber that tensor’s memory—so re-recording is required. Seecheck_invariants(). Note that the reverse case is safe: if a tensor was alive during recording but dead before replay, the graph simply won’t access memory that’s now free.
To illustrate, consider this example where foo() returns two outputs and bar() consumes the first:
@torch.compile(mode="reduce-overhead")
def foo(x):
return x + 1, x + 2 # Two outputs: y1, y2
@torch.compile(mode="reduce-overhead")
def bar(y):
return y * 2
for i in range(3):
torch.compiler.cudagraph_mark_step_begin()
x = torch.randn(4, device="cuda")
y1, y2 = foo(x)
if i == 1:
del y2 # In iter 1 only: second output dies before bar()
z = bar(y1)
The tree structure evolves as follows:
- Iter 0 (
y2alive): Warmup runs for both functions—no graphs recorded yet. - Iter 1 (
y2dead):foo()is recorded asGraph[0]with 2 outputs. Theny2is deleted, so whenbar()is recorded asGraph[1], it capturesexpected_dead_indices = [(0, 1)]—meaning output index 1 from Graph[0] (i.e.,y2) must be dead before replay. - Iter 2 (
y2alive):foo()replaysGraph[0]. But nowy2is kept alive, so whenbar()tries to replayGraph[1], the liveness check fails—y2should be dead but isn’t. A newGraph[2]is recorded as a sibling branch without the dead expectation.
══════════════════════════════════════════════════
CUDA Graph Tree (device 0)
Graphs in tree: 3
══════════════════════════════════════════════════
└── Graph[0] foo() outputs=2
├── Graph[1] bar() outputs=1 [expects_dead: [(0, 1)]]
└── Graph[2] bar() outputs=1
The [expects_dead: [(0, 1)]] annotation shows that Graph[1] expects foo()’s second output to be dead. Since Graph[2] has no such expectation, it can replay when both outputs are alive.
Excessive re-recording indicates instability and hurts performance. The system tracks re-record counts via num_rerecord and falls back to eager execution when exceed_rerecord_limit() returns True:
def exceed_rerecord_limit(self, node_id, function_id) -> bool:
if torch._dynamo.config.inline_inbuilt_nn_modules:
return False # Skip limit when inlining builtin nn modules
return (
self.num_rerecord[node_id][function_id]
> config.triton.cudagraph_unexpected_rerecord_limit # Default: 128
)
The limit is configurable via torch._inductor.config.triton.cudagraph_unexpected_rerecord_limit. If you see frequent re-recording warnings, investigate the root cause rather than simply increasing the limit.
Dynamic Shape Handling Link to heading
Dynamic shapes are handled at multiple levels in the torch.compile stack. Understanding when each mechanism applies is important:
| Level | When | What Happens | Config |
|---|---|---|---|
| Dynamo guards | Guard failure (dtype, device, requires_grad, or first shape change) | Full recompilation: new FX graph → new Inductor compilation → new CUDA graph | automatic_dynamic_shapes |
| FX-level check | Compile time, unbacked symbolic shapes present | CUDA Graph disabled entirely for that graph | cudagraph_skip_dynamic_graphs |
| int_key dispatch | Runtime, same compiled graph with varying backed symints | New FunctionID per unique int_key, new root node in tree | cudagraph_capture_sizes |
Key distinctions:
Level 1 (Dynamo guards): With
automatic_dynamic_shapes=True(default), the first compilation uses static shapes (hardcoded in the generated Triton kernel). When the shape changes for the first time, a guard fails (e.g.,"tensor 'x' size mismatch at index 0. expected 4, actual 8"), triggering recompilation with that dimension marked as symbolic. Subsequent shape changes on that dimension won’t cause recompilation. Non-shape guards (dtype, device, requires_grad) always trigger recompilation since they are not symbolic. This happens before CUDA Graph Trees.Level 2 (FX-level): At compile time, unbacked symbolic shapes always disable CUDA Graphs. These are dimensions whose sizes cannot be determined until the code actually runs—for example,
x[x > 0]where the output size depends on how many elements satisfy the condition, not just the input shape. This is different from backed symints which have concrete runtime values.Level 3 (int_key dispatch): At runtime, when the compiled graph has backed symbolic shapes (concrete integer values like batch size), the shape values are passed as scalar kernel arguments. Since CUDA Graphs capture kernel arguments at record time, each unique combination needs its own CUDA Graph recording. This is where
int_keydispatch comes in.
After passing Dynamo guards (Level 1) and FX-level checks (Level 2), CUDA Graph Trees handles varying backed symints via cudagraphify_impl(), which extracts integer inputs and uses them as a cache key:
fn_cache: dict[tuple[int, ...], Callable[..., Any]] = {}
def deferred_cudagraphify(inputs):
int_key = get_ints(inputs) # Extract integer inputs (e.g., batch size)
if not is_cudagraph_capture_sizes(int_key):
return model(inputs) # Fall back to eager if not in capture set
fn = fn_cache.get(int_key)
if fn is not None:
return fn(inputs) # Reuse existing function for this int_key
fn, out = cudagraphify(...) # Register new FunctionID in tree manager
fn_cache[int_key] = fn
return out
Each unique int_key creates a new FunctionID in the same CUDAGraphTreeManager. When recording starts (at the beginning of an iteration), this becomes a new root node for that FunctionID—effectively a sibling root alongside roots from other int_key values. The fn_cache is a local optimization to avoid re-registering the same function.
If too many distinct sizes are recorded, a warning is emitted suggesting to pad inputs or disable CUDA graphs for dynamic shapes via cudagraph_skip_dynamic_graphs=True. The set of sizes to capture can be restricted via torch._inductor.config.triton.cudagraph_capture_sizes—any int_key not in this set falls back to eager execution.
📋 Summary Link to heading
CUDA Graph integration in torch.compile happens at the end of Inductor, after Triton kernel codegen. Minimizing graph breaks is important to get larger CUDA graphs with greater performance benefits.
A single Inductor graph can be partitioned into multiple CUDA graphs at .cpu(), RNG state save/restore, or non-deterministic ops. Some ops disable CUDA graphs entirely:
.item()(dispatches toaten._local_scalar_dense) - in the FX-level forbidden settorch.cond/torch.while_loop- higher-order control flow ops setdisable_cudagraphs_reasonduring IR lowering- Unbacked symbolic shapes - data-dependent output sizes like
x[x > 0]
There are two implementations:
- Legacy: Simple capture/replay with no memory pool sharing. Forces shape specialization via
int(t)on symbolic inputs, causing full Dynamo recompilation for each unique shape. - CUDA Graph Trees (default): Single shared memory pool with tree-structured branching support. The structure is actually a
forest—one tree per device, where each compiled function can have multiple root nodes for different input invariants. Handles dynamic shapes via
int_keydispatch—records separate CUDA graphs per shape without Dynamo recompilation.
CUDA Graph Trees provides several key mechanisms:
- Memory pool sharing: All graphs share one pool per device, requiring only
max(path₁, path₂)memory instead of the sum - Warmup in pool: Warmup runs inside the shared memory pool via
CUDAWarmupNode, avoiding extra memory overhead - Checkpointing: Restore/save allocator state before/after recording to handle CPU allocator state staleness during replay and enable safe branch recording (e.g.,
if/elsepaths) - Liveness tracking: Track tensor liveness using weak references to detect tensors deallocated between recordings
- Input mutation handling: Falls back to eager execution if user-provided input is mutated
- Generation tracking: Detects cross-iteration access to overwritten outputs, raising
RuntimeErrorinstead of silent corruption - Re-recording limits: Falls back to eager execution if a function exceeds the re-recording limit (default: 128)
To debug CUDA graph behavior, use TORCH_LOGS="+cudagraphs" or TORCH_COMPILE_DEBUG=1 environment variables.
🔗 References Link to heading
Documentation:
- PyTorch CUDA Graphs Documentation
- CUDA Graph Trees Design Doc
- torch.compile API Reference
- NVIDIA CUDA Graphs Guide
Key Source Files (PyTorch v2.9.0):
cudagraph_trees.py- Tree-based CUDA Graph implementationcompile_fx.py-cudagraphify()and legacy implementationoutput_code.py-CompiledFxGraphand post-compile logicscheduler.py- Graph partitioning logicutils.py- Compatibility checks (get_first_incompatible_cudagraph_node)config.py- Configuration options
