Back to LLM Inference

Tensor Flow Through the Model: From Tokens to Logits

When you pass tokens through an LLM, they undergo a series of transformations. Each layer refines the representation, adding context and meaning. But most of these transformations happen implicitly—you call a module like a function, and PyTorch automatically invokes its forward() method. Let's trace exactly what happens to a tensor as it flows through the model.

The Big Picture

Here's the overall flow:

Input tokens (IDs)
    ↓
Embedding layer (token IDs → vectors)
    ↓
Transformer layers (24-32 layers of attention + MLP)
    ↓
Final normalization
    ↓
Linear projection to vocabulary (logits)
    ↓
Sampling (logits → next token)

But what does this actually look like with real numbers? Let's walk through it.

Starting Point: ModelRunner.run_model()

You have 2 tokens to process:

input_ids = torch.tensor([12, 34])  # Token IDs: "Hello", "world"
positions = torch.tensor([0, 1])     # Positions in sequence

The ModelRunner calls:

logits = self.model.compute_logits(self.model(input_ids, positions))

This invokes Qwen3ForCausalLM.forward(), which is the entry point.

Step 1: Qwen3ForCausalLM.forward()

def forward(self, input_ids, positions):
    # Call the core model
    hidden_states = self.model(input_ids, positions)
    
    # Apply masking if needed (for mixed batching)
    if isinstance(mask, torch.Tensor):
        flat_mask = mask.reshape(-1)
    
    # Return filtered hidden states
    return hidden_states

This is a wrapper that delegates to the core transformer model.

Step 2: Qwen3Model.forward() - The Core Transformer

Now we enter the actual transformer:

# Input:
input_ids = tensor([12, 34])     # shape: [2]
positions = tensor([0, 1])       # shape: [2]

# Operation 1: Embedding lookup
hidden_states = self.embed_tokens(input_ids)
# Result: hidden_states = tensor([
#   [0.1, -0.2, 0.3, ...],  # Embedding for token 12
#   [0.4, 0.1, -0.1, ...]   # Embedding for token 34
# ])  # shape: [2, hidden_size=2048]

The embedding layer converts token IDs to dense vectors. Each token gets a 2048-dimensional vector.

# Operation 2: Loop through transformer layers
residual = None
for layer in self.layers:  # 24 layers for example
    hidden_states, residual = layer(positions, hidden_states, residual)

Each layer takes the hidden states and refines them. After the first layer:

# hidden_states: [2, 2048] - transformed by attention+MLP
# residual: [2, 2048] - original embedding for residual connection

After all 24 layers, the hidden states have been transformed many times, with each layer adding more context and meaning.

# Operation 3: Final normalization
hidden_states, _ = self.norm(hidden_states, residual)
# Result: hidden_states = tensor([
#   [0.8, 0.2, -0.4, ...],  # Final representation for token 12
#   [-0.1, 0.7, 0.3, ...]   # Final representation for token 34
# ])  # shape: [2, 2048]

Step 3: Inside Qwen3DecoderLayer - The Transformer Block

Each decoder layer does the same thing: apply attention, then apply MLP. Let's trace one layer:

# Input to first decoder layer:
positions = tensor([0, 1])                    # shape: [2]
hidden_states = tensor([[0.1, -0.2, ...],     # shape: [2, 2048]
                       [0.4, 0.1, ...]])
residual = None

# Operation 1: Input layernorm
hidden_states, residual = self.input_layernorm(hidden_states), hidden_states
# Result:
# hidden_states: [2, 2048] - normalized version
# residual: [2, 2048] - original for residual connection

Normalization stabilizes the training and helps information flow.

# Operation 2: Self-attention
hidden_states = self.self_attn(positions, hidden_states)
# This calls Qwen3Attention.forward():
#   qkv = self.qkv_proj(hidden_states)      # [2, 2048] → [2, 2304] (Q+K+V)
#   q, k, v = split qkv                      # Each: [2, num_heads, head_dim]
#   o = self.attn(q, k, v)                   # [2, num_heads, head_dim]
#   output = self.o_proj(o.flatten(1, -1))    # [2, 2304] → [2, 2048]
# Result: [2, 2048] - attention output

