Attention Mechanism: The Eyes of Neural Networks
Attention allows neural networks to focus on relevant parts of the input dynamically. From machine translation to LLMs and multimodal AI — complete mathematical and practical reference for additive, multiplicative, self, and multi-head attention.
Bahdanau
Additive (2014)
Luong
Multiplicative (2015)
Self-Attention
Q, K, V (2017)
Multi-Head
Parallel attention
What is Attention?
Attention is a neural component that dynamically computes a weighted sum of values, where weights depend on the similarity between a query and corresponding keys. It allows models to focus on specific parts of the input when producing each output element — mimicking visual attention.
↓
Attention Weights (softmax)
↓
Weighted Sum → Context Vector × Values (V)
Core idea: Not all input elements are equally important. Learn to assign importance dynamically.
The Alignment Problem: Why Attention?
Seq2Seq without Attention
Encoder compresses entire source into one fixed-size vector → information bottleneck. Long sentences degrade rapidly.
"I love cats" → fixed vector (5-dim) → "Je ___ ?"
Seq2Seq with Attention
Decoder looks at all encoder states, weights them dynamically. Solves bottleneck, improves long-range translation.
Alignment: "cat" ↔ "chat" at step 3
Bahdanau Attention (Additive)
Additive Attention Score
eᵢⱼ = vᵃ tanh(Wₐ [sᵢ₋₁; hⱼ])
or concat version: score(s, h) = vᵃ tanh(Wₐ[s; h])
Context vector cᵢ = Σⱼ αᵢⱼ hⱼ
Historical Significance
First attention mechanism for NLP. Used in RNN encoder-decoders. Computationally expensive (fully connected layer per alignment).
Bidirectional RNN Concatenation tanh
class BahdanauAttention(nn.Module):
def __init__(self, hidden_dim):
super().__init__()
self.W_a = nn.Linear(hidden_dim * 2, hidden_dim) # [s; h]
self.v_a = nn.Linear(hidden_dim, 1, bias=False)
def forward(self, query, encoder_outputs):
# query: decoder hidden (batch, hidden)
# encoder_outputs: (batch, seq_len, hidden)
seq_len = encoder_outputs.size(1)
query = query.unsqueeze(1).repeat(1, seq_len, 1) # (batch, seq_len, hidden)
# Combine query and encoder outputs
energy = torch.tanh(self.W_a(torch.cat((query, encoder_outputs), dim=2))) # (batch, seq_len, hidden)
scores = self.v_a(energy).squeeze(2) # (batch, seq_len)
attn_weights = torch.softmax(scores, dim=1)
context = torch.bmm(attn_weights.unsqueeze(1), encoder_outputs).squeeze(1)
return context, attn_weights
Luong Attention (Multiplicative)
Scoring Functions
Dot: score = sᵀ h
General: score = sᵀ W h
Concat: score = vᵀ tanh(W[s; h])
Key Differences
Luong computes attention after decoder output (vs before in Bahdanau). Simpler, faster. Uses top-layer state only.
Types: global (all source steps) vs local (window).
def luong_dot_attention(query, encoder_outputs):
# query: (batch, 1, hidden)
# encoder_outputs: (batch, seq_len, hidden)
scores = torch.bmm(query, encoder_outputs.transpose(1, 2)) # (batch, 1, seq_len)
attn_weights = torch.softmax(scores, dim=2)
context = torch.bmm(attn_weights, encoder_outputs)
return context, attn_weights
Scaled Dot-Product Attention
The Transformer Formula
Attention(Q, K, V) = softmax(QKᵀ / √dₖ) V
Q, K, V: queries, keys, values matrices.
√dₖ: scaling factor prevents softmax saturation.
Why Scaling?
For large dₖ, dot products grow large in magnitude, pushing softmax into regions of vanishing gradients. Scaling fixes this.
def scaled_dot_product_attention(Q, K, V, mask=None):
d_k = Q.size(-1)
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(d_k)
if mask is not None:
scores = scores.masked_fill(mask == 0, -1e9)
attention_weights = torch.softmax(scores, dim=-1)
output = torch.matmul(attention_weights, V)
return output, attention_weights
Multi-Head Attention
Instead of one attention function, project Q, K, V h times with different linear projections, perform attention in parallel, concatenate, and project.
Diverse Representations
Each head learns different relationships: syntactic, semantic, coreference, positional.
MultiHead(Q,K,V)
Concat(head₁,...,headₕ)Wᴼ
headᵢ = Attention(QWᵢ^Q, KWᵢ^K, VWᵢ^V)
Typical Values
h = 8, 12, 16, 32. dₖ = d_v = d_model / h.
Attention Variants: Self, Cross, Causal
Self-Attention
Q, K, V from same sequence. Each token attends to all tokens in the same sequence. Captures intra-sequence dependencies.
Encoders BERT
Cross-Attention
Q from decoder, K, V from encoder. Decoder attends to input sequence. Essential for seq2seq.
T5, BART
Causal (Masked) Attention
Prevents attending to future tokens. Upper triangular mask set to -∞. Used in autoregressive decoders.
GPT, Llama
def causal_mask(size):
"""Upper triangular matrix with zeros on diagonal and below, -inf above"""
mask = torch.triu(torch.ones(size, size), diagonal=1).bool()
return mask # True where future tokens (to be masked)
Visualizing Attention Weights
Alignment Matrix
Plot attention weights as heatmap. Rows = decoder steps, Cols = encoder steps. Reveals word alignment.
[0.1, 0.8, 0.1]
[0.1, 0.1, 0.8]
Probing Attention Heads
Certain heads specialize: positional heads attend to previous/next token, syntactic heads attend to dependent tokens, rare word heads.
Attention Beyond NLP
Vision
Spatial attention: Attend to relevant image regions. ViT uses self-attention on patches. Cross-attention in image captioning.
Audio
Speech recognition: Attend to acoustic frames. Listen, Attend and Spell (LAS).
Video
Temporal attention: Focus on relevant frames. Video transformers.
Multimodal
CLIP, Flamingo, LLaVA: cross-attention between image and text.
Graphs
Graph Attention Networks (GAT): attend to neighbor nodes.
Reinforcement Learning
Attend to relevant observations in memory.
Attention Types – Cheatsheet
Attention Mechanism Comparison
| Attention Type | Score Function | Complexity | Typical Use |
|---|---|---|---|
| Bahdanau (Additive) | vᵃ tanh(W[s; h]) | O(n·d²) | RNN seq2seq |
| Luong (Dot) | sᵀ h | O(n·d) | RNN, efficient |
| Scaled Dot-Product | QKᵀ/√d | O(n²·d) | Transformers |
| Multi-Head | h × scaled dot | O(n²·d·h) | BERT, GPT |
| Graph Attention | LeakyReLU(aᵀ[Whᵢ; Whⱼ]) | O(E·d) | Graph networks |