Attention Is All You Need - Paper Analysis
Overview
The "Attention Is All You Need" paper introduced the Transformer architecture, fundamentally changing how we approach sequence modeling in machine learning. Before transformers, recurrent neural networks (RNNs) and LSTMs dominated NLP tasks.
The Problem: RNNs process sequences sequentially, which makes parallel processing impossible and creates difficulties in learning long-range dependencies. Information has to flow through many intermediate states, leading to vanishing gradients.
The Solution: Replace recurrence entirely with attention mechanisms. This allows the model to directly connect any two positions in a sequence, regardless of distance, and enables massive parallelization during training.
Self-Attention Mechanism
Self-attention is the core innovation. It allows each position in a sequence to attend to all positions in the same sequence, computing relationships between all pairs of words simultaneously.
Query, Key, Value (Q, K, V)
The attention mechanism uses three learned linear projections of the input:
- Query (Q): What am I looking for? Represents the current position asking questions.
- Key (K): What do I contain? Represents what each position offers.
- Value (V): What information do I hold? The actual content to be aggregated.
Think of it like a database: you query (Q) for information, match against keys (K) in the database, and retrieve the corresponding values (V). The difference here is that everything happens in a continuous, differentiable way.
Attention Formula
The scaled dot-product attention is computed as:
Attention(Q, K, V) = softmax(Q * K^T / sqrt(d_k)) * V
Where:
- Q * K^T computes similarity scores between all query-key pairs
- sqrt(d_k) scales the scores (d_k is the key dimension)
- softmax converts scores to probabilities
- Multiply by V to get weighted sum of values
Here's how it works step by step:
- Compute attention scores: Multiply queries by keys (Q * K^T) to get compatibility scores
- Scale: Divide by sqrt(d_k) to prevent exploding values
- Apply softmax: Convert scores to probabilities that sum to 1
- Weighted sum: Multiply probabilities by values and sum
import torch
import torch.nn.functional as F
def scaled_dot_product_attention(Q, K, V, mask=None):
"""
Q: (batch, heads, seq_len, d_k)
K: (batch, heads, seq_len, d_k)
V: (batch, heads, seq_len, d_v)
"""
d_k = Q.size(-1)
# Step 1 & 2: Compute scaled scores
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(d_k)
# Optional: Apply mask (for padding or causal attention)
if mask is not None:
scores = scores.masked_fill(mask == 0, -1e9)
# Step 3: Apply softmax
attention_weights = F.softmax(scores, dim=-1)
# Step 4: Weighted sum of values
output = torch.matmul(attention_weights, V)
return output, attention_weights
Multi-Head Attention
Instead of performing a single attention function, the transformer uses multiple attention "heads" in parallel. Each head learns different aspects of the relationships between words.
Why multiple heads? Different heads can focus on different types of relationships:
- One head might focus on syntactic relationships (subject-verb agreement)
- Another might capture semantic similarity (synonyms, related concepts)
- Others might learn positional patterns or long-range dependencies
MultiHead(Q, K, V) = Concat(head_1, ..., head_h) * W^O
Where each head_i = Attention(Q*W_i^Q, K*W_i^K, V*W_i^V)
- h is the number of heads (typically 8 or 16)
- W_i^Q, W_i^K, W_i^V are learned projection matrices for each head
- W^O is the output projection matrix
The key insight is that by splitting the model dimension into multiple heads, we don't increase computational cost significantly (since each head operates on a smaller dimension), but we gain the ability to attend to information from different representation subspaces.
Positional Encoding
Here's a critical challenge: Unlike RNNs that naturally process sequences in order, the transformer processes all positions simultaneously. This means it has no inherent notion of position or order.
The Solution: Add positional encodings to the input embeddings. These are fixed functions (not learned) that inject information about the relative or absolute position of tokens.
PE(pos, 2i) = sin(pos / 10000^(2i/d_model))
PE(pos, 2i+1) = cos(pos / 10000^(2i/d_model))
Where:
- pos is the position in the sequence
- i is the dimension
- d_model is the model dimension
Why sine and cosine functions? They have several nice properties:
- Different frequencies for different dimensions
- Values are bounded between -1 and 1
- The model can learn to attend by relative positions (since sin/cos have predictable differences)
- Can extrapolate to longer sequences than seen during training
If you visualize positional encodings as a heatmap (positions vs dimensions), you'll see beautiful wave patterns. Lower dimensions have longer wavelengths (slower oscillations), while higher dimensions oscillate more rapidly. This creates a unique "fingerprint" for each position.
The Transformer Architecture
The full transformer consists of an encoder and decoder, each made up of stacked layers. Each layer has two main sub-layers:
Encoder
- Multi-head self-attention: Each position attends to all positions in the input
- Position-wise feed-forward network: Two linear transformations with ReLU activation
Both sub-layers use residual connections and layer normalization:
output = LayerNorm(x + Sublayer(x))
Where Sublayer is either attention or feed-forward
Decoder
The decoder has three sub-layers:
- Masked multi-head self-attention: Attends to previously generated positions (autoregressive)
- Multi-head cross-attention: Attends to the encoder output
- Position-wise feed-forward network: Same as encoder
Impact and Legacy
It's hard to overstate the impact of this paper. The transformer architecture has become the foundation for virtually all modern NLP systems and has expanded far beyond its original domain.
Direct Descendants
- BERT (2018): Encoder-only transformer for understanding tasks
- GPT series (2018-2024): Decoder-only transformers for generation
- T5 (2019): Text-to-text framework using full encoder-decoder
- BART (2019): Denoising autoencoder using full transformer
Beyond NLP
The transformer architecture has been successfully adapted to:
- Computer Vision: Vision Transformers (ViT), CLIP, DALL-E
- Speech: Wav2Vec, Whisper
- Protein Folding: AlphaFold 2
- Reinforcement Learning: Decision Transformer
- Multi-modal AI: GPT-4, Gemini
"The transformer didn't just improve NLP—it provided a universal architecture for learning from sequences, images, and beyond. It showed that attention is indeed all you need."
Key Innovations That Lasted
- Parallelization: Training is massively faster than RNNs
- Long-range dependencies: Direct connections between any positions
- Transfer learning: Pre-trained transformers transfer exceptionally well
- Scalability: Performance improves with model size (scaling laws)