back to knowledge base
module 036 min read

RNNs & LSTMs

Sequence modeling, BPTT, vanishing/exploding gradients, LSTM & GRU gate math, worked example.

Sequences (text, audio, time series) have order and variable length. RNNs process them one step at a time, carrying a memory (hidden state) forward. This chapter derives the recurrence, the famous vanishing-gradient problem, and how gates (LSTM/GRU) fix it.


3.1 The vanilla RNN

Intuition

Read tokens one at a time. Keep a running summary ht\mathbf{h}_t ("what I've seen so far"). Update it at each step from the previous summary and the new input.

Math

At time step tt, input xtRd\mathbf{x}_t\in\mathbb{R}^{d}, hidden state htRH\mathbf{h}_t\in\mathbb{R}^{H}:

ht=tanh(Wxhxt+Whhht1+bh)\mathbf{h}_t = \tanh(\mathbf{W}_{xh}\mathbf{x}_t + \mathbf{W}_{hh}\mathbf{h}_{t-1} + \mathbf{b}_h) y^t=Whyht+by(+ softmax for classification)\hat{\mathbf{y}}_t = \mathbf{W}_{hy}\mathbf{h}_t + \mathbf{b}_y \quad(\text{+ softmax for classification})

Crucially, the same weights Wxh,Whh,Why\mathbf{W}_{xh}, \mathbf{W}_{hh}, \mathbf{W}_{hy} are reused at every time step (weight sharing across time, analogous to CNN weight sharing across space). h0\mathbf{h}_0 is usually zeros.

Unrolling

An RNN over TT steps is equivalent to a TT-layer feedforward net where every layer shares weights:

code
x1 → [RNN] → h1 → [RNN] → h2 → ... → [RNN] → hT
      ↑h0          ↑h1                ↑h_{T-1}

Tiny numeric example

H=1H=1, Wxh=0.5, Whh=0.9, bh=0\mathbf{W}_{xh}=0.5,\ \mathbf{W}_{hh}=0.9,\ b_h=0, tanh\tanh, inputs x1=1,x2=1x_1=1, x_2=1, h0=0h_0=0.

  • h1=tanh(0.51+0.90)=tanh(0.5)=0.4621h_1 = \tanh(0.5\cdot1 + 0.9\cdot0) = \tanh(0.5) = 0.4621
  • h2=tanh(0.51+0.90.4621)=tanh(0.9159)=0.7237h_2 = \tanh(0.5\cdot1 + 0.9\cdot0.4621) = \tanh(0.9159) = 0.7237

The state accumulates history: h2h_2 depends on both x1x_1 (through h1h_1) and x2x_2.


3.2 Backpropagation Through Time (BPTT)

To train, unroll and apply backprop across all time steps. The total loss is the sum over steps:

L=t=1TLt(y^t,yt)L = \sum_{t=1}^{T} L_t(\hat{\mathbf{y}}_t, \mathbf{y}_t)

The gradient w.r.t. the shared recurrent weight Whh\mathbf{W}_{hh} accumulates contributions from every time step:

LWhh=t=1TLtWhh=t=1Tk=1tLtht(j=k+1thjhj1)hkWhh\frac{\partial L}{\partial \mathbf{W}_{hh}} = \sum_{t=1}^{T} \frac{\partial L_t}{\partial \mathbf{W}_{hh}} = \sum_{t=1}^{T}\sum_{k=1}^{t} \frac{\partial L_t}{\partial \mathbf{h}_t}\left(\prod_{j=k+1}^{t} \frac{\partial \mathbf{h}_j}{\partial \mathbf{h}_{j-1}}\right)\frac{\partial \mathbf{h}_k}{\partial \mathbf{W}_{hh}}

The dangerous term is the product of Jacobians:

