Introduction
Every LLM (GPT-4, Claude, Gemini) runs on self-attention. Yet most explanations hand-wave past the math. This article gives a compact, verifiable explanation of how queries (Q), keys (K) and values (V) interact inside a transformer layer and includes interactive visualizations so you can experiment with real matrices. By the end of this article you'll be able to compute attention weights, reason about softmax scaling, and relate the algebra to practical model behavior.
Scope: linear algebra recap, attention formula, visualization of Q·Kᵀ/√d, and examples that connect math to model output.
🔬 Interactive Attention Calculator
Click any word to see how attention flows through Q, K, V matrices
Q Query Vector
"What am I looking for?"
K Key Vectors
"What do I contain?"
V Value Vectors
"What information to pass?"
📊 Attention Weights (Softmax of Q·KT/√d)
📝 Step-by-Step Calculation for "sat"
🧩 Real Simulations
Run larger attention simulations with configurable tokens, embedding size and heads. Heavy runs use a Web Worker to avoid blocking the page.
Simulation Output
📐 The Attention Formula
Self-attention can be written in one beautiful equation:
Let's break this down piece by piece.
Step 1: Create Q, K, V Matrices
Each input token is first embedded as a vector (say, 512 dimensions). Then we project through three learned weight matrices:
Q = X · W_Q # Query: "What am I looking for?"
K = X · W_K # Key: "What do I contain?"
V = X · W_V # Value: "What information to pass forward?"
For a sequence of 5 tokens with embedding size 512 and attention dimension 64:
- X is (5, 512) - 5 tokens, 512 dimensions each
- W_Q, W_K, W_V are (512, 64) - projection matrices
- Q, K, V are (5, 64) - projected representations
Step 2: Compute Attention Scores
For each query, compute dot product with all keys:
This gives us a (5, 5) matrix where scores[i][j] measures how much token i should attend to token j.
Step 3: Scale by √dk
Why scale? Dot products can get very large with high dimensions, pushing softmax into saturation (all attention on one token). Scaling by √dk keeps gradients healthy.
# Without scaling: softmax([100, 2, 3, 4, 5]) ≈ [1.0, 0, 0, 0, 0]
# With scaling: softmax([12.5, 0.25, 0.38, 0.5, 0.63]) ≈ [0.9, 0.02, 0.02, 0.03, 0.03]
Step 4: Apply Softmax
Convert scores to probabilities (sum to 1 for each query):
Step 5: Weighted Sum of Values
Finally, compute output as weighted combination of value vectors:
🎭 Multi-Head Attention
One attention head captures one type of relationship. GPT-4 uses 96 attention heads in parallel, each learning different patterns:
- Head 1: Syntactic relationships (subject-verb)
- Head 2: Positional patterns (nearby words)
- Head 3: Semantic similarity
- Head 4: Coreference (pronouns to nouns)
- ... and 92 more!
# Multi-head attention
heads = []
for i in range(num_heads):
Q_i = X @ W_Q[i]
K_i = X @ W_K[i]
V_i = X @ W_V[i]
head_i = attention(Q_i, K_i, V_i)
heads.append(head_i)
# Concatenate and project
multi_head_output = concat(heads) @ W_O
🔍 What Attention Learns
Researchers have found attention heads that specialize in:
- "Previous token" heads: Always attend to the immediately preceding token
- "Induction" heads: Pattern match: if "A B" appeared before, when seeing "A" again, attend to what followed "B"
- "Subject-verb" heads: Connect verbs to their subjects across long distances
- "Bracket matching" heads: Track nested structures like parentheses
💻 Implementation in PyTorch
import torch
import torch.nn as nn
import torch.nn.functional as F
class SelfAttention(nn.Module):
def __init__(self, embed_dim, num_heads):
super().__init__()
self.num_heads = num_heads
self.head_dim = embed_dim // num_heads
self.W_q = nn.Linear(embed_dim, embed_dim)
self.W_k = nn.Linear(embed_dim, embed_dim)
self.W_v = nn.Linear(embed_dim, embed_dim)
self.W_o = nn.Linear(embed_dim, embed_dim)
def forward(self, x):
B, T, C = x.shape # batch, sequence length, embedding dim
# Project to Q, K, V
Q = self.W_q(x).view(B, T, self.num_heads, self.head_dim).transpose(1, 2)
K = self.W_k(x).view(B, T, self.num_heads, self.head_dim).transpose(1, 2)
V = self.W_v(x).view(B, T, self.num_heads, self.head_dim).transpose(1, 2)
# Compute attention scores
scores = (Q @ K.transpose(-2, -1)) / (self.head_dim ** 0.5)
# Apply softmax
attn_weights = F.softmax(scores, dim=-1)
# Weighted sum of values
out = attn_weights @ V
# Reshape and project output
out = out.transpose(1, 2).contiguous().view(B, T, C)
return self.W_o(out)
PyTorch Simulation: Visualize Per-Head Attention
The script below runs the SelfAttention module on a random batch, extracts per-head attention matrices, and saves PNG heatmaps per head. Run locally with Python + PyTorch.
"""
Run this script locally to generate attention heatmaps per head.
Requirements:
pip install torch numpy matplotlib
Usage:
python run_attention_sim.py
"""
import torch
import numpy as np
import matplotlib.pyplot as plt
def run_attention_sim(embed_dim=128, num_heads=4, seq_len=16, seed=0):
torch.manual_seed(seed)
model = SelfAttention(embed_dim=embed_dim, num_heads=num_heads)
model.eval()
# Single batch example
B = 1
x = torch.randn(B, seq_len, embed_dim)
with torch.no_grad():
# Project to Q/K/V and compute attention exactly as in the class
Q = model.W_q(x).view(B, seq_len, model.num_heads, model.head_dim).transpose(1,2)
K = model.W_k(x).view(B, seq_len, model.num_heads, model.head_dim).transpose(1,2)
V = model.W_v(x).view(B, seq_len, model.num_heads, model.head_dim).transpose(1,2)
# scores: (B, heads, T, T)
scores = torch.matmul(Q, K.transpose(-2, -1)) / (model.head_dim ** 0.5)
attn = torch.softmax(scores, dim=-1)
attn_np = attn[0].cpu().numpy() # shape (heads, T, T)
# save heatmaps
for h in range(attn_np.shape[0]):
plt.figure(figsize=(6,5))
plt.imshow(attn_np[h], cmap='viridis', vmin=0, vmax=1)
plt.title(f'Head {h} Attention')
plt.xlabel('Key position (j)')
plt.ylabel('Query position (i)')
plt.colorbar(label='Attention weight')
plt.tight_layout()
fname = f'attn_head_{h}.png'
plt.savefig(fname, dpi=150)
plt.close()
print(f'Saved {fname}')
# Save raw array for further analysis
np.save('attention_heads.npy', attn_np)
print('Saved attention_heads.npy')
if __name__ == '__main__':
run_attention_sim(embed_dim=128, num_heads=4, seq_len=16)
Beyond Vanilla Attention
Flash Attention
Standard attention is O(n²) in memory. Flash Attention uses tiling to achieve O(n) memory while being 2-4x faster.
Sparse Attention
Instead of attending to all tokens, use patterns like local windows + global tokens.
Linear Attention
Approximate softmax attention to achieve O(n) complexity. Used in models like Mamba.
📚 Further Reading
- Vaswani et al. (2017). "Attention Is All You Need" - arXiv: https://arxiv.org/abs/1706.03762
- Bahdanau et al. (2014). "Neural Machine Translation by Jointly Learning to Align and Translate" - arXiv: https://arxiv.org/abs/1409.0473
- Dao et al. (2022). "FlashAttention: Fast and Memory-Efficient Exact Attention" - arXiv: https://arxiv.org/abs/2205.14135
- Jay Alammar. "The Illustrated Transformer" - blog: https://jalammar.github.io/illustrated-transformer/
Conclusion
Self-attention is a compact, powerful operation that lets models compare tokens pairwise and route information via learned projections. The key steps - forming Q, K, V, computing scaled dot-products, applying softmax, and forming a weighted sum - are simple linear algebra. Through multi-head and optimized implementations like FlashAttention, attention scales to large models while retaining interpretability at the head level.
Try the interactive demo above: experiment with scaling, token vectors, and observe how attention weights shift. If you'd like, we can add downloadable notebooks or a Web Worker to run larger examples without blocking the UI.
References
- Vaswani, A., Shazeer, N., Parmar, N., Uszkoreit, J., Jones, L., Gomez, A. N., Kaiser, Ł., & Polosukhin, I. (2017). Attention Is All You Need. https://arxiv.org/abs/1706.03762
- Bahdanau, D., Cho, K., & Bengio, Y. (2014). Neural Machine Translation by Jointly Learning to Align and Translate. https://arxiv.org/abs/1409.0473
- Dao, T., et al. (2022). FlashAttention: Fast and Memory-Efficient Exact Attention. https://arxiv.org/abs/2205.14135
- Alammar, J. (Illustrated Transformer). Visual guide and explainer. https://jalammar.github.io/illustrated-transformer/
- For practical PyTorch implementations and tutorials, see the official PyTorch docs: https://pytorch.org/docs/stable/nn.html#multiheadattention
Help us improve by rating this article and sharing your thoughts
Leave a Comment
Previous Comments
Great article! Very informative and well-structured. Looking forward to more content like this.