back to knowledge base
module 1211 min read

Advanced Deep Dives

FlashAttention (online softmax), RoPE (full derivation), a trainable LSTM from scratch with BPTT, and the KV-cache.

Three requested topics in full: FlashAttention, RoPE (rotary position embeddings), and a from-scratch, trainable LSTM with complete BPTT. Plus a bonus on the KV-cache, which they all interact with. Prereqs: [[04_transformers]] and [[03_rnn_lstm]].


12.1 FlashAttention — IO-aware exact attention

The problem it solves

Standard attention ([[04_transformers]] §4.3) computes

O=softmax ⁣(QKdk)V\mathbf{O} = \text{softmax}\!\left(\frac{\mathbf{Q}\mathbf{K}^\top}{\sqrt{d_k}}\right)\mathbf{V}

by materializing the full n×nn\times n score matrix S=QK\mathbf{S}=\mathbf{Q}\mathbf{K}^\top in GPU high-bandwidth memory (HBM). For sequence length nn, that's O(n2)O(n^2) memory and — crucially — O(n2)O(n^2) reads/writes to slow HBM. Attention is memory-bandwidth bound, not compute bound: the GPU spends most time moving the giant S\mathbf{S} and A=softmax(S)\mathbf{A}=\text{softmax}(\mathbf{S}) matrices to and from HBM, not doing math.

Key realization: the FLOPs are unavoidable (O(n2d)O(n^2 d)), but the memory traffic is not. FlashAttention computes the exact same result while never writing the n×nn\times n matrix to HBM — it keeps tiles in fast on-chip SRAM and fuses everything into one kernel.

The two ingredients

(1) Tiling. Split Q,K,V\mathbf{Q},\mathbf{K},\mathbf{V} into blocks that fit in SRAM. Loop over key/value blocks, accumulating each query block's output incrementally.

(2) Online softmax. The obstacle to tiling is that softmax needs the global max and sum over the whole row before normalizing. Online (streaming) softmax computes them incrementally with a running correction, so you never need the full row at once.

Online softmax — the math (derive it once)

We want softmax(x1,,xN)\text{softmax}(x_1,\dots,x_N) but receive the xix_i in blocks. Use the numerically stable form (subtract the max, [[10_math_appendix]] §10.4). Maintain three running quantities as we stream blocks:

  • mm = running max seen so far,
  • \ell = running sum of exp(xim)\exp(x_i - m),
  • o\mathbf{o} = running weighted sum of values exp(xim)vi\sum \exp(x_i-m)\,\mathbf{v}_i.

When a new block arrives with local max mblkm^{\text{blk}} and local sums, update:

mnew=max(m, mblk)m^{\text{new}} = \max(m,\ m^{\text{blk}}) new=emmnew  +  emblkmnewblk\ell^{\text{new}} = e^{\,m - m^{\text{new}}}\,\ell \;+\; e^{\,m^{\text{blk}} - m^{\text{new}}}\,\ell^{\text{blk}} onew=emmnewo  +  emblkmnewoblk\mathbf{o}^{\text{new}} = e^{\,m - m^{\text{new}}}\,\mathbf{o} \;+\; e^{\,m^{\text{blk}} - m^{\text{new}}}\,\mathbf{o}^{\text{blk}}

The factor emmnewe^{m - m^{\text{new}}} rescales the old accumulators to the new max — this correction is the whole trick. After the last block, the attention output for that query is o/\mathbf{o}/\ell. This is mathematically identical to computing softmax over the full row; FlashAttention is exact, not an approximation.

Why a tiny worked example convinces you

