Tensor Flow Through the Model: From Tokens to Logits
When you pass tokens through an LLM, they undergo a series of transformations. Each layer refines the representation, adding context and meaning. But most of these transformations happen implicitly—you call a module like a function, and PyTorch automatically invokes its forward() method. Let's trace exactly what happens to a tensor as it flows through the model.
The Big Picture
Here's the overall flow:
Input tokens (IDs)
↓
Embedding layer (token IDs → vectors)
↓
Transformer layers (24-32 layers of attention + MLP)
↓
Final normalization
↓
Linear projection to vocabulary (logits)
↓
Sampling (logits → next token)
But what does this actually look like with real numbers? Let's walk through it.
Starting Point: ModelRunner.run_model()
You have 2 tokens to process:
input_ids = torch.tensor([12, 34]) # Token IDs: "Hello", "world"
positions = torch.tensor([0, 1]) # Positions in sequence
The ModelRunner calls:
logits = self.model.compute_logits(self.model(input_ids, positions))
This invokes Qwen3ForCausalLM.forward(), which is the entry point.
Step 1: Qwen3ForCausalLM.forward()
def forward(self, input_ids, positions):
# Call the core model
hidden_states = self.model(input_ids, positions)
# Apply masking if needed (for mixed batching)
if isinstance(mask, torch.Tensor):
flat_mask = mask.reshape(-1)
# Return filtered hidden states
return hidden_states
This is a wrapper that delegates to the core transformer model.
Step 2: Qwen3Model.forward() - The Core Transformer
Now we enter the actual transformer:
# Input:
input_ids = tensor([12, 34]) # shape: [2]
positions = tensor([0, 1]) # shape: [2]
# Operation 1: Embedding lookup
hidden_states = self.embed_tokens(input_ids)
# Result: hidden_states = tensor([
# [0.1, -0.2, 0.3, ...], # Embedding for token 12
# [0.4, 0.1, -0.1, ...] # Embedding for token 34
# ]) # shape: [2, hidden_size=2048]
The embedding layer converts token IDs to dense vectors. Each token gets a 2048-dimensional vector.
# Operation 2: Loop through transformer layers
residual = None
for layer in self.layers: # 24 layers for example
hidden_states, residual = layer(positions, hidden_states, residual)
Each layer takes the hidden states and refines them. After the first layer:
# hidden_states: [2, 2048] - transformed by attention+MLP
# residual: [2, 2048] - original embedding for residual connection
After all 24 layers, the hidden states have been transformed many times, with each layer adding more context and meaning.
# Operation 3: Final normalization
hidden_states, _ = self.norm(hidden_states, residual)
# Result: hidden_states = tensor([
# [0.8, 0.2, -0.4, ...], # Final representation for token 12
# [-0.1, 0.7, 0.3, ...] # Final representation for token 34
# ]) # shape: [2, 2048]
Step 3: Inside Qwen3DecoderLayer - The Transformer Block
Each decoder layer does the same thing: apply attention, then apply MLP. Let's trace one layer:
# Input to first decoder layer:
positions = tensor([0, 1]) # shape: [2]
hidden_states = tensor([[0.1, -0.2, ...], # shape: [2, 2048]
[0.4, 0.1, ...]])
residual = None
# Operation 1: Input layernorm
hidden_states, residual = self.input_layernorm(hidden_states), hidden_states
# Result:
# hidden_states: [2, 2048] - normalized version
# residual: [2, 2048] - original for residual connection
Normalization stabilizes the training and helps information flow.
# Operation 2: Self-attention
hidden_states = self.self_attn(positions, hidden_states)
# This calls Qwen3Attention.forward():
# qkv = self.qkv_proj(hidden_states) # [2, 2048] → [2, 2304] (Q+K+V)
# q, k, v = split qkv # Each: [2, num_heads, head_dim]
# o = self.attn(q, k, v) # [2, num_heads, head_dim]
# output = self.o_proj(o.flatten(1, -1)) # [2, 2304] → [2, 2048]
# Result: [2, 2048] - attention output
Attention is where the model learns relationships between tokens. The Q (query), K (key), and V (value) projections are linear transformations that prepare the data for the attention mechanism.
# Operation 3: Post-attention layernorm
hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
# Result:
# hidden_states: [2, 2048] - normalized attention output + residual
# residual: [2, 2048] - updated residual
Residual connections allow gradients to flow directly through layers, preventing the vanishing gradient problem.
# Operation 4: MLP (Feed-Forward Network)
hidden_states = self.mlp(hidden_states)
# This calls Qwen3MLP.forward():
# gate_up = self.gate_up_proj(x) # [2, 2048] → [2, 4096*2]
# x = self.act_fn(gate_up) # SiLU activation
# x = self.down_proj(x) # [2, 4096*2] → [2, 2048]
# Result: [2, 2048] - final layer output
The MLP adds non-linear reasoning capacity. It expands to 4x the hidden size, applies an activation function, then projects back down.
Step 4: Back to Qwen3ForCausalLM.forward()
After all layers have processed the tokens:
# After Qwen3Model returns:
hidden_states = tensor([[0.8, 0.2, -0.4, ...], # shape: [2, 2048]
[-0.1, 0.7, 0.3, ...]])
# Mask processing (for prefill, no filtering needed)
if isinstance(mask, torch.Tensor):
flat_mask = mask.reshape(-1) # [-1, -1]
else:
flat_mask = torch.tensor([-1, -1]) # shape: [2]
# Since mask is all -1 (prefill), no filtering
return hidden_states # [2, 2048]
Step 5: Qwen3ForCausalLM.compute_logits() - The Final Projection
Now we have the final hidden states. But we need to convert them to vocabulary logits (probabilities for each token):
# Input:
hidden_states = tensor([[0.8, 0.2, -0.4, ...], # shape: [2, 2048]
[-0.1, 0.7, 0.3, ...]])
# This calls: return self.lm_head(hidden_states)
# Which calls ParallelLMHead.forward():
# Operation 1: Handle prefill context
if context.is_prefill and context.context_lens is None:
last_indices = context.cu_seqlens_q[1:] - 1 # [1] (last token of each sequence)
x = hidden_states[last_indices] # Select last tokens only
# Result: x = tensor([[-0.1, 0.7, 0.3, ...]]) # shape: [1, 2048]
# Operation 2: Final linear projection
logits = F.linear(x, self.weight) # [1, 2048] × [2048, vocab_size]
# Result: logits = tensor([[0.2, -0.1, 0.8, ...]]) # shape: [1, vocab_size=152064]
This is the moment of truth. The F.linear() operation projects the 2048-dimensional hidden state to 152,064 dimensions (the vocabulary size). Each dimension represents the model's confidence that the next token is that word.
# Operation 3: Tensor parallel gathering (if needed)
if tp_size > 1:
# Gather logits from all GPUs and concatenate
logits = torch.cat(all_logits, -1) # [1, 152064]
return logits
Why Most Calls Are Implicit
You might notice that we never explicitly call forward(). Instead, we just use the module like a function:
hidden_states = self.model(input_ids, positions) # Calls model.forward()
hidden_states = self.self_attn(positions, hidden_states) # Calls self_attn.forward()
hidden_states = self.mlp(hidden_states) # Calls mlp.forward()
This is PyTorch magic. When you call a nn.Module instance, it automatically invokes the forward() method. This is why the code looks clean—all the complexity is hidden behind these function calls.
The Key Insight
The final F.linear(x, self.weight) operation is simple—just a matrix multiplication. But it works because its input hidden_states contains all the complex contextual information extracted by the transformer layers.
Without the transformer:
F.linear()would just map token IDs randomly to vocabulary
With the transformer:
F.linear()maps contextually rich representations to vocabulary probabilities
Each layer in the transformer refines the representation:
- Embedding: Token IDs → semantic vectors
- Attention layers: Add contextual understanding
- MLP layers: Add non-linear reasoning
- Normalization: Keep training stable
By the time you reach the final linear layer, the hidden states contain everything the model has learned about the input. The linear layer just needs to project that knowledge to vocabulary space.
From Logits to Tokens
After you have the logits, the sampler converts them to actual token IDs:
token_ids = self.sampler(logits, temperatures).tolist()
# Result: [13986] # "world"
The sampler applies temperature scaling (to control randomness) and then samples from the probability distribution.
This completes one forward pass. The token is added to the sequence, and the process repeats for the next token.