hjhj1=diag ⁣(tanh())Whh\frac{\partial \mathbf{h}_j}{\partial \mathbf{h}_{j-1}} = \text{diag}\!\big(\tanh'(\cdot)\big)\,\mathbf{W}_{hh}

Vanishing & exploding gradients

That product has T\sim T factors. Roughly its magnitude scales like Whhtktanh\|\mathbf{W}_{hh}\|^{\,t-k} \cdot \prod \tanh':

  • If the dominant eigenvalue of Whh\mathbf{W}_{hh} (times tanh1\tanh'\le1) is < 1 → product 0\to 0vanishing gradient: the network can't learn long-range dependencies (early inputs have no effect on late losses).
  • If > 1 → product blows up → exploding gradient → NaNs.

Mitigations:

  • Gradient clipping for explosion: if g>τ\|\mathbf{g}\| > \tau, rescale gτg/g\mathbf{g} \leftarrow \tau\,\mathbf{g}/\|\mathbf{g}\|.
  • Truncated BPTT: only backprop kk steps back (e.g. 35) — cheaper, limits the product length.
  • Gated architectures (LSTM/GRU) for vanishing — the real fix, below.

3.3 LSTM (Long Short-Term Memory)

Intuition

Add a separate cell state ct\mathbf{c}_t that acts like a conveyor belt of memory, modified only by gates that decide what to forget, what to add, and what to output. Because the cell state is updated additively (not by repeated matrix multiply), gradients flow through it without vanishing.

The gate equations (memorize the structure, not the symbols)

At each step, with input xt\mathbf{x}_t and previous ht1\mathbf{h}_{t-1}:

ft=σ(Wf[ht1,xt]+bf)forget gate (what to erase from c)it=σ(Wi[ht1,xt]+bi)input gate (how much new info to write)c~t=tanh(Wc[ht1,xt]+bc)candidate cell (the new info)ct=ftct1+itc~tupdate cell state (additive!)ot=σ(Wo[ht1,xt]+bo)output gateht=ottanh(ct)hidden state / output\begin{aligned} \mathbf{f}_t &= \sigma(\mathbf{W}_f[\mathbf{h}_{t-1}, \mathbf{x}_t] + \mathbf{b}_f) && \text{forget gate (what to erase from } \mathbf{c}) \\ \mathbf{i}_t &= \sigma(\mathbf{W}_i[\mathbf{h}_{t-1}, \mathbf{x}_t] + \mathbf{b}_i) && \text{input gate (how much new info to write)} \\ \tilde{\mathbf{c}}_t &= \tanh(\mathbf{W}_c[\mathbf{h}_{t-1}, \mathbf{x}_t] + \mathbf{b}_c) && \text{candidate cell (the new info)} \\ \mathbf{c}_t &= \mathbf{f}_t \odot \mathbf{c}_{t-1} + \mathbf{i}_t \odot \tilde{\mathbf{c}}_t && \text{update cell state (additive!)} \\ \mathbf{o}_t &= \sigma(\mathbf{W}_o[\mathbf{h}_{t-1}, \mathbf{x}_t] + \mathbf{b}_o) && \text{output gate} \\ \mathbf{h}_t &= \mathbf{o}_t \odot \tanh(\mathbf{c}_t) && \text{hidden state / output} \end{aligned}

[ht1,xt][\mathbf{h}_{t-1}, \mathbf{x}_t] = concatenation. Each gate is a sigmoid in (0,1)(0,1) acting as a soft switch (0 = block, 1 = pass).

Why gradients survive

The cell-state recurrence is ct=ftct1+\mathbf{c}_t = \mathbf{f}_t\odot\mathbf{c}_{t-1} + \dots, so

ctct1=diag(ft)\frac{\partial \mathbf{c}_t}{\partial \mathbf{c}_{t-1}} = \text{diag}(\mathbf{f}_t)

When the forget gate ft1\mathbf{f}_t \approx 1, this Jacobian is \approx identity → the product over time stays 1\approx 1no vanishing. The network learns when to remember (keep f1\mathbf{f}\to1) vs forget (drive f0\mathbf{f}\to0). This is the "constant error carousel."

Worked single-step (1-dim, conceptual numbers)

Say ct1=0.8c_{t-1}=0.8. Gates compute ft=0.9f_t=0.9 (mostly remember), it=0.5i_t=0.5, c~t=0.6\tilde c_t=0.6, ot=0.7o_t=0.7.

  • ct=0.90.8+0.50.6=0.72+0.30=1.02c_t = 0.9\cdot0.8 + 0.5\cdot0.6 = 0.72 + 0.30 = 1.02
  • ht=0.7tanh(1.02)=0.70.7699=0.539h_t = 0.7\cdot\tanh(1.02) = 0.7\cdot0.7699 = 0.539

The cell kept 90% of old memory and wrote half of the new candidate.

Parameter count

4 gates, each a dense layer over [h;x][\mathbf{h};\mathbf{x}] of size (H+d)H(H+d)\to H:

#params=4[(H+d)H+H]\#\text{params} = 4\big[(H+d)\cdot H + H\big]

3.4 GRU (Gated Recurrent Unit) — a lighter LSTM

Merges cell and hidden state, uses 2 gates instead of 3. Fewer parameters, often comparable performance.

zt=σ(Wz[ht1,xt])update gatert=σ(Wr[ht1,xt])reset gateh~t=tanh(Wh[rtht1,xt])candidateht=(1zt)ht1+zth~tinterpolate old/new\begin{aligned} \mathbf{z}_t &= \sigma(\mathbf{W}_z[\mathbf{h}_{t-1}, \mathbf{x}_t]) && \text{update gate} \\ \mathbf{r}_t &= \sigma(\mathbf{W}_r[\mathbf{h}_{t-1}, \mathbf{x}_t]) && \text{reset gate} \\ \tilde{\mathbf{h}}_t &= \tanh(\mathbf{W}_h[\mathbf{r}_t \odot \mathbf{h}_{t-1}, \mathbf{x}_t]) && \text{candidate} \\ \mathbf{h}_t &= (1-\mathbf{z}_t)\odot \mathbf{h}_{t-1} + \mathbf{z}_t \odot \tilde{\mathbf{h}}_t && \text{interpolate old/new} \end{aligned}

The update gate zt\mathbf{z}_t interpolates between keeping the old state and adopting the new candidate — same additive-memory benefit as LSTM.


3.5 Architectural patterns for sequences

PatternShapeExample task
one-to-one1 in → 1 outimage classification (not really RNN)
one-to-many1 in → seq outimage captioning
many-to-oneseq in → 1 outsentiment classification
many-to-many (aligned)seq in → seq out, same lengthPOS tagging
many-to-many (seq2seq)seq in → seq out, different lengthtranslation
  • Bidirectional RNN: run one RNN forward and one backward, concatenate hidden states → each position sees both past and future context. Great for tagging; can't be used for autoregressive generation.
  • Stacked/deep RNN: feed one RNN's outputs as inputs to another.
  • Encoder–decoder (seq2seq): an encoder RNN compresses the input into a context vector; a decoder RNN generates the output from it. This is the direct ancestor of the Transformer encoder-decoder in [[05_architectures]]. Its bottleneck (one fixed context vector) motivated attention, which [[04_transformers]] takes to its full conclusion.

3.6 Code: LSTM for sentiment in PyTorch

python
import torch, torch.nn as nn

class SentimentLSTM(nn.Module):
    def __init__(self, vocab_size, emb=128, hidden=256, num_layers=2):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, emb, padding_idx=0)
        self.lstm  = nn.LSTM(emb, hidden, num_layers,
                             batch_first=True, bidirectional=True, dropout=0.3)
        self.fc    = nn.Linear(hidden*2, 1)   # *2 for bidirectional

    def forward(self, x):                      # x: (B, T) token ids
        e = self.embed(x)                      # (B, T, emb)
        out, (h_n, c_n) = self.lstm(e)         # out: (B, T, 2*hidden)
        # concat final forward (h_n[-2]) and backward (h_n[-1]) hidden states
        final = torch.cat([h_n[-2], h_n[-1]], dim=1)   # (B, 2*hidden)
        return self.fc(final).squeeze(1)       # (B,) logits → BCEWithLogitsLoss

model = SentimentLSTM(vocab_size=20000)
loss_fn = nn.BCEWithLogitsLoss()               # sigmoid+BCE fused (see ch.1)
opt = torch.optim.Adam(model.parameters(), 1e-3)

# gradient clipping — essential for RNNs:
# loss.backward()
# torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=5.0)
# opt.step()

A character-level RNN forward, from scratch (NumPy)

python
import numpy as np
def rnn_step(x, h, Wxh, Whh, bh):
    return np.tanh(Wxh @ x + Whh @ h + bh)     # one time step

H, D = 4, 3
Wxh = np.random.randn(H, D)*0.1
Whh = np.random.randn(H, H)*0.1
bh  = np.zeros(H)
h = np.zeros(H)
sequence = [np.random.randn(D) for _ in range(5)]
for x in sequence:
    h = rnn_step(x, h, Wxh, Whh, bh)           # state carries forward
print("final hidden state:", h.round(3))

3.7 Limitations that motivated Transformers

  1. Sequential computation — step tt needs step t1t-1, so you can't parallelize across time → slow training on long sequences.
  2. Long-range dependencies — even LSTMs struggle past a few hundred tokens; the path between distant tokens is O(distance)O(\text{distance}).
  3. Fixed-size context vector in seq2seq bottlenecks information.

The fix: attention lets every token directly look at every other token in O(1)O(1) path length and fully in parallel. That is the subject of [[04_transformers]].

Next: [[04_transformers]] — attention is all you need.