Stream x=[1,3]x = [1, 3] in two blocks, with values v1=10,v2=20v_1=10, v_2=20.

  • Block 1 (x1=1x_1=1): m=1, =e0=1, o=110=10m=1,\ \ell=e^{0}=1,\ o=1\cdot10=10.
  • Block 2 (x2=3x_2=3): mnew=max(1,3)=3m^{\text{new}}=\max(1,3)=3.
    • =e131+e331=e2+1=1.135\ell = e^{1-3}\cdot1 + e^{3-3}\cdot1 = e^{-2} + 1 = 1.135.
    • o=e210+e020=1.353+20=21.353o = e^{-2}\cdot10 + e^{0}\cdot20 = 1.353 + 20 = 21.353.
  • Output =o/=21.353/1.135=18.81= o/\ell = 21.353/1.135 = 18.81.

Check against direct softmax: weights (e1,e3)=(2.718,20.09)\propto (e^1,e^3)=(2.718,20.09), normalized (0.119,0.881)(0.119,0.881); 0.11910+0.88120=18.810.119\cdot10+0.881\cdot20=18.81. ✓ Identical.

Algorithm sketch

code
for each block of queries Qi:                 # outer loop (rows)
    init  m = -inf,  l = 0,  O = 0            # in SRAM
    for each block of keys/values Kj, Vj:     # inner loop (cols)
        S = Qi @ Kj^T / sqrt(dk)              # small tile, stays in SRAM
        m_blk = rowmax(S)
        P = exp(S - m_blk)                    # unnormalized weights
        m_new = max(m, m_blk)
        scale = exp(m - m_new)
        l = scale*l + exp(m_blk - m_new)*rowsum(P)
        O = scale*O + exp(m_blk - m_new)*(P @ Vj)
        m = m_new
    write O / l to HBM                        # only the n×d output is written

Results & properties

  • Memory: O(n)O(n) instead of O(n2)O(n^2) — you only ever store O(block2)O(\text{block}^2) tiles in SRAM and the n×dn\times d output. This is what lets models train with long contexts.
  • Speed: 2–4× faster wall-clock by slashing HBM traffic (despite recomputing S\mathbf{S} tiles in the backward pass instead of storing them — recompute is cheaper than HBM round-trips).
  • Exact: same numbers as vanilla attention (up to floating point).
  • Backward pass: recomputes the needed tiles on the fly using the stored m,m,\ell statistics, keeping memory O(n)O(n).
  • FlashAttention-2/3 add better work partitioning and use newer GPU features; the core idea is unchanged.

In practice

python
import torch.nn.functional as F
# PyTorch 2.x dispatches to a FlashAttention kernel automatically when eligible:
out = F.scaled_dot_product_attention(q, k, v, is_causal=True)   # fused, memory-efficient

12.2 RoPE — Rotary Position Embeddings

The goal

Sinusoidal/learned absolute position embeddings ([[04_transformers]] §4.2) are added to token embeddings. RoPE instead rotates the query and key vectors by an angle proportional to their position, so that the attention dot product qmkn\mathbf{q}_m\cdot\mathbf{k}_n depends only on the relative offset mnm-n.

The 2D core

Take a query vector and view its dimensions in pairs. For one pair (q(1),q(2))(q^{(1)}, q^{(2)}) at position mm, rotate by angle mθm\theta:

q~m=R(mθ)q,R(mθ)=[cosmθsinmθsinmθcosmθ]\tilde{\mathbf{q}}_m = \mathbf{R}(m\theta)\,\mathbf{q}, \qquad \mathbf{R}(m\theta) = \begin{bmatrix}\cos m\theta & -\sin m\theta\\ \sin m\theta & \cos m\theta\end{bmatrix}

Do the same for keys at position nn. The dot product of two rotated 2D vectors:

q~mk~n=(R(mθ)q)(R(nθ)k)=qR(mθ)R(nθ)k=qR((nm)θ)k\tilde{\mathbf{q}}_m \cdot \tilde{\mathbf{k}}_n = (\mathbf{R}(m\theta)\mathbf{q})^\top(\mathbf{R}(n\theta)\mathbf{k}) = \mathbf{q}^\top \mathbf{R}(m\theta)^\top \mathbf{R}(n\theta)\,\mathbf{k} = \mathbf{q}^\top \mathbf{R}((n-m)\theta)\,\mathbf{k}

