CUDA Graphs and Inference Optimization: Eliminating Launch Overhead
One of the biggest bottlenecks in LLM inference isn't the GPU computation itself—it's the CPU overhead of launching kernels. Every time you call a PyTorch operation, the CPU has to prepare the kernel, package parameters, and submit it to the GPU. This overhead adds up quickly. CUDA graphs solve this by pre-recording the entire forward pass and replaying it.
The Problem: Kernel Launch Overhead
When you call a PyTorch operation, here's what actually happens:
1. Python calls PyTorch operator
2. PyTorch dispatches to CUDA backend
3. CUDA driver prepares the kernel
4. Kernel parameters are packaged
5. Command is inserted into GPU command queue
6. GPU scheduler starts the kernel
7. GPU actually computes
Steps 1-6 are CPU overhead. Even if the GPU work is tiny, this setup still happens. A single kernel launch costs roughly 5–20 microseconds.
For a model with 32 layers, each with attention and MLP, you might have hundreds of kernels per forward pass. If you're generating 200 tokens, that's 200 forward passes. The launch overhead becomes significant.
How CUDA Graphs Work
A CUDA graph records the entire sequence of kernels once, then replays it:
Normal execution:
CPU launches kernel 1 → GPU computes
CPU launches kernel 2 → GPU computes
CPU launches kernel 3 → GPU computes
... (hundreds of times)
CUDA graph execution:
CPU records: kernel 1 → kernel 2 → kernel 3 → ...
CPU replays the entire graph
GPU executes all kernels in sequence
The key insight: You only pay the launch overhead once during recording, not during replay.
Capturing CUDA Graphs
During ModelRunner initialization, capture_cudagraph() runs:
for bs in reversed(self.graph_bs): # [1, 2, 4, 8, 16, 32] in reverse
graph = torch.cuda.CUDAGraph()
# Step 1: Warmup run
outputs[:bs] = self.model(input_ids[:bs], positions[:bs])
# Step 2: Record the forward pass
with torch.cuda.graph(graph, self.graph_pool):
outputs[:bs] = self.model(input_ids[:bs], positions[:bs])
# Step 3: Save for later replay
self.graphs[bs] = graph
Why Warmup?
The warmup run is critical:
- Initializes CUDA kernels (first run is always slower)
- Stabilizes GPU memory allocation
- Prevents memory fragmentation during graph capture
Without warmup, the graph might capture memory allocation operations, making replay slower.
Why Reverse Order?
Graphs are captured from largest to smallest batch size:
for bs in reversed(self.graph_bs): # 32 → 16 → 8 → 4 → 2 → 1
Why? The largest batch uses the most memory. By capturing it first, you create a memory pool. Smaller graphs then reuse the same memory pool, saving GPU memory and avoiding fragmentation.
Memory Pool Reuse
if self.graph_pool is None:
self.graph_pool = graph.pool() # First graph creates pool
All subsequent graphs reuse this pool:
with torch.cuda.graph(graph, self.graph_pool):
# Reuse memory from larger graphs
This is a massive memory optimization. Without it, each graph would allocate its own memory, fragmenting the GPU.
The Trick: Static Tensors with Dynamic Values
CUDA graphs require static memory addresses. You can't give the graph "new" tensors—the addresses must be the same every time.
The solution: Pre-allocate static tensors and update their values before replay.
# Pre-allocate static tensors
input_ids = torch.zeros(max_batch_size, device='cuda')
positions = torch.zeros(max_batch_size, device='cuda')
outputs = torch.empty(max_batch_size, vocab_size, device='cuda')
# Record graph using these static tensors
with torch.cuda.graph(graph):
outputs[:bs] = self.model(input_ids[:bs], positions[:bs])
# Save references
self.graph_vars = dict(
input_ids=input_ids,
positions=positions,
outputs=outputs,
)
At inference time:
# Update static tensors with new data
bs = actual_batch_size
self.graph_vars['input_ids'][:bs] = new_tokens
self.graph_vars['positions'][:bs] = new_positions
# Replay the graph
self.graphs[bs].replay()
# Results are already in self.graph_vars['outputs']
logits = self.graph_vars['outputs'][:bs]
This is the magic: static graph speed + dynamic inputs.
Why Decode Benefits, Prefill Doesn't
CUDA graphs are perfect for decode but not for prefill. Here's why:
Decode:
- Fixed batch structure (one token per sequence)
- Same number of sequences each step
- Same batch size each step
- Perfect for CUDA graphs
Prefill:
- Variable batch structure (different sequences, different lengths)
- Number of sequences changes
- Batch size changes
- Can't use CUDA graphs effectively
This is why the code only uses CUDA graphs for decode:
def run_model(self, input_ids, positions, is_prefill):
if is_prefill or self.enforce_eager or input_ids.size(0) > 512:
# Eager execution
return self.model.compute_logits(self.model(input_ids, positions))
else:
# CUDA graph for decode
return self.run_cudagraph(input_ids, positions)
Batch Size Bucketing
You can't capture a graph for every possible batch size. Instead, you capture for specific sizes and bucket:
self.graph_bs = [1, 2, 4, 8, 16, 32] # Supported batch sizes
# At inference time:
bs = actual_batch_size # e.g., 5
graph_bs = next(x for x in self.graph_bs if x >= bs) # Use graph for 8
graph = self.graphs[graph_bs]
If you have 5 sequences, you use the graph captured for batch size 8. The extra 3 slots are filled with dummy data and masked out.
This trades a small amount of wasted computation for the ability to use CUDA graphs.
The Performance Impact
How much does CUDA graph help?
For a typical decode step with batch size 32:
- Eager execution: ~50-100 microseconds of launch overhead per kernel × 100+ kernels = 5-10ms
- CUDA graph: ~1-2ms total (one replay operation)
That's a 3-5x speedup just from eliminating launch overhead.
For large-scale inference with thousands of tokens generated, this adds up to significant wall-clock time savings.
When CUDA Graphs Don't Help
CUDA graphs are most effective when:
- The same computation repeats many times
- The batch structure is fixed
- You're GPU-bound, not CPU-bound
They don't help when:
- The computation changes each step (prefill)
- You're CPU-bound (rare in inference)
- The GPU work is tiny (launch overhead is negligible)
The Bigger Picture
CUDA graphs are one piece of the optimization puzzle. Combined with:
- Paged attention (efficient KV cache)
- Packed-ragged batching (prefill efficiency)
- Continuous batching (GPU utilization)
They enable modern inference engines to achieve 10-100x speedups over naive implementations.
The key insight: Optimization isn't about one trick—it's about eliminating bottlenecks at every level, from memory layout to kernel launch overhead to batch scheduling.