Skip to content

Understanding KV Cache Bugs in Transformers

January 27, 2026

While working on a transformer-based language model, I encountered subtle bugs in the KV cache implementation that caused gibberish outputs during text generation. Here's a deep dive into what Int wrong and how to fix it.

What is KV Cache?

In transformer models, during text generation, I generate one token at a time. For each new token, the model needs to "look back" at all previous tokens through the attention mechanism.

Without KV Cache (Naive):

Generate token 1: Process [token1]
Generate token 2: Process [token1, token2]        ← recompute token1
Generate token 3: Process [token1, token2, token3] ← recompute token1, token2
...

This is O(N²) - very slow!

With KV Cache:

Generate token 1: Process [token1], cache K1, V1
Generate token 2: Process [token2] only, use cached K1,V1 → get K2,V2
Generate token 3: Process [token3] only, use cached K1,V1,K2,V2 → get K3,V3
...

This is O(N) - much faster!

How Attention Works (Kinda)

For each token, the model computes:

  • Q (Query): "What am I looking for?"
  • K (Key): "What do I contain?"
  • V (Value): "What information do I have?"
Attention score = Q × K^T (how relevant is each past token?)
Output = softmax(scores) × V (Iighted sum of past information)

The Causal Mask

In language models, token 5 should only "see" tokens 1-4, not future tokens 6, 7, 8...

This is enforced by a causal mask - a loIr triangular matrix:

Token:    1  2  3  4  5
      1 [ 1  0  0  0  0 ]  ← Token 1 sees only itself
      2 [ 1  1  0  0  0 ]  ← Token 2 sees tokens 1-2
      3 [ 1  1  1  0  0 ]  ← Token 3 sees tokens 1-3
      4 [ 1  1  1  1  0 ]  ← Token 4 sees tokens 1-4
      5 [ 1  1  1  1  1 ]  ← Token 5 sees tokens 1-5

The Two Bugs I Encountered

Bug 1: Position Encoding

Transformers don't inherently know token order. I add position embeddings to tell the model "this is token #5".

Broken Code:

python
pos = torch.arange(0, T, ...)  # T = number of input tokens

I found an issue when implementing token #8 generation with KV cache:

  • I only input the NEW token (T=1)
  • Position becomes [0] instead of [7]
  • Model thinks it's generating the FIRST token!

Fixed Code:

python
past_length = past_key_values[0][0].size(2)  # How many cached tokens?
pos = torch.arange(past_length, past_length + T, ...)  # Start from correct position

Now when generating token #8: position = [7]

Bug 2: Causal Mask Selection

This is where things get really subtle.

How attention shapes work with KV cache:

When generating token #8 (with 7 cached tokens):

Q shape: [batch, heads, 1, dim]      ← Only current token's query
K shape: [batch, heads, 8, dim]      ← All 8 tokens (7 cached + 1 new)
V shape: [batch, heads, 8, dim]      ← All 8 tokens

Attention scores: Q × K^T
Score shape: [batch, heads, 1, 8]    ← 1 query attending to 8 keys

The broken mask selection:

python
# Broken: selects rows 0:1, cols 0:8
mask = self.bias[:, :, :att.size(-2), :att.size(-1)]
# This gives: self.bias[:, :, :1, :8]
# Which is ROW 0 of the causal mask!

What row 0 looks like:

Row 0: [1, 0, 0, 0, 0, 0, 0, 0]

Result: Token #8 can ONLY attend to token #1! It's blind to tokens 2-7!

This is why output was gibberish - each generated token could only see the very first token of the conversation.

Fixed Code:

python
past_len = past_k.size(2)  # 7 cached tokens
seq_len = k.size(2)        # 8 total tokens

# Select the CORRECT row based on current position
mask = self.bias[:, :, past_len:past_len+T, :seq_len]
# This gives: self.bias[:, :, 7:8, :8]
# Which is ROW 7 of the causal mask!

What row 7 looks like:

Row 7: [1, 1, 1, 1, 1, 1, 1, 1]

Now token #8 can attend to ALL previous tokens ✓

Visual Summary

BROKEN (Row 0 mask):
┌─────────────────────────────┐
│ Token 8 generating...       │
│ Can see: [Token 1] only     │  ← Using wrong mask row!
│ Blind to: [2,3,4,5,6,7]     │
│ Output: gibberish           │
└─────────────────────────────┘

FIXED (Row 7 mask):
┌─────────────────────────────┐
│ Token 8 generating...       │
│ Can see: [1,2,3,4,5,6,7]    │  ← Using correct mask row!
│ Has full context            │
│ Output: coherent            │
└─────────────────────────────┘

Why This Bug is Common

Many tutorials and implementations show attention like this:

python
mask = causal_mask[:, :, :seq_len, :seq_len]

This works without KV cache because Q and K have the same length.

But with KV cache:

  • Q length ≠ K length
  • Q has only NEW tokens
  • K has ALL tokens (cached + new)

The mask indexing must account for this asymmetry.