back to knowledge base
module 049 min read

Transformers

Self-attention from first principles, scaled dot-product, multi-head, positional encodings, the full block and its training math.

The Transformer (Vaswani et al., 2017, "Attention Is All You Need") replaced recurrence with attention, enabling full parallelism and direct long-range connections. It is the foundation of BERT, GPT, T5, and every modern LLM. We build it from the single most important idea: self-attention.


4.1 Tokenization & embeddings (the input)

Before any attention, text becomes vectors:

  1. Tokenize: split text into tokens (subwords) via Byte-Pair Encoding (BPE) / WordPiece / SentencePiece. E.g. "unhappiness" → ["un", "happiness"] or ["un", "happi", "ness"]. Each token maps to an integer id.
  2. Embed: an embedding matrix ERV×d\mathbf{E}\in\mathbb{R}^{V\times d} (VV=vocab size, dd=model dim) maps each id to a dd-dim vector via lookup. These are learned.
  3. Add positional information (next section) because attention itself is order-agnostic.

Result: input sequence of nn tokens → matrix XRn×d\mathbf{X}\in\mathbb{R}^{n\times d}.


4.2 Positional encoding

Attention treats the input as a set — it has no inherent notion of order. We must inject position.

Sinusoidal (original Transformer)

PE(pos,2i)=sin ⁣(pos100002i/d),PE(pos,2i+1)=cos ⁣(pos100002i/d)PE_{(pos,\,2i)} = \sin\!\left(\frac{pos}{10000^{2i/d}}\right), \qquad PE_{(pos,\,2i+1)} = \cos\!\left(\frac{pos}{10000^{2i/d}}\right)
  • pospos = position index, ii = dimension index.
  • Different dimensions oscillate at different frequencies (wavelengths from 2π2\pi to 100002π\sim10000\cdot2\pi).
  • Key property: PEpos+kPE_{pos+k} is a linear function of PEposPE_{pos} (rotation), so the model can learn to attend by relative position. Also extrapolates to unseen lengths.

We add it: XX+PE\mathbf{X} \leftarrow \mathbf{X} + \mathbf{PE}.

Modern variants

  • Learned absolute positions (BERT, GPT-2): a trainable position embedding table.
  • RoPE (Rotary Position Embedding): rotates query/key vectors by an angle proportional to position → encodes relative position directly in the dot product. Used in LLaMA, GPT-NeoX.
  • ALiBi: adds a distance-based linear bias to attention scores; great length extrapolation.

4.3 Self-attention — the heart of everything

Intuition

For each token, ask: "which other tokens are relevant to me, and how much?" Then build that token's new representation as a weighted blend of all tokens' values. "Attention" = these learned relevance weights.

Analogy — a soft dictionary lookup (the mechanics of how a word "asks around"):

  • Query (Q): what I'm looking for. ("I'm the word 'it', I'm looking for a noun I might refer to.")
  • Key (K): what each token advertises about itself. ("I'm 'animal', a noun, a subject.")
  • Value (V): the actual content a token hands over if matched. You match your query against all keys to get relevance weights (a good Query–Key match → high weight), then take a weighted blend of the values. It's like searching a library: your question (query) is compared to each book's title/index (key), and you walk away with a mix of the contents (values) of the best-matching books — not just one book, but a blend weighted by relevance.

Math — Scaled Dot-Product Attention

From input XRn×d\mathbf{X}\in\mathbb{R}^{n\times d}, project into queries, keys, values with learned matrices WQ,WKRd×dk\mathbf{W}^Q,\mathbf{W}^K\in\mathbb{R}^{d\times d_k}, WVRd×dv\mathbf{W}^V\in\mathbb{R}^{d\times d_v}:

Q=XWQ,K=XWK,V=XWV\mathbf{Q} = \mathbf{X}\mathbf{W}^Q, \qquad \mathbf{K} = \mathbf{X}\mathbf{W}^K, \qquad \mathbf{V} = \mathbf{X}\mathbf{W}^V

Then:

  Attention(Q,K,V)=softmax ⁣(QKdk)V  \boxed{\;\text{Attention}(\mathbf{Q},\mathbf{K},\mathbf{V}) = \text{softmax}\!\left(\frac{\mathbf{Q}\mathbf{K}^\top}{\sqrt{d_k}}\right)\mathbf{V}\;}