Attention is where the model learns relationships between tokens. The Q (query), K (key), and V (value) projections are linear transformations that prepare the data for the attention mechanism.

# Operation 3: Post-attention layernorm
hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
# Result:
# hidden_states: [2, 2048] - normalized attention output + residual
# residual: [2, 2048] - updated residual

Residual connections allow gradients to flow directly through layers, preventing the vanishing gradient problem.

# Operation 4: MLP (Feed-Forward Network)
hidden_states = self.mlp(hidden_states)
# This calls Qwen3MLP.forward():
#   gate_up = self.gate_up_proj(x)           # [2, 2048] → [2, 4096*2]
#   x = self.act_fn(gate_up)                 # SiLU activation
#   x = self.down_proj(x)                    # [2, 4096*2] → [2, 2048]
# Result: [2, 2048] - final layer output

The MLP adds non-linear reasoning capacity. It expands to 4x the hidden size, applies an activation function, then projects back down.

Step 4: Back to Qwen3ForCausalLM.forward()

After all layers have processed the tokens:

# After Qwen3Model returns:
hidden_states = tensor([[0.8, 0.2, -0.4, ...],     # shape: [2, 2048]
                       [-0.1, 0.7, 0.3, ...]])

# Mask processing (for prefill, no filtering needed)
if isinstance(mask, torch.Tensor):
    flat_mask = mask.reshape(-1)  # [-1, -1]
else:
    flat_mask = torch.tensor([-1, -1])  # shape: [2]

# Since mask is all -1 (prefill), no filtering
return hidden_states  # [2, 2048]

Step 5: Qwen3ForCausalLM.compute_logits() - The Final Projection

Now we have the final hidden states. But we need to convert them to vocabulary logits (probabilities for each token):

# Input:
hidden_states = tensor([[0.8, 0.2, -0.4, ...],     # shape: [2, 2048]
                       [-0.1, 0.7, 0.3, ...]])

# This calls: return self.lm_head(hidden_states)
# Which calls ParallelLMHead.forward():

# Operation 1: Handle prefill context
if context.is_prefill and context.context_lens is None:
    last_indices = context.cu_seqlens_q[1:] - 1  # [1] (last token of each sequence)
    x = hidden_states[last_indices]              # Select last tokens only
# Result: x = tensor([[-0.1, 0.7, 0.3, ...]])    # shape: [1, 2048]

# Operation 2: Final linear projection
logits = F.linear(x, self.weight)               # [1, 2048] × [2048, vocab_size]
# Result: logits = tensor([[0.2, -0.1, 0.8, ...]])  # shape: [1, vocab_size=152064]

This is the moment of truth. The F.linear() operation projects the 2048-dimensional hidden state to 152,064 dimensions (the vocabulary size). Each dimension represents the model's confidence that the next token is that word.

# Operation 3: Tensor parallel gathering (if needed)
if tp_size > 1:
    # Gather logits from all GPUs and concatenate
    logits = torch.cat(all_logits, -1)          # [1, 152064]

return logits

Why Most Calls Are Implicit

You might notice that we never explicitly call forward(). Instead, we just use the module like a function:

hidden_states = self.model(input_ids, positions)  # Calls model.forward()
hidden_states = self.self_attn(positions, hidden_states)  # Calls self_attn.forward()
hidden_states = self.mlp(hidden_states)  # Calls mlp.forward()

This is PyTorch magic. When you call a nn.Module instance, it automatically invokes the forward() method. This is why the code looks clean—all the complexity is hidden behind these function calls.

The Key Insight

The final F.linear(x, self.weight) operation is simple—just a matrix multiplication. But it works because its input hidden_states contains all the complex contextual information extracted by the transformer layers.

Without the transformer:

With the transformer:

Each layer in the transformer refines the representation:

By the time you reach the final linear layer, the hidden states contain everything the model has learned about the input. The linear layer just needs to project that knowledge to vocabulary space.

From Logits to Tokens

After you have the logits, the sampler converts them to actual token IDs:

token_ids = self.sampler(logits, temperatures).tolist()
# Result: [13986]  # "world"

The sampler applies temperature scaling (to control randomness) and then samples from the probability distribution.

This completes one forward pass. The token is added to the sequence, and the process repeats for the next token.