Three requested topics in full: FlashAttention, RoPE (rotary position embeddings), and a from-scratch, trainable LSTM with complete BPTT. Plus a bonus on the KV-cache, which they all interact with. Prereqs: [[04_transformers]] and [[03_rnn_lstm]].
12.1 FlashAttention — IO-aware exact attention
The problem it solves
Standard attention ([[04_transformers]] §4.3) computes
by materializing the full score matrix in GPU high-bandwidth memory (HBM). For sequence length , that's memory and — crucially — reads/writes to slow HBM. Attention is memory-bandwidth bound, not compute bound: the GPU spends most time moving the giant and matrices to and from HBM, not doing math.
Key realization: the FLOPs are unavoidable (), but the memory traffic is not. FlashAttention computes the exact same result while never writing the matrix to HBM — it keeps tiles in fast on-chip SRAM and fuses everything into one kernel.
The two ingredients
(1) Tiling. Split into blocks that fit in SRAM. Loop over key/value blocks, accumulating each query block's output incrementally.
(2) Online softmax. The obstacle to tiling is that softmax needs the global max and sum over the whole row before normalizing. Online (streaming) softmax computes them incrementally with a running correction, so you never need the full row at once.
Online softmax — the math (derive it once)
We want but receive the in blocks. Use the numerically stable form (subtract the max, [[10_math_appendix]] §10.4). Maintain three running quantities as we stream blocks:
- = running max seen so far,
- = running sum of ,
- = running weighted sum of values .
When a new block arrives with local max and local sums, update:
The factor rescales the old accumulators to the new max — this correction is the whole trick. After the last block, the attention output for that query is . This is mathematically identical to computing softmax over the full row; FlashAttention is exact, not an approximation.
Why a tiny worked example convinces you
Stream in two blocks, with values .
- Block 1 (): .
- Block 2 (): .
- .
- .
- Output .
Check against direct softmax: weights , normalized ; . ✓ Identical.
Algorithm sketch
codefor each block of queries Qi: # outer loop (rows) init m = -inf, l = 0, O = 0 # in SRAM for each block of keys/values Kj, Vj: # inner loop (cols) S = Qi @ Kj^T / sqrt(dk) # small tile, stays in SRAM m_blk = rowmax(S) P = exp(S - m_blk) # unnormalized weights m_new = max(m, m_blk) scale = exp(m - m_new) l = scale*l + exp(m_blk - m_new)*rowsum(P) O = scale*O + exp(m_blk - m_new)*(P @ Vj) m = m_new write O / l to HBM # only the n×d output is written
Results & properties
- Memory: instead of — you only ever store tiles in SRAM and the output. This is what lets models train with long contexts.
- Speed: 2–4× faster wall-clock by slashing HBM traffic (despite recomputing tiles in the backward pass instead of storing them — recompute is cheaper than HBM round-trips).
- Exact: same numbers as vanilla attention (up to floating point).
- Backward pass: recomputes the needed tiles on the fly using the stored statistics, keeping memory .
- FlashAttention-2/3 add better work partitioning and use newer GPU features; the core idea is unchanged.
In practice
pythonimport torch.nn.functional as F # PyTorch 2.x dispatches to a FlashAttention kernel automatically when eligible: out = F.scaled_dot_product_attention(q, k, v, is_causal=True) # fused, memory-efficient
12.2 RoPE — Rotary Position Embeddings
The goal
Sinusoidal/learned absolute position embeddings ([[04_transformers]] §4.2) are added to token embeddings. RoPE instead rotates the query and key vectors by an angle proportional to their position, so that the attention dot product depends only on the relative offset .
The 2D core
Take a query vector and view its dimensions in pairs. For one pair at position , rotate by angle :
Do the same for keys at position . The dot product of two rotated 2D vectors:
using (rotations compose additively). The result depends only on the relative position — exactly what we want, and the absolute positions cancel.
Full -dimensional version
Split the -dim vector into pairs. Pair rotates at its own frequency:
Low-index pairs rotate fast (capture local/relative position), high-index pairs rotate slowly (capture long-range) — analogous to the multi-frequency idea in sinusoidal encodings, but applied multiplicatively to Q and K rather than added to the embedding.
The complex-number view (elegant): identify each pair with a complex number ; rotation by is multiplication by . Then — manifestly a function of .
Worked micro-example
, one pair, , query at position :
A key at position rotates the same way → their dot product equals (offset 0). At the offset is 1, giving — the score now reflects a distance of one position, regardless of where the pair sits absolutely.
Implementation (the standard "rotate-half" trick)
pythonimport torch def build_rope_cache(seq_len, dim, base=10000.0, device="cpu"): # frequencies θ_i for each pair inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, device=device).float() / dim)) t = torch.arange(seq_len, device=device).float() freqs = torch.outer(t, inv_freq) # (seq_len, dim/2): m * θ_i emb = torch.cat([freqs, freqs], dim=-1) # duplicate for the rotate-half layout return emb.cos(), emb.sin() # each (seq_len, dim) def rotate_half(x): x1, x2 = x.chunk(2, dim=-1) return torch.cat([-x2, x1], dim=-1) # implements the 2D rotation per pair def apply_rope(q, k, cos, sin): # q,k: (B, n, d); cos,sin: (n, d) q_rot = q * cos + rotate_half(q) * sin k_rot = k * cos + rotate_half(k) * sin return q_rot, k_rot # usage inside attention, BEFORE computing Q·K^T: # cos, sin = build_rope_cache(n, d_head) # q, k = apply_rope(q, k, cos, sin) # scores = q @ k.transpose(-2,-1) / sqrt(d_head)
Notice RoPE is applied to Q and K only (not V), right before the score computation — it shapes where tokens attend, not what content flows.
Why it generalizes to longer contexts
Because positions enter only as relative rotations, a model trained to length degrades gracefully beyond , and tricks like NTK-aware scaling / YaRN / linear position interpolation rescale to extend context (e.g. 4k → 32k) with little or no retraining.
12.3 A trainable LSTM from scratch — full BPTT in NumPy
We now derive every gradient of an LSTM cell and implement a working trainer in pure NumPy. This is the capstone of [[03_rnn_lstm]] + [[01_deep_learning_foundations]].
Forward (single cell, recap)
At step , let (concatenation). With weight blocks and biases:
Backward — derive each gradient
Let the gradient arriving at the hidden state from the layer above be , and the gradient flowing back through time into the cell state be (from step ). Define helpers and .
Through the output equation :
(the cell-state gradient gets contributions both from and from the future step — this additive path is exactly why LSTMs don't vanish, [[03_rnn_lstm]] §3.3).
Through the cell update :
Through the gate nonlinearities (pre-activation gradients):
Weight gradients (accumulate over all time steps, like the shared-weight rule in [[03_rnn_lstm]] §3.2). With and gate pre-activation gradient :
Gradient to the previous hidden state (the top rows of , summed over gates):
Full runnable implementation (learns to remember the first bit)
Task: read a length- binary sequence; output the first input bit at the last step. Solving it requires long-term memory across the whole sequence — a clean demo of the cell state working.
pythonimport numpy as np rng = np.random.default_rng(0) def sig(x): return 1/(1+np.exp(-x)) H, D, T = 16, 1, 8 # hidden size, input dim, sequence length Z = H + D # one weight matrix per gate: (H, H+D) def init(): return rng.normal(0, 0.1, (H, Z)) Wf, Wi, Wc, Wo = init(), init(), init(), init() bf = np.ones(H); bi = np.zeros(H); bc = np.zeros(H); bo = np.zeros(H) # forget bias=1 → remember by default Wy = rng.normal(0, 0.1, (1, H)); by = np.zeros(1) params = ['Wf','Wi','Wc','Wo','bf','bi','bc','bo','Wy','by'] lr = 0.1 def forward(xs): h, c = np.zeros(H), np.zeros(H); cache = [] for x in xs: # x: (D,) z = np.concatenate([h, x]) # (H+D,) f, i = sig(Wf@z+bf), sig(Wi@z+bi) g, o = np.tanh(Wc@z+bc), sig(Wo@z+bo) c = f*c + i*g h = o*np.tanh(c) cache.append((z, f, i, g, o, c, h)) y = Wy@h + by # final-step readout return sig(y), cache # prediction in (0,1) def train_step(xs, target): p, cache = forward(xs) loss = -(target*np.log(p+1e-9) + (1-target)*np.log(1-p+1e-9)) # BCE grads = {k: np.zeros_like(globals()[k]) for k in params} # output layer (BCE+sigmoid → clean p-target, see ch.1) dy = (p - target) # (1,) h_last = cache[-1][6] grads['Wy'] += np.outer(dy, h_last); grads['by'] += dy dh = Wy.T @ dy # (H,) dc_next = np.zeros(H) for t in reversed(range(T)): # BPTT z, f, i, g, o, c, h = cache[t] c_prev = cache[t-1][5] if t > 0 else np.zeros(H) do = dh * np.tanh(c) dc = dh * o * (1-np.tanh(c)**2) + dc_next df, di, dg = dc*c_prev, dc*g, dc*i da_f = df * f*(1-f); da_i = di * i*(1-i) da_o = do * o*(1-o); da_c = dg * (1-g**2) for da, Wk, bk in [(da_f,'Wf','bf'),(da_i,'Wi','bi'), (da_c,'Wc','bc'),(da_o,'Wo','bo')]: grads[Wk] += np.outer(da, z); grads[bk] += da dz = (Wf.T@da_f + Wi.T@da_i + Wc.T@da_c + Wo.T@da_o) dh = dz[:H] # gradient to previous hidden state dc_next = dc * f # gradient to previous cell state # clip + SGD update for k in params: g_ = np.clip(grads[k], -5, 5) globals()[k] -= lr * g_ return float(loss), float(p) # ---- train ---- for step in range(3000): xs = rng.integers(0, 2, size=(T, D)).astype(float) target = xs[0, 0] # remember the FIRST bit loss, p = train_step(xs, target) if step % 500 == 0: print(f"step {step:4d} loss {loss:.4f}") # ---- test ---- correct = 0 for _ in range(200): xs = rng.integers(0, 2, size=(T, D)).astype(float) p, _ = forward(xs) correct += int((p[0] > 0.5) == bool(xs[0,0])) print("accuracy:", correct/200) # → ~1.0; the LSTM learned to carry bit 0 across 8 steps
Every line of the backward pass corresponds to a derivative we derived above. Swap the LSTM equations for the vanilla RNN recurrence and the same task fails as grows — a hands-on demonstration of vanishing gradients vs the gated fix.
12.4 Bonus: the KV-cache (why generation is fast)
In autoregressive decoding ([[05_architectures]] §5.5), generating token recomputes attention over all previous tokens. Naively that re-projects every past token's each step → wasted work.
KV-cache: store each layer's keys and values for past tokens; at step only compute the new token's , append to the cache, and attend against the cached . Per-step cost drops from recompute to projection + attention read.
- Memory cost: cache size — this dominates LLM serving memory and grows linearly with context.
- Optimizations: Multi-Query Attention (MQA) and Grouped-Query Attention (GQA) share K/V across heads to shrink the cache (used in LLaMA-2/3, Mistral); PagedAttention (vLLM) manages it like virtual memory to avoid fragmentation.
python# conceptual decode loop with cache k_cache, v_cache = [], [] for step in range(max_new): q, k, v = project(current_token) # only the new token k_cache.append(k); v_cache.append(v) # grow the cache attn = softmax(q @ stack(k_cache).T / sqrt(d)) @ stack(v_cache) next_token = sample(head(attn)) current_token = next_token
Return to the index. Practice problems for these and all modules are in [[11_exercises]].