Understanding KV Cache Bugs in Transformers
January 27, 2026
While working on a transformer-based language model, I encountered subtle bugs in the KV cache implementation that caused gibberish outputs during text generation. Here's a deep dive into what Int wrong and how to fix it.
What is KV Cache?
In transformer models, during text generation, I generate one token at a time. For each new token, the model needs to "look back" at all previous tokens through the attention mechanism.
Without KV Cache (Naive):
Generate token 1: Process [token1]
Generate token 2: Process [token1, token2] ← recompute token1
Generate token 3: Process [token1, token2, token3] ← recompute token1, token2
...This is O(N²) - very slow!
With KV Cache:
Generate token 1: Process [token1], cache K1, V1
Generate token 2: Process [token2] only, use cached K1,V1 → get K2,V2
Generate token 3: Process [token3] only, use cached K1,V1,K2,V2 → get K3,V3
...This is O(N) - much faster!
How Attention Works (Kinda)
For each token, the model computes:
- Q (Query): "What am I looking for?"
- K (Key): "What do I contain?"
- V (Value): "What information do I have?"
Attention score = Q × K^T (how relevant is each past token?)
Output = softmax(scores) × V (Iighted sum of past information)The Causal Mask
In language models, token 5 should only "see" tokens 1-4, not future tokens 6, 7, 8...
This is enforced by a causal mask - a loIr triangular matrix:
Token: 1 2 3 4 5
1 [ 1 0 0 0 0 ] ← Token 1 sees only itself
2 [ 1 1 0 0 0 ] ← Token 2 sees tokens 1-2
3 [ 1 1 1 0 0 ] ← Token 3 sees tokens 1-3
4 [ 1 1 1 1 0 ] ← Token 4 sees tokens 1-4
5 [ 1 1 1 1 1 ] ← Token 5 sees tokens 1-5The Two Bugs I Encountered
Bug 1: Position Encoding
Transformers don't inherently know token order. I add position embeddings to tell the model "this is token #5".
Broken Code:
pos = torch.arange(0, T, ...) # T = number of input tokensI found an issue when implementing token #8 generation with KV cache:
- I only input the NEW token (T=1)
- Position becomes
[0]instead of[7] - Model thinks it's generating the FIRST token!
Fixed Code:
past_length = past_key_values[0][0].size(2) # How many cached tokens?
pos = torch.arange(past_length, past_length + T, ...) # Start from correct positionNow when generating token #8: position = [7]
Bug 2: Causal Mask Selection
This is where things get really subtle.
How attention shapes work with KV cache:
When generating token #8 (with 7 cached tokens):
Q shape: [batch, heads, 1, dim] ← Only current token's query
K shape: [batch, heads, 8, dim] ← All 8 tokens (7 cached + 1 new)
V shape: [batch, heads, 8, dim] ← All 8 tokens
Attention scores: Q × K^T
Score shape: [batch, heads, 1, 8] ← 1 query attending to 8 keysThe broken mask selection:
# Broken: selects rows 0:1, cols 0:8
mask = self.bias[:, :, :att.size(-2), :att.size(-1)]
# This gives: self.bias[:, :, :1, :8]
# Which is ROW 0 of the causal mask!What row 0 looks like:
Row 0: [1, 0, 0, 0, 0, 0, 0, 0]Result: Token #8 can ONLY attend to token #1! It's blind to tokens 2-7!
This is why output was gibberish - each generated token could only see the very first token of the conversation.
Fixed Code:
past_len = past_k.size(2) # 7 cached tokens
seq_len = k.size(2) # 8 total tokens
# Select the CORRECT row based on current position
mask = self.bias[:, :, past_len:past_len+T, :seq_len]
# This gives: self.bias[:, :, 7:8, :8]
# Which is ROW 7 of the causal mask!What row 7 looks like:
Row 7: [1, 1, 1, 1, 1, 1, 1, 1]Now token #8 can attend to ALL previous tokens ✓
Visual Summary
BROKEN (Row 0 mask):
┌─────────────────────────────┐
│ Token 8 generating... │
│ Can see: [Token 1] only │ ← Using wrong mask row!
│ Blind to: [2,3,4,5,6,7] │
│ Output: gibberish │
└─────────────────────────────┘
FIXED (Row 7 mask):
┌─────────────────────────────┐
│ Token 8 generating... │
│ Can see: [1,2,3,4,5,6,7] │ ← Using correct mask row!
│ Has full context │
│ Output: coherent │
└─────────────────────────────┘Why This Bug is Common
Many tutorials and implementations show attention like this:
mask = causal_mask[:, :, :seq_len, :seq_len]This works without KV cache because Q and K have the same length.
But with KV cache:
- Q length ≠ K length
- Q has only NEW tokens
- K has ALL tokens (cached + new)
The mask indexing must account for this asymmetry.