using R(a)R(b)=R(ba)\mathbf{R}(a)^\top\mathbf{R}(b)=\mathbf{R}(b-a) (rotations compose additively). The result depends only on the relative position nmn-m — exactly what we want, and the absolute positions cancel.

Full dd-dimensional version

Split the dd-dim vector into d/2d/2 pairs. Pair ii rotates at its own frequency:

θi=b2i/d,i=0,1,,d21,(base b=10000)\theta_i = b^{-2i/d}, \qquad i = 0,1,\dots,\tfrac{d}{2}-1, \quad (\text{base } b = 10000)

Low-index pairs rotate fast (capture local/relative position), high-index pairs rotate slowly (capture long-range) — analogous to the multi-frequency idea in sinusoidal encodings, but applied multiplicatively to Q and K rather than added to the embedding.

The complex-number view (elegant): identify each pair with a complex number q(1)+iq(2)q^{(1)}+iq^{(2)}; rotation by mθim\theta_i is multiplication by eimθie^{im\theta_i}. Then q~m,k~n=Rei(qikˉi)ei(mn)θi\langle \tilde q_m, \tilde k_n\rangle = \text{Re}\sum_i (q_i \bar k_i)\,e^{i(m-n)\theta_i} — manifestly a function of mnm-n.

Worked micro-example

d=2d=2, one pair, θ=1\theta=1, query q=[1,0]\mathbf{q}=[1,0] at position m=2m=2:

q~=R(2)[1,0]=[cos2, sin2]=[0.416, 0.909]\tilde{\mathbf{q}} = \mathbf{R}(2)\,[1,0]^\top = [\cos 2,\ \sin 2] = [-0.416,\ 0.909]

A key k=[1,0]\mathbf{k}=[1,0] at position n=2n=2 rotates the same way → their dot product equals qR(0)k=1\mathbf{q}^\top\mathbf{R}(0)\mathbf{k}=1 (offset 0). At n=3n=3 the offset is 1, giving cos(1)=0.540\cos(1)=0.540 — the score now reflects a distance of one position, regardless of where the pair sits absolutely.

Implementation (the standard "rotate-half" trick)

python
import torch

def build_rope_cache(seq_len, dim, base=10000.0, device="cpu"):
    # frequencies θ_i for each pair
    inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, device=device).float() / dim))
    t = torch.arange(seq_len, device=device).float()
    freqs = torch.outer(t, inv_freq)            # (seq_len, dim/2): m * θ_i
    emb = torch.cat([freqs, freqs], dim=-1)     # duplicate for the rotate-half layout
    return emb.cos(), emb.sin()                 # each (seq_len, dim)

def rotate_half(x):
    x1, x2 = x.chunk(2, dim=-1)
    return torch.cat([-x2, x1], dim=-1)         # implements the 2D rotation per pair

def apply_rope(q, k, cos, sin):                 # q,k: (B, n, d); cos,sin: (n, d)
    q_rot = q * cos + rotate_half(q) * sin
    k_rot = k * cos + rotate_half(k) * sin
    return q_rot, k_rot

# usage inside attention, BEFORE computing Q·K^T:
# cos, sin = build_rope_cache(n, d_head)
# q, k = apply_rope(q, k, cos, sin)
# scores = q @ k.transpose(-2,-1) / sqrt(d_head)

Notice RoPE is applied to Q and K only (not V), right before the score computation — it shapes where tokens attend, not what content flows.

Why it generalizes to longer contexts

Because positions enter only as relative rotations, a model trained to length LL degrades gracefully beyond LL, and tricks like NTK-aware scaling / YaRN / linear position interpolation rescale θi\theta_i to extend context (e.g. 4k → 32k) with little or no retraining.


12.3 A trainable LSTM from scratch — full BPTT in NumPy