Step by step:

  1. Scores S=QKRn×n\mathbf{S} = \mathbf{Q}\mathbf{K}^\top \in \mathbb{R}^{n\times n}. Entry Sij=qikjS_{ij} = \mathbf{q}_i\cdot\mathbf{k}_j = how much token ii attends to token jj (dot product = similarity).
  2. Scale by dk\sqrt{d_k}. Why: if q,k\mathbf{q},\mathbf{k} have unit-variance independent entries, qk\mathbf{q}\cdot\mathbf{k} has variance dkd_k. Large dkd_k → large scores → softmax saturates → tiny gradients. Dividing by dk\sqrt{d_k} restores unit variance.
  3. Softmax over each row → attention weights ARn×n\mathbf{A}\in\mathbb{R}^{n\times n}, each row sums to 1.
  4. Weighted sum of values: output Z=AVRn×dv\mathbf{Z} = \mathbf{A}\mathbf{V}\in\mathbb{R}^{n\times d_v}. Row ii = token ii's new, context-aware representation.

Fully worked numeric example (do this once by hand!)

Two tokens, dk=2d_k=2. Suppose after projection:

Q=[1001],K=[1011],V=[100010]\mathbf{Q}=\begin{bmatrix}1 & 0\\ 0 & 1\end{bmatrix},\quad \mathbf{K}=\begin{bmatrix}1 & 0\\ 1 & 1\end{bmatrix},\quad \mathbf{V}=\begin{bmatrix}10 & 0\\ 0 & 10\end{bmatrix}

Scores QK\mathbf{Q}\mathbf{K}^\top:

S=[q1k1q1k2q2k1q2k2]=[1101]\mathbf{S}=\begin{bmatrix}\mathbf{q}_1\cdot\mathbf{k}_1 & \mathbf{q}_1\cdot\mathbf{k}_2\\ \mathbf{q}_2\cdot\mathbf{k}_1 & \mathbf{q}_2\cdot\mathbf{k}_2\end{bmatrix} =\begin{bmatrix}1 & 1\\ 0 & 1\end{bmatrix}

Scale by 2=1.414\sqrt{2}=1.414: [0.7070.70700.707]\begin{bmatrix}0.707 & 0.707\\ 0 & 0.707\end{bmatrix}. Softmax rows:

  • Row 1: softmax(0.707,0.707)=(0.5,0.5)\text{softmax}(0.707, 0.707) = (0.5, 0.5).
  • Row 2: softmax(0,0.707)\text{softmax}(0, 0.707). e0=1, e0.707=2.028e^0=1,\ e^{0.707}=2.028, sum =3.028=3.028(0.330,0.670)(0.330, 0.670).
A=[0.50.50.330.67]\mathbf{A}=\begin{bmatrix}0.5 & 0.5\\ 0.33 & 0.67\end{bmatrix}

Output Z=AV\mathbf{Z}=\mathbf{A}\mathbf{V}:

  • z1=0.5[10,0]+0.5[0,10]=[5,5]\mathbf{z}_1 = 0.5[10,0] + 0.5[0,10] = [5, 5].
  • z2=0.33[10,0]+0.67[0,10]=[3.3,6.7]\mathbf{z}_2 = 0.33[10,0] + 0.67[0,10] = [3.3, 6.7].

Token 1 blended both values equally; token 2 leaned toward value 2. This is attention doing its job: mixing information across positions based on learned similarity.


4.4 Multi-Head Attention (MHA)

One attention "head" learns one kind of relationship. We want many in parallel (syntax, coreference, etc.). Split dd into hh heads of dim dk=d/hd_k=d/h:

headi=Attention(XWiQ, XWiK, XWiV)\text{head}_i = \text{Attention}(\mathbf{X}\mathbf{W}^Q_i,\ \mathbf{X}\mathbf{W}^K_i,\ \mathbf{X}\mathbf{W}^V_i) MHA(X)=Concat(head1,,headh)WO\text{MHA}(\mathbf{X}) = \text{Concat}(\text{head}_1,\dots,\text{head}_h)\,\mathbf{W}^O

where WORd×d\mathbf{W}^O\in\mathbb{R}^{d\times d} recombines them. Each head attends in a different learned subspace; concatenation + projection fuses their findings. Cost is the same as one full-dim attention because each head is 1/h1/h as wide.

Example: d=512,h=8dk=dv=64d=512, h=8 \Rightarrow d_k=d_v=64 per head.


4.5 The complete Transformer block

A block stacks attention + a feed-forward network, each wrapped with residual connections (from [[02_cnns]]) and LayerNorm (from [[01_deep_learning_foundations]]).

Position-wise Feed-Forward Network (FFN)

