Understanding KV Cache and Paged Attention in LLM Inference
When you run an LLM, it doesn't just process your prompt once and forget about it. Instead, it needs to remember the key and value vectors from every token in the sequence to efficiently generate the next token. This is where KV cache comes in—and it's one of the most critical optimizations in modern inference engines.
The Problem: Why KV Cache Matters
During the prefill phase, when the model processes your entire prompt, it computes attention over all tokens. But here's the inefficiency: during the decode phase (generating one token at a time), you'd recompute the same attention for all previous tokens just to get the new token's attention. That's wasteful.
The solution? Cache the K and V vectors from the prefill phase, so during decode you only compute attention for the new token against the cached values. This is a massive speedup.
What Does KV Cache Actually Look Like?
In a system like nano-vLLM, the KV cache is a pre-allocated GPU memory pool, not a dynamic list. Here's the actual memory layout:
Shape: [2, 28, 1320, 64, 8, 128]
Dimensions: [KV, Layers, Blocks, Tokens, Heads, Dim]
Breaking this down:
- KV (2): Separate storage for Keys and Values
- Layers (28): Each transformer layer has its own cache
- Blocks (1320): Physical memory blocks (we'll explain this next)
- Tokens (64): Tokens per block
- Heads (8): Number of key-value heads
- Dim (128): Dimension of each head
The total memory is calculated as: 2 × 28 × 1320 × 64 × 8 × 128 = ~9GB
Paged Attention: Treating KV Cache Like Virtual Memory
Here's where it gets clever. Instead of allocating one giant contiguous block for each sequence's KV cache, modern inference engines use paged attention—similar to how operating systems manage virtual memory.
The idea is simple:
- Divide the KV cache into fixed-size blocks (e.g., 64 tokens per block)
- Each sequence gets a block table (like a page table) that maps logical blocks to physical blocks
- Blocks can be scattered in memory, but the block table keeps track of where they are
Why This Matters
- Memory efficiency: You don't waste space on sequences of different lengths
- Block reuse: If two sequences share a common prefix, they can point to the same physical block
- Flexible batching: Sequences can be added/removed without reorganizing memory
How Block Tables Work
Let's say you have two sequences:
- Sequence A: 100 tokens (needs 2 blocks of size 64)
- Sequence B: 50 tokens (needs 1 block of size 64)
The block manager allocates:
- Physical block 0 → Sequence A's first 64 tokens
- Physical block 1 → Sequence A's remaining 36 tokens
- Physical block 2 → Sequence B's 50 tokens
Each sequence stores its block table:
- Sequence A:
block_table = [0, 1] - Sequence B:
block_table = [2]
When the attention layer needs to read KV values, it uses the block table to find the physical address.
Slot Mapping: From Logical to Physical
During inference, the model needs to know exactly where to write new KV values. This is where slot mapping comes in.
For each token being processed, we calculate:
slot = physical_block * block_size + offset_in_block
For example, if we're writing token 5 of sequence A:
- Token 5 is in logical block 0 (5 < 64)
- Physical block 0 is at index 0
- Offset is 5
- Slot = 0 × 64 + 5 = 5
The slot mapping is passed to the GPU, and the attention kernel writes the K and V vectors to that exact memory location.
Prefill vs Decode: Different Batch Structures
The KV cache management is fundamentally different in prefill and decode phases.
During Prefill:
- Multiple sequences with varying lengths
- All uncached tokens are processed in one batch
- Tokens are "packed" together (no padding between sequences)
- Example: sequences of length 10, 15, 20 → one batch of 45 tokens
During Decode:
- Each sequence contributes exactly one token
- All tokens are processed in parallel
- Fixed batch structure (one token per sequence)
- Example: 32 sequences → batch of 32 tokens
This is why the prepare_prefill and prepare_decode functions look so different—they're building fundamentally different batch structures.
Memory Address Calculation
If you want to access the KV cache at a specific location, you need to calculate the memory address. Given the stride information:
address = base_address +
kv * 2422210560 * element_size +
layer * 86507520 * element_size +
block * 65536 * element_size +
token * 1024 * element_size +
head * 128 * element_size +
dim * 1 * element_size
For example, to access Layer 1, Block 0, Token 0, Head 0, Dim 0:
address = 0x7131dc000000 + 1 * 86507520 * 2 = 0x7131e6500000
This is how the GPU kernel knows exactly where to read and write KV values.
The Big Picture
KV cache with paged attention is what makes LLM inference practical. Without it, you'd recompute attention for every token, making inference impossibly slow. With it, you cache the expensive parts and only compute what's new.
The block-based approach adds another layer of efficiency: you can manage sequences of different lengths without wasting memory, and you can even share blocks between sequences that have common prefixes (a technique called prefix caching).
Understanding this memory layout is crucial for optimizing inference engines and debugging performance issues.