We now derive every gradient of an LSTM cell and implement a working trainer in pure NumPy. This is the capstone of [[03_rnn_lstm]] + [[01_deep_learning_foundations]].

Forward (single cell, recap)

At step tt, let zt=[ht1;xt]\mathbf{z}_t = [\mathbf{h}_{t-1};\,\mathbf{x}_t] (concatenation). With weight blocks Wf,Wi,Wc,Wo\mathbf{W}_f,\mathbf{W}_i,\mathbf{W}_c,\mathbf{W}_o and biases:

ft=σ(Wfzt+bf)it=σ(Wizt+bi)c~t=tanh(Wczt+bc)ot=σ(Wozt+bo)ct=ftct1+itc~tht=ottanh(ct)\begin{aligned} \mathbf{f}_t &= \sigma(\mathbf{W}_f \mathbf{z}_t + \mathbf{b}_f) & \mathbf{i}_t &= \sigma(\mathbf{W}_i \mathbf{z}_t + \mathbf{b}_i) \\ \tilde{\mathbf{c}}_t &= \tanh(\mathbf{W}_c \mathbf{z}_t + \mathbf{b}_c) & \mathbf{o}_t &= \sigma(\mathbf{W}_o \mathbf{z}_t + \mathbf{b}_o) \\ \mathbf{c}_t &= \mathbf{f}_t\odot\mathbf{c}_{t-1} + \mathbf{i}_t\odot\tilde{\mathbf{c}}_t & \mathbf{h}_t &= \mathbf{o}_t\odot\tanh(\mathbf{c}_t) \end{aligned}

Backward — derive each gradient

Let the gradient arriving at the hidden state from the layer above be dht\mathrm{d}\mathbf{h}_t, and the gradient flowing back through time into the cell state be dctnext\mathrm{d}\mathbf{c}_t^{\text{next}} (from step t+1t+1). Define helpers σ(gate)=gate(1gate)\sigma'(\text{gate}) = \text{gate}\odot(1-\text{gate}) and tanh(u)=1tanh2(u)\tanh'(u)=1-\tanh^2(u).

Through the output equation ht=ottanh(ct)\mathbf{h}_t=\mathbf{o}_t\odot\tanh(\mathbf{c}_t):

dot=dhttanh(ct)\mathrm{d}\mathbf{o}_t = \mathrm{d}\mathbf{h}_t \odot \tanh(\mathbf{c}_t) dct=dhtot(1tanh2(ct))  +  dctnext\mathrm{d}\mathbf{c}_t = \mathrm{d}\mathbf{h}_t \odot \mathbf{o}_t \odot \big(1-\tanh^2(\mathbf{c}_t)\big) \;+\; \mathrm{d}\mathbf{c}_t^{\text{next}}