Applied independently to each position:

FFN(x)=W2ϕ(W1x+b1)+b2\text{FFN}(\mathbf{x}) = \mathbf{W}_2\,\phi(\mathbf{W}_1\mathbf{x} + \mathbf{b}_1) + \mathbf{b}_2

with W1Rdff×d\mathbf{W}_1\in\mathbb{R}^{d_{ff}\times d}, W2Rd×dff\mathbf{W}_2\in\mathbb{R}^{d\times d_{ff}}, typically dff=4dd_{ff}=4d, ϕ=\phi= GELU/ReLU. This is where much of the model's "knowledge" and per-token nonlinear processing lives. Attention mixes across tokens; FFN processes each token.

Post-LN (original) vs Pre-LN (modern)

Post-LN (original paper):

x=LayerNorm(x+MHA(x))\mathbf{x}' = \text{LayerNorm}(\mathbf{x} + \text{MHA}(\mathbf{x})) x=LayerNorm(x+FFN(x))\mathbf{x}'' = \text{LayerNorm}(\mathbf{x}' + \text{FFN}(\mathbf{x}'))

Pre-LN (GPT-2 onward — more stable for deep nets):

x=x+MHA(LayerNorm(x))\mathbf{x}' = \mathbf{x} + \text{MHA}(\text{LayerNorm}(\mathbf{x})) x=x+FFN(LayerNorm(x))\mathbf{x}'' = \mathbf{x}' + \text{FFN}(\text{LayerNorm}(\mathbf{x}'))

The residual +x gives the gradient highway; LayerNorm stabilizes scale. Stack NN such blocks (e.g. 12 in BERT-base, 96 in GPT-3).

Block diagram

code
        ┌─────────────── + ◄──────────────┐  (residual)
x ──►LayerNorm──► Multi-Head Attention ────┘
        ┌─────────────── + ◄──────────────┐  (residual)
   └───►LayerNorm──► Feed-Forward (4d) ────┘ ──► output

4.6 Masking

Padding mask

Batched sequences are padded to equal length. We set attention scores for pad positions to -\infty before softmax so they get weight 0.

Causal (look-ahead) mask — for generation

In a decoder, token ii must not see future tokens j>ij>i (that would be cheating during next-token prediction). Apply a mask M\mathbf{M} with Mij=M_{ij}=-\infty for j>ij>i, else 00:

A=softmax ⁣(QKdk+M)\mathbf{A} = \text{softmax}\!\left(\frac{\mathbf{Q}\mathbf{K}^\top}{\sqrt{d_k}} + \mathbf{M}\right)

Concretely the mask is upper-triangular -\infty:

M=[000000]\mathbf{M} = \begin{bmatrix}0 & -\infty & -\infty\\ 0 & 0 & -\infty\\ 0 & 0 & 0\end{bmatrix}

After softmax, each token attends only to itself and earlier tokens. This single trick is what makes GPT autoregressive. ([[05_architectures]] details decoder-only models.)


4.7 Why dk\sqrt{d_k}, complexity, and properties

  • Complexity: self-attention is O(n2d)O(n^2 d) in time and O(n2)O(n^2) memory (the n×nn\times n score matrix). This quadratic cost in sequence length nn is the main scaling limitation → motivates FlashAttention (IO-aware exact attention), sparse/linear attention, sliding windows.
  • Path length: any two tokens interact in one layer (O(1)O(1) path) vs O(n)O(n) for RNNs → far better long-range modeling.
  • Parallelism: all positions computed simultaneously (no time recurrence) → GPU-friendly, the reason Transformers scaled.

4.8 Training a Transformer language model

Objective

Causal/autoregressive LM (GPT): predict the next token. With sequence w1wnw_1\dots w_n:

L=t=1nlogP(wtw1,,wt1)L = -\sum_{t=1}^{n} \log P(w_t \mid w_1,\dots,w_{t-1})

Each position's output goes through a linear "LM head" (often weight-tied to the embedding matrix E\mathbf{E}^\top) → softmax over vocab → cross-entropy against the actual next token (the clean y^y\hat y - y gradient from [[01_deep_learning_foundations]]). With causal masking, all nn next-token predictions are computed in one parallel forward pass ("teacher forcing").

Perplexity

A common metric: PPL=exp(Lavg)\text{PPL} = \exp(L_{\text{avg}}) = the effective branching factor. Lower is better.

Learning-rate schedule

The original used warmup then inverse-sqrt decay:

η=d0.5min ⁣(step0.5, stepwarmup1.5)\eta = d^{-0.5}\cdot\min\!\big(step^{-0.5},\ step\cdot warmup^{-1.5}\big)

Warmup avoids early instability when Adam's variance estimates are noisy; decay refines later. Modern LLMs use linear warmup + cosine decay with AdamW.

Inference / decoding strategies

Given the next-token distribution, pick a token:

  • Greedy: argmax (deterministic, can be repetitive).
  • Beam search: keep top-bb partial sequences (good for translation).
  • Temperature TT: divide logits by TT before softmax; T<1T<1 sharpens, T>1T>1 flattens.
  • Top-k: sample among the kk most likely tokens.
  • Top-p (nucleus): sample from the smallest set whose cumulative prob p\ge p.

4.9 Code: a Transformer block from scratch (PyTorch)

python
import torch, torch.nn as nn, torch.nn.functional as F, math

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, n_heads):
        super().__init__()
        assert d_model % n_heads == 0
        self.h, self.dk = n_heads, d_model // n_heads
        self.qkv = nn.Linear(d_model, 3*d_model)   # fused Q,K,V projection
        self.out = nn.Linear(d_model, d_model)

    def forward(self, x, mask=None):               # x: (B, n, d)
        B, n, d = x.shape
        qkv = self.qkv(x).reshape(B, n, 3, self.h, self.dk).permute(2,0,3,1,4)
        q, k, v = qkv[0], qkv[1], qkv[2]           # each (B, h, n, dk)
        scores = (q @ k.transpose(-2,-1)) / math.sqrt(self.dk)   # (B,h,n,n)
        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))
        attn = scores.softmax(-1)                  # attention weights
        z = attn @ v                               # (B,h,n,dk)
        z = z.transpose(1,2).reshape(B, n, d)      # concat heads
        return self.out(z)

