Back to LLM Inference

CUDA Graphs and Inference Optimization: Eliminating Launch Overhead

One of the biggest bottlenecks in LLM inference isn't the GPU computation itself—it's the CPU overhead of launching kernels. Every time you call a PyTorch operation, the CPU has to prepare the kernel, package parameters, and submit it to the GPU. This overhead adds up quickly. CUDA graphs solve this by pre-recording the entire forward pass and replaying it.

The Problem: Kernel Launch Overhead

When you call a PyTorch operation, here's what actually happens:

1. Python calls PyTorch operator
2. PyTorch dispatches to CUDA backend
3. CUDA driver prepares the kernel
4. Kernel parameters are packaged
5. Command is inserted into GPU command queue
6. GPU scheduler starts the kernel
7. GPU actually computes

Steps 1-6 are CPU overhead. Even if the GPU work is tiny, this setup still happens. A single kernel launch costs roughly 5–20 microseconds.

For a model with 32 layers, each with attention and MLP, you might have hundreds of kernels per forward pass. If you're generating 200 tokens, that's 200 forward passes. The launch overhead becomes significant.

How CUDA Graphs Work

A CUDA graph records the entire sequence of kernels once, then replays it:

Normal execution:
CPU launches kernel 1 → GPU computes
CPU launches kernel 2 → GPU computes
CPU launches kernel 3 → GPU computes
... (hundreds of times)

CUDA graph execution:
CPU records: kernel 1 → kernel 2 → kernel 3 → ...
CPU replays the entire graph
GPU executes all kernels in sequence

The key insight: You only pay the launch overhead once during recording, not during replay.

Capturing CUDA Graphs

During ModelRunner initialization, capture_cudagraph() runs:

for bs in reversed(self.graph_bs):  # [1, 2, 4, 8, 16, 32] in reverse
    graph = torch.cuda.CUDAGraph()
    
    # Step 1: Warmup run
    outputs[:bs] = self.model(input_ids[:bs], positions[:bs])
    
    # Step 2: Record the forward pass
    with torch.cuda.graph(graph, self.graph_pool):
        outputs[:bs] = self.model(input_ids[:bs], positions[:bs])
    
    # Step 3: Save for later replay
    self.graphs[bs] = graph

Why Warmup?

The warmup run is critical:

Without warmup, the graph might capture memory allocation operations, making replay slower.

Why Reverse Order?

Graphs are captured from largest to smallest batch size:

for bs in reversed(self.graph_bs):  # 32 → 16 → 8 → 4 → 2 → 1

Why? The largest batch uses the most memory. By capturing it first, you create a memory pool. Smaller graphs then reuse the same memory pool, saving GPU memory and avoiding fragmentation.

Memory Pool Reuse

if self.graph_pool is None:
    self.graph_pool = graph.pool()  # First graph creates pool

All subsequent graphs reuse this pool:

with torch.cuda.graph(graph, self.graph_pool):
    # Reuse memory from larger graphs

This is a massive memory optimization. Without it, each graph would allocate its own memory, fragmenting the GPU.

The Trick: Static Tensors with Dynamic Values

CUDA graphs require static memory addresses. You can't give the graph "new" tensors—the addresses must be the same every time.

The solution: Pre-allocate static tensors and update their values before replay.

# Pre-allocate static tensors
input_ids = torch.zeros(max_batch_size, device='cuda')
positions = torch.zeros(max_batch_size, device='cuda')
outputs = torch.empty(max_batch_size, vocab_size, device='cuda')

# Record graph using these static tensors
with torch.cuda.graph(graph):
    outputs[:bs] = self.model(input_ids[:bs], positions[:bs])

# Save references
self.graph_vars = dict(
    input_ids=input_ids,
    positions=positions,
    outputs=outputs,
)

At inference time:

# Update static tensors with new data
bs = actual_batch_size
self.graph_vars['input_ids'][:bs] = new_tokens
self.graph_vars['positions'][:bs] = new_positions

# Replay the graph
self.graphs[bs].replay()

# Results are already in self.graph_vars['outputs']
logits = self.graph_vars['outputs'][:bs]

This is the magic: static graph speed + dynamic inputs.

Why Decode Benefits, Prefill Doesn't

CUDA graphs are perfect for decode but not for prefill. Here's why:

Decode:

Prefill:

This is why the code only uses CUDA graphs for decode:

def run_model(self, input_ids, positions, is_prefill):
    if is_prefill or self.enforce_eager or input_ids.size(0) > 512:
        # Eager execution
        return self.model.compute_logits(self.model(input_ids, positions))
    else:
        # CUDA graph for decode
        return self.run_cudagraph(input_ids, positions)

Batch Size Bucketing

You can't capture a graph for every possible batch size. Instead, you capture for specific sizes and bucket:

self.graph_bs = [1, 2, 4, 8, 16, 32]  # Supported batch sizes

# At inference time:
bs = actual_batch_size  # e.g., 5
graph_bs = next(x for x in self.graph_bs if x >= bs)  # Use graph for 8
graph = self.graphs[graph_bs]

If you have 5 sequences, you use the graph captured for batch size 8. The extra 3 slots are filled with dummy data and masked out.

This trades a small amount of wasted computation for the ability to use CUDA graphs.

The Performance Impact

How much does CUDA graph help?

For a typical decode step with batch size 32:

That's a 3-5x speedup just from eliminating launch overhead.

For large-scale inference with thousands of tokens generated, this adds up to significant wall-clock time savings.

When CUDA Graphs Don't Help

CUDA graphs are most effective when:

They don't help when:

The Bigger Picture

CUDA graphs are one piece of the optimization puzzle. Combined with:

They enable modern inference engines to achieve 10-100x speedups over naive implementations.

The key insight: Optimization isn't about one trick—it's about eliminating bottlenecks at every level, from memory layout to kernel launch overhead to batch scheduling.