In 2017, a paper with the provocative title "Attention Is All You Need" changed everything. The transformer architecture it introduced has since become the foundation of GPT, BERT, and virtually every large language model that followed. At the heart of this architecture is a deceptively simple idea: attention.
The word "attention" might conjure images of human cognition, of focusing on one thing while ignoring others. The mathematical mechanism shares this intuition. When a transformer processes the sentence "The cat sat on the mat because it was tired", it needs to figure out that "it" refers to "cat". Attention is how it makes that connection, how it learns to focus on relevant parts of the input when producing each part of the output.
But intuition only gets you so far. To truly understand attention, and to appreciate both its elegance and its limitations, we need to dive into the mathematics. Let's build it from scratch.
The Core Intuition
Before the equations, let's establish what we're trying to achieve. Imagine you're translating "The cat sat on the mat" to French. When generating the word for "cat" in French ("chat"), you need to pay attention to "cat" in the English sentence. When generating "tapis" (mat), you need to focus on "mat".
Traditional sequence models processed input one token at a time, passing information through hidden states. This worked, but it created a bottleneck. Information from the beginning of a long sequence had to survive many processing steps to influence the end, and much was lost along the way.
Attention sidesteps this problem entirely. Instead of relying on a chain of hidden states, every output position can directly look at every input position and decide how much to focus on each. It's like having a spotlight that can illuminate any part of the input, and the model learns where to point that spotlight.
Queries, Keys, and Values
The attention mechanism operates on three types of vectors, which the original paper called queries, keys, and values. The naming comes from database retrieval, which is a useful analogy.
Think of it this way: you have a collection of information stored as key-value pairs. When you want to retrieve information, you provide a query. The system compares your query to all the keys, determines how well each key matches, and returns a weighted combination of the corresponding values.
In the transformer, each input position generates all three: a query vector (what am I looking for?), a key vector (what information do I contain?), and a value vector (what information should I provide if selected?). These are computed by multiplying the input embeddings by learned weight matrices.
K = X · W_k
V = X · W_v
Where X is the input matrix (sequence length × embedding dimension), and W_q, W_k, W_v are learned weight matrices. The resulting Q, K, V matrices have dimensions (sequence length × d_k), where d_k is typically 64.
Scaled Dot-Product Attention
Now for the heart of the mechanism. Given queries, keys, and values, how do we compute attention? The formula is elegant:
Let's break this down step by step. First, we compute Q · K^T, the dot product of queries with keys. This produces a matrix where each entry (i, j) represents how much query i should attend to key j. High values mean high compatibility.
The division by √d_k is crucial and easy to overlook. Without it, the dot products can become very large when d_k is big, pushing the softmax into regions where its gradients are vanishingly small. This scaling keeps the values in a reasonable range where learning can proceed smoothly.
The softmax converts these raw compatibility scores into a probability distribution. For each query position, the attention weights across all key positions now sum to 1. Finally, we multiply by V to get a weighted combination of values.
Implementing scaled dot-product attentionimport numpy as np
def softmax(x, axis=-1):
"""Numerically stable softmax."""
exp_x = np.exp(x - np.max(x, axis=axis, keepdims=True))
return exp_x / np.sum(exp_x, axis=axis, keepdims=True)
def scaled_dot_product_attention(Q, K, V, mask=None):
"""
Compute scaled dot-product attention.
Args:
Q: Queries, shape (batch, seq_len, d_k)
K: Keys, shape (batch, seq_len, d_k)
V: Values, shape (batch, seq_len, d_v)
mask: Optional mask for positions to ignore
Returns:
Output and attention weights
"""
d_k = Q.shape[-1]
# Compute attention scores
scores = np.matmul(Q, K.transpose(0, 2, 1)) / np.sqrt(d_k)
# Apply mask if provided (for causal attention)
if mask is not None:
scores = np.where(mask == 0, -1e9, scores)
# Convert to probabilities
attention_weights = softmax(scores, axis=-1)
# Weighted sum of values
output = np.matmul(attention_weights, V)
return output, attention_weights
Let's trace through a concrete example to see this in action.
Example: attention in action# Simple example: 3 positions, embedding dimension 4
np.random.seed(42)
# Input: 3 tokens, each embedded as 4-dimensional vector
X = np.array([[
[1.0, 0.0, 1.0, 0.0], # Token 1
[0.0, 1.0, 0.0, 1.0], # Token 2
[1.0, 1.0, 0.0, 0.0], # Token 3
]])
# In practice, W_q, W_k, W_v are learned. Here we use simple projections
d_model, d_k = 4, 2
W_q = np.random.randn(d_model, d_k) * 0.1
W_k = np.random.randn(d_model, d_k) * 0.1
W_v = np.random.randn(d_model, d_k) * 0.1
# Compute Q, K, V
Q = np.matmul(X, W_q)
K = np.matmul(X, W_k)
V = np.matmul(X, W_v)
# Apply attention
output, weights = scaled_dot_product_attention(Q, K, V)
print("Attention weights (which positions each position attends to):")
print(np.round(weights[0], 3))
print("\nEach row shows how much that position attends to positions 0, 1, 2")
Why Self-Attention Works
The magic of attention is that it learns what to attend to. In the example above, we used random weights. In a trained transformer, those weights have been optimised so that queries and keys align when they should.
Consider processing "The cat sat on the mat because it was tired." When computing the representation for "it", the query vector for "it" should be similar to the key vectors for "cat" and "mat" (both plausible antecedents), but through training on many examples, the model learns that "it" in this context matches "cat" better.
This is self-attention: every position in a sequence attends to every other position in the same sequence. The "self" distinguishes it from cross-attention, where queries come from one sequence (like a translation) and keys/values come from another (like the source sentence).
Multi-Head Attention
A single attention head can only focus on one thing at a time. But language is complex. When processing a word, you might need to attend to its syntactic role, its semantic meaning, and its position relative to other words, all simultaneously.
Multi-head attention runs several attention operations in parallel, each with different learned projections. Each "head" can learn to focus on different types of relationships.
where head_i = Attention(Q·W_q^i, K·W_k^i, V·W_v^i)
class MultiHeadAttention:
def __init__(self, d_model, num_heads):
"""
Multi-head attention layer.
Args:
d_model: Model dimension (e.g., 512)
num_heads: Number of attention heads (e.g., 8)
"""
assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
self.d_model = d_model
self.num_heads = num_heads
self.d_k = d_model // num_heads
# Initialize projection matrices for all heads at once
self.W_q = np.random.randn(d_model, d_model) * np.sqrt(2.0 / d_model)
self.W_k = np.random.randn(d_model, d_model) * np.sqrt(2.0 / d_model)
self.W_v = np.random.randn(d_model, d_model) * np.sqrt(2.0 / d_model)
self.W_o = np.random.randn(d_model, d_model) * np.sqrt(2.0 / d_model)
def split_heads(self, x):
"""Reshape for multi-head computation."""
batch_size, seq_len, _ = x.shape
x = x.reshape(batch_size, seq_len, self.num_heads, self.d_k)
return x.transpose(0, 2, 1, 3) # (batch, heads, seq, d_k)
def combine_heads(self, x):
"""Reverse the head split."""
batch_size, _, seq_len, _ = x.shape
x = x.transpose(0, 2, 1, 3)
return x.reshape(batch_size, seq_len, self.d_model)
def forward(self, X, mask=None):
"""
Forward pass of multi-head attention.
Args:
X: Input, shape (batch, seq_len, d_model)
mask: Optional attention mask
Returns:
Output of same shape as input
"""
# Linear projections
Q = np.matmul(X, self.W_q)
K = np.matmul(X, self.W_k)
V = np.matmul(X, self.W_v)
# Split into multiple heads
Q = self.split_heads(Q)
K = self.split_heads(K)
V = self.split_heads(V)
# Attention for each head
d_k = Q.shape[-1]
scores = np.matmul(Q, K.transpose(0, 1, 3, 2)) / np.sqrt(d_k)
if mask is not None:
scores = np.where(mask == 0, -1e9, scores)
attention_weights = softmax(scores, axis=-1)
attention_output = np.matmul(attention_weights, V)
# Combine heads and final projection
combined = self.combine_heads(attention_output)
output = np.matmul(combined, self.W_o)
return output, attention_weights
In practice, researchers have found that different heads do learn to specialise. Some heads focus on syntactic relationships (subject-verb agreement), others on semantic relationships (word meaning), and others on positional patterns (attending to nearby words).
Causal Masking for Language Models
When training a language model to predict the next word, there's a critical constraint: the model shouldn't be able to see future words. When predicting what comes after "The cat", the model can't look at "sat" because that's what it's trying to predict.
Causal masking enforces this constraint by setting attention weights to zero for all future positions. The mask is a lower triangular matrix of ones.
Causal (autoregressive) maskingdef create_causal_mask(seq_len):
"""
Create a causal mask for autoregressive attention.
Position i can only attend to positions <= i.
"""
mask = np.tril(np.ones((seq_len, seq_len)))
return mask
# Example for sequence length 5
mask = create_causal_mask(5)
print("Causal mask:")
print(mask)
# Output:
# [[1. 0. 0. 0. 0.] Position 0 can only see position 0
# [1. 1. 0. 0. 0.] Position 1 can see positions 0, 1
# [1. 1. 1. 0. 0.] Position 2 can see positions 0, 1, 2
# [1. 1. 1. 1. 0.] ...
# [1. 1. 1. 1. 1.]] Position 4 can see all positions
When this mask is applied during attention computation, the softmax sees -infinity for masked positions, which become probability zero after the softmax. The model literally cannot attend to future tokens.
Positional Encoding
There's a subtle problem with attention as we've described it: it's completely position-agnostic. If you shuffle the words in your input, the attention mechanism produces the same outputs (just reordered). But word order matters enormously in language.
The original transformer solved this by adding positional encodings to the input embeddings. These are vectors that encode the position of each token in the sequence.
PE(pos, 2i+1) = cos(pos / 10000^(2i/d_model))
def positional_encoding(max_len, d_model):
"""
Generate sinusoidal positional encodings.
Args:
max_len: Maximum sequence length
d_model: Model dimension
Returns:
Positional encoding matrix (max_len, d_model)
"""
position = np.arange(max_len)[:, np.newaxis]
div_term = np.exp(np.arange(0, d_model, 2) * -(np.log(10000.0) / d_model))
pe = np.zeros((max_len, d_model))
pe[:, 0::2] = np.sin(position * div_term)
pe[:, 1::2] = np.cos(position * div_term)
return pe
# Generate encodings for 100 positions, dimension 64
pe = positional_encoding(100, 64)
print(f"Positional encoding shape: {pe.shape}")
print(f"First position encoding (first 8 dims): {pe[0, :8].round(3)}")
print(f"Second position encoding (first 8 dims): {pe[1, :8].round(3)}")
The sinusoidal functions have a beautiful property: the encoding for position p+k can be expressed as a linear function of the encoding for position p. This allows the model to learn relative positions through linear operations, which attention can perform.
Putting It All Together
A complete transformer encoder layer combines multi-head attention with feed-forward networks, layer normalisation, and residual connections. Here's how the pieces fit:
Complete transformer encoder layerdef layer_norm(x, eps=1e-6):
"""Layer normalisation."""
mean = np.mean(x, axis=-1, keepdims=True)
std = np.std(x, axis=-1, keepdims=True)
return (x - mean) / (std + eps)
def feed_forward(x, W1, b1, W2, b2):
"""Position-wise feed-forward network."""
hidden = np.maximum(0, np.matmul(x, W1) + b1) # ReLU activation
return np.matmul(hidden, W2) + b2
class TransformerEncoderLayer:
def __init__(self, d_model, num_heads, d_ff):
"""
Single transformer encoder layer.
Args:
d_model: Model dimension (512)
num_heads: Number of attention heads (8)
d_ff: Feed-forward hidden dimension (2048)
"""
self.attention = MultiHeadAttention(d_model, num_heads)
# Feed-forward weights
self.W1 = np.random.randn(d_model, d_ff) * np.sqrt(2.0 / d_model)
self.b1 = np.zeros(d_ff)
self.W2 = np.random.randn(d_ff, d_model) * np.sqrt(2.0 / d_ff)
self.b2 = np.zeros(d_model)
def forward(self, x, mask=None):
"""Forward pass with residual connections and layer norm."""
# Multi-head attention with residual
attn_output, _ = self.attention.forward(x, mask)
x = layer_norm(x + attn_output)
# Feed-forward with residual
ff_output = feed_forward(x, self.W1, self.b1, self.W2, self.b2)
x = layer_norm(x + ff_output)
return x
# Create an encoder layer
encoder_layer = TransformerEncoderLayer(
d_model=512,
num_heads=8,
d_ff=2048
)
# Process a batch
batch_size, seq_len, d_model = 2, 10, 512
x = np.random.randn(batch_size, seq_len, d_model)
output = encoder_layer.forward(x)
print(f"Input shape: {x.shape}")
print(f"Output shape: {output.shape}")
The residual connections (x + attn_output) are crucial for training deep transformers. They allow gradients to flow directly through the network, preventing the vanishing gradient problem. Layer normalisation stabilises training by keeping activations in a reasonable range.
The Computational Cost
Attention has a significant limitation: it scales quadratically with sequence length. For a sequence of n tokens, we compute an n×n attention matrix. Double your sequence length, and you quadruple your memory and compute requirements.
This is why context lengths were limited in early models. GPT-2's context of 1024 tokens required attention matrices of about 1 million elements per head per layer. More recent models with 100k+ token contexts use various optimisations: sparse attention patterns, linear attention approximations, or architectural changes like Mamba's state space approach.
Understanding this quadratic cost helps you understand both the limitations of current models and the research directions that might overcome them.
The Deeper Meaning
We've covered the mathematics, but I want to end with something more philosophical. Attention isn't just a clever computational trick. It represents a fundamentally different way of processing information.
Traditional neural networks process information through fixed computational graphs. The same operations happen regardless of the input. Attention makes the computation dynamic: what gets computed depends on what's being processed. The model learns not just what to compute, but what to pay attention to.
This is closer to how we might imagine cognition working. When you read a sentence, you don't process each word in isolation. You're constantly relating words to each other, building a web of connections that gives meaning to the whole. Attention is a mathematical approximation of this relational processing.
Whether it's sufficient for true understanding is another question entirely. But as a mechanism for learning patterns in sequential data, attention has proven remarkably powerful. The mathematics are simple. The implications are profound.