(the cell-state gradient gets contributions both from ht\mathbf{h}_t and from the future step — this additive path is exactly why LSTMs don't vanish, [[03_rnn_lstm]] §3.3).

Through the cell update ct=ftct1+itc~t\mathbf{c}_t=\mathbf{f}_t\odot\mathbf{c}_{t-1}+\mathbf{i}_t\odot\tilde{\mathbf{c}}_t:

dft=dctct1,dit=dctc~t,dc~t=dctit\mathrm{d}\mathbf{f}_t = \mathrm{d}\mathbf{c}_t\odot\mathbf{c}_{t-1},\quad \mathrm{d}\mathbf{i}_t = \mathrm{d}\mathbf{c}_t\odot\tilde{\mathbf{c}}_t,\quad \mathrm{d}\tilde{\mathbf{c}}_t = \mathrm{d}\mathbf{c}_t\odot\mathbf{i}_t dct1next=dctft(passed to step t1)\mathrm{d}\mathbf{c}_{t-1}^{\text{next}} = \mathrm{d}\mathbf{c}_t\odot\mathbf{f}_t \quad(\text{passed to step } t-1)

Through the gate nonlinearities (pre-activation gradients):

daf=dftft(1ft),  dai=ditit(1it),\mathrm{d}\mathbf{a}_f = \mathrm{d}\mathbf{f}_t\odot\mathbf{f}_t\odot(1-\mathbf{f}_t),\; \mathrm{d}\mathbf{a}_i = \mathrm{d}\mathbf{i}_t\odot\mathbf{i}_t\odot(1-\mathbf{i}_t), dao=dotot(1ot),  dac=dc~t(1c~t2)\mathrm{d}\mathbf{a}_o = \mathrm{d}\mathbf{o}_t\odot\mathbf{o}_t\odot(1-\mathbf{o}_t),\; \mathrm{d}\mathbf{a}_c = \mathrm{d}\tilde{\mathbf{c}}_t\odot(1-\tilde{\mathbf{c}}_t^2)

Weight gradients (accumulate over all time steps, like the shared-weight rule in [[03_rnn_lstm]] §3.2). With zt=[ht1;xt]\mathbf{z}_t=[\mathbf{h}_{t-1};\mathbf{x}_t] and gate pre-activation gradient da\mathrm{d}\mathbf{a}_\bullet:

LW+=dazt,Lb+=da\frac{\partial L}{\partial \mathbf{W}_\bullet} \mathrel{+}= \mathrm{d}\mathbf{a}_\bullet\,\mathbf{z}_t^\top,\qquad \frac{\partial L}{\partial \mathbf{b}_\bullet}\mathrel{+}= \mathrm{d}\mathbf{a}_\bullet

Gradient to the previous hidden state (the top HH rows of Wda\mathbf{W}_\bullet^\top \mathrm{d}\mathbf{a}_\bullet, summed over gates):

dht1=[{f,i,c,o}Wda]1:H\mathrm{d}\mathbf{h}_{t-1} = \Big[\sum_{\bullet\in\{f,i,c,o\}} \mathbf{W}_\bullet^\top\, \mathrm{d}\mathbf{a}_\bullet\Big]_{1:H}

Full runnable implementation (learns to remember the first bit)

Task: read a length-TT binary sequence; output the first input bit at the last step. Solving it requires long-term memory across the whole sequence — a clean demo of the cell state working.

python
import numpy as np
rng = np.random.default_rng(0)
def sig(x):  return 1/(1+np.exp(-x))

H, D, T = 16, 1, 8          # hidden size, input dim, sequence length
Z = H + D
# one weight matrix per gate: (H, H+D)
def init(): return rng.normal(0, 0.1, (H, Z))
Wf, Wi, Wc, Wo = init(), init(), init(), init()
bf = np.ones(H); bi = np.zeros(H); bc = np.zeros(H); bo = np.zeros(H)  # forget bias=1 → remember by default
Wy = rng.normal(0, 0.1, (1, H)); by = np.zeros(1)
params = ['Wf','Wi','Wc','Wo','bf','bi','bc','bo','Wy','by']
lr = 0.1

def forward(xs):
    h, c = np.zeros(H), np.zeros(H); cache = []
    for x in xs:                       # x: (D,)
        z = np.concatenate([h, x])     # (H+D,)
        f, i = sig(Wf@z+bf), sig(Wi@z+bi)
        g, o = np.tanh(Wc@z+bc), sig(Wo@z+bo)
        c = f*c + i*g
        h = o*np.tanh(c)
        cache.append((z, f, i, g, o, c, h))
    y = Wy@h + by                      # final-step readout
    return sig(y), cache               # prediction in (0,1)

def train_step(xs, target):
    p, cache = forward(xs)
    loss = -(target*np.log(p+1e-9) + (1-target)*np.log(1-p+1e-9))  # BCE
    grads = {k: np.zeros_like(globals()[k]) for k in params}

    # output layer (BCE+sigmoid → clean p-target, see ch.1)
    dy = (p - target)                  # (1,)
    h_last = cache[-1][6]
    grads['Wy'] += np.outer(dy, h_last); grads['by'] += dy
    dh = Wy.T @ dy                     # (H,)
    dc_next = np.zeros(H)

    for t in reversed(range(T)):       # BPTT
        z, f, i, g, o, c, h = cache[t]
        c_prev = cache[t-1][5] if t > 0 else np.zeros(H)
        do = dh * np.tanh(c)
        dc = dh * o * (1-np.tanh(c)**2) + dc_next
        df, di, dg = dc*c_prev, dc*g, dc*i
        da_f = df * f*(1-f); da_i = di * i*(1-i)
        da_o = do * o*(1-o); da_c = dg * (1-g**2)
        for da, Wk, bk in [(da_f,'Wf','bf'),(da_i,'Wi','bi'),
                           (da_c,'Wc','bc'),(da_o,'Wo','bo')]:
            grads[Wk] += np.outer(da, z); grads[bk] += da
        dz = (Wf.T@da_f + Wi.T@da_i + Wc.T@da_c + Wo.T@da_o)
        dh = dz[:H]                    # gradient to previous hidden state
        dc_next = dc * f               # gradient to previous cell state
    # clip + SGD update
    for k in params:
        g_ = np.clip(grads[k], -5, 5)
        globals()[k] -= lr * g_
    return float(loss), float(p)

# ---- train ----
for step in range(3000):
    xs = rng.integers(0, 2, size=(T, D)).astype(float)
    target = xs[0, 0]                  # remember the FIRST bit
    loss, p = train_step(xs, target)
    if step % 500 == 0:
        print(f"step {step:4d}  loss {loss:.4f}")

# ---- test ----
correct = 0
for _ in range(200):
    xs = rng.integers(0, 2, size=(T, D)).astype(float)
    p, _ = forward(xs)
    correct += int((p[0] > 0.5) == bool(xs[0,0]))
print("accuracy:", correct/200)        # → ~1.0; the LSTM learned to carry bit 0 across 8 steps

Every line of the backward pass corresponds to a derivative we derived above. Swap the LSTM equations for the vanilla RNN recurrence and the same task fails as TT grows — a hands-on demonstration of vanishing gradients vs the gated fix.


12.4 Bonus: the KV-cache (why generation is fast)

In autoregressive decoding ([[05_architectures]] §5.5), generating token tt recomputes attention over all previous tokens. Naively that re-projects every past token's K,V\mathbf{K},\mathbf{V} each step → O(n2)O(n^2) wasted work.

KV-cache: store each layer's keys and values for past tokens; at step tt only compute the new token's qt,kt,vt\mathbf{q}_t,\mathbf{k}_t,\mathbf{v}_t, append kt,vt\mathbf{k}_t,\mathbf{v}_t to the cache, and attend qt\mathbf{q}_t against the cached K1:t,V1:t\mathbf{K}_{1:t},\mathbf{V}_{1:t}. Per-step cost drops from O(td)O(t\cdot d) recompute to O(d)O(d) projection + O(td)O(t\cdot d) attention read.

  • Memory cost: cache size =2nlayersnheadsdheadseq_lenbatch= 2 \cdot n_{\text{layers}} \cdot n_{\text{heads}} \cdot d_{\text{head}} \cdot \text{seq\_len} \cdot \text{batch} — this dominates LLM serving memory and grows linearly with context.
  • Optimizations: Multi-Query Attention (MQA) and Grouped-Query Attention (GQA) share K/V across heads to shrink the cache (used in LLaMA-2/3, Mistral); PagedAttention (vLLM) manages it like virtual memory to avoid fragmentation.
python
# conceptual decode loop with cache
k_cache, v_cache = [], []
for step in range(max_new):
    q, k, v = project(current_token)        # only the new token
    k_cache.append(k); v_cache.append(v)    # grow the cache
    attn = softmax(q @ stack(k_cache).T / sqrt(d)) @ stack(v_cache)
    next_token = sample(head(attn))
    current_token = next_token

Return to the index. Practice problems for these and all modules are in [[11_exercises]].