class TransformerBlock(nn.Module):                 # Pre-LN variant
    def __init__(self, d_model, n_heads, d_ff, p=0.1):
        super().__init__()
        self.ln1 = nn.LayerNorm(d_model); self.attn = MultiHeadAttention(d_model, n_heads)
        self.ln2 = nn.LayerNorm(d_model)
        self.ff  = nn.Sequential(nn.Linear(d_model, d_ff), nn.GELU(),
                                 nn.Linear(d_ff, d_model))
        self.drop = nn.Dropout(p)
    def forward(self, x, mask=None):
        x = x + self.drop(self.attn(self.ln1(x), mask))   # residual + attention
        x = x + self.drop(self.ff(self.ln2(x)))           # residual + FFN
        return x

def causal_mask(n):
    return torch.tril(torch.ones(n, n)).bool()   # lower-triangular True

A minimal GPT

python
class MiniGPT(nn.Module):
    def __init__(self, vocab, d=256, n_heads=8, n_layers=6, max_len=512, d_ff=1024):
        super().__init__()
        self.tok = nn.Embedding(vocab, d)
        self.pos = nn.Embedding(max_len, d)
        self.blocks = nn.ModuleList([TransformerBlock(d, n_heads, d_ff) for _ in range(n_layers)])
        self.ln_f = nn.LayerNorm(d)
        self.head = nn.Linear(d, vocab, bias=False)
        self.head.weight = self.tok.weight          # weight tying
    def forward(self, idx):                          # idx: (B, n)
        B, n = idx.shape
        pos = torch.arange(n, device=idx.device)
        x = self.tok(idx) + self.pos(pos)            # embed + positional
        mask = causal_mask(n).to(idx.device)
        for blk in self.blocks:
            x = blk(x, mask)
        return self.head(self.ln_f(x))               # logits (B, n, vocab)

# training: logits=model(idx); loss=F.cross_entropy(logits[:,:-1].reshape(-1,V),
#                                                    idx[:,1:].reshape(-1))

4.10 Pitfalls & key intuitions

  • Q, K, V are the same input in self-attention; in cross-attention (decoder attending to encoder) Q comes from the decoder, K/V from the encoder ([[05_architectures]]).
  • Forgetting the mask in a decoder leaks future info → the model "cheats" and fails at generation.
  • Forgetting positional encoding → the model can't tell word order ("dog bites man" = "man bites dog").
  • Quadratic memory limits context length; that's an active research/engineering frontier.
  • Attention weights are somewhat interpretable but not a faithful explanation of the model's reasoning — treat attention maps cautiously.

Next: [[05_architectures]] — how encoder-only, decoder-only, and encoder-decoder models reuse these blocks for different jobs.