Skip to content

Backpropagation Through Time

Recurrent networks reuse the same parameters at every time step.

Recurrent networks reuse the same parameters at every time step. This makes them compact, but it also changes how gradients are computed. A parameter such as WhhW_{hh} affects not one layer, but every recurrent transition in the sequence.

Backpropagation through time, usually abbreviated BPTT, is the method used to train recurrent neural networks. It applies ordinary backpropagation to the unrolled recurrent graph.

The Unrolled Graph

A recurrent network is defined recursively:

ht=f(ht1,xt;θ). h_t = f(h_{t-1}, x_t; \theta).

For a sequence of length TT, this recurrence expands into

h1=f(h0,x1;θ), h_1 = f(h_0, x_1; \theta), h2=f(h1,x2;θ), h_2 = f(h_1, x_2; \theta), \ldots hT=f(hT1,xT;θ). h_T = f(h_{T-1}, x_T; \theta).

The same parameter set θ\theta appears at every step. After unrolling, the recurrent model becomes a deep feedforward graph with shared weights.

If the model produces a loss at every step, the total loss may be written as

L=t=1TLt. L = \sum_{t=1}^{T} L_t.

The training problem is to compute

θL. \nabla_\theta L.

Gradients with Shared Parameters

Because the same parameters are used repeatedly, the total gradient is the sum of contributions from all time steps:

Lθ=t=1TLtθ. \frac{\partial L}{\partial \theta} = \sum_{t=1}^{T} \frac{\partial L_t}{\partial \theta}.

However, this expression hides an important detail. The loss at time tt depends on earlier hidden states through the recurrence. Therefore, LtL_t depends on θ\theta both directly and indirectly.

For example, L5L_5 depends on:

h5,h4,h3,h2,h1. h_5, h_4, h_3, h_2, h_1.

Since each hidden state was computed using θ\theta, the gradient must flow backward through the entire chain.

Gradient Flow Through Hidden States

Suppose the loss depends only on the final hidden state:

L=(hT). L = \ell(h_T).

Then the gradient with respect to an earlier hidden state hth_t is computed by repeated application of the chain rule:

Lht=LhThThT1hT1hT2ht+1ht. \frac{\partial L}{\partial h_t} = \frac{\partial L}{\partial h_T} \frac{\partial h_T}{\partial h_{T-1}} \frac{\partial h_{T-1}}{\partial h_{T-2}} \cdots \frac{\partial h_{t+1}}{\partial h_t}.

This product of Jacobian matrices determines how strongly information from step tt influences the final loss.

If the product grows, gradients explode. If the product shrinks, gradients vanish. This is the central training difficulty of basic recurrent networks.

A Simple RNN Example

Consider the simple recurrent update:

ht=tanh(Wxhxt+Whhht1+bh). h_t = \tanh(W_{xh}x_t + W_{hh}h_{t-1} + b_h).

Let

at=Wxhxt+Whhht1+bh. a_t = W_{xh}x_t + W_{hh}h_{t-1} + b_h.

Then

ht=tanh(at). h_t = \tanh(a_t).

The derivative of hth_t with respect to ht1h_{t-1} is

htht1=diag(1ht2)Whh. \frac{\partial h_t}{\partial h_{t-1}} = \operatorname{diag}(1 - h_t^2) W_{hh}.

The gradient over many steps contains products of terms like this:

k=t+1Tdiag(1hk2)Whh. \prod_{k=t+1}^{T} \operatorname{diag}(1 - h_k^2) W_{hh}.

The repeated multiplication explains why recurrent networks can be unstable over long sequences.

Loss at Every Time Step

Many sequence tasks produce an output at every step. Examples include language modeling and sequence labeling.

The model may compute

yt=g(ht) y_t = g(h_t)

and a loss

Lt=(yt,y^t). L_t = \ell(y_t, \hat{y}_t).

The total loss is usually the mean or sum over time:

L=1Tt=1TLt. L = \frac{1}{T} \sum_{t=1}^{T} L_t.

In PyTorch, this often looks like:

logits, h_n = rnn_model(x)

# logits: [B, T, V]
# targets: [B, T]

loss = criterion(
    logits.reshape(B * T, V),
    targets.reshape(B * T),
)

The reshape converts the sequence problem into a batch of B×TB \times T classification problems.

BPTT in PyTorch

PyTorch builds the computation graph dynamically as the forward pass executes. When a recurrent layer processes a sequence, PyTorch records the operations needed for backpropagation.

A minimal training step:

optimizer.zero_grad()

output, h_n = rnn(x)
logits = classifier(output)

loss = criterion(
    logits.reshape(B * T, num_classes),
    targets.reshape(B * T),
)

loss.backward()
optimizer.step()

The call to loss.backward() performs backpropagation through the unrolled recurrent computation.

The user does not manually write BPTT. PyTorch’s autograd system computes it from the graph.

Truncated Backpropagation Through Time

For long sequences, full BPTT can be expensive. It requires storing hidden states and intermediate activations for every time step.

If TT is large, memory usage grows with sequence length.

Truncated backpropagation through time limits how far gradients flow backward. Instead of backpropagating through an entire sequence, we split the sequence into shorter chunks.

For example, a sequence of length 1000 may be processed in chunks of 100:

steps 1-100
steps 101-200
steps 201-300
...

The hidden state is carried forward between chunks, but the gradient graph is detached.

In PyTorch:

h = None

for x_chunk, y_chunk in chunks:
    optimizer.zero_grad()

    output, h = rnn(x_chunk, h)
    logits = classifier(output)

    loss = criterion(
        logits.reshape(-1, num_classes),
        y_chunk.reshape(-1),
    )

    loss.backward()
    optimizer.step()

    h = h.detach()

The line

h = h.detach()

breaks the computation graph. The numerical hidden state is reused, but gradients do not flow into earlier chunks.

This reduces memory cost and makes long sequence training practical.

Detaching Hidden States

Detaching a tensor keeps its value but removes its connection to earlier computation.

This matters in recurrent training because hidden states connect chunks over time.

Without detaching, PyTorch would attempt to backpropagate through every previous chunk. That may cause high memory usage or an error when the graph has already been freed.

Example:

h = h.detach()

For LSTMs, the hidden state has two tensors, so both must be detached:

h = tuple(state.detach() for state in h)

Detaching is a mathematical choice. It says: use the past state as context, but do not update parameters based on dependencies before this boundary.

Exploding Gradients

Because BPTT multiplies many Jacobians, gradients can become very large. This is called exploding gradients.

Symptoms include:

SymptomDescription
Loss becomes nanNumerical overflow occurred
Training becomes unstableLoss jumps suddenly
Parameters grow rapidlyWeight norms become very large
Gradients spikeGradient norm increases sharply

A standard remedy is gradient clipping.

torch.nn.utils.clip_grad_norm_(
    model.parameters(),
    max_norm=1.0,
)

Gradient clipping rescales gradients when their norm exceeds a threshold. It does not solve every optimization problem, but it is widely used in recurrent training.

Vanishing Gradients

Vanishing gradients occur when the product of Jacobians becomes very small. Then early time steps receive little learning signal.

The model may learn short-range patterns but fail to capture long-range dependencies.

For example, in a sentence:

The keys to the cabinet near the door are missing.

The verb “are” depends on “keys,” which appears many words earlier. A basic RNN may struggle to preserve this dependency.

Gated recurrent architectures were designed to reduce this problem. LSTMs and GRUs introduce additive memory paths and gates that regulate information flow.

Computational Cost

For a sequence length TT, hidden size HH, and batch size BB, a simple RNN requires approximately:

O(BTH2) O(BTH^2)

computation when the recurrent matrix multiplication dominates.

Memory for full BPTT grows roughly with

O(BTH) O(BTH)

for stored hidden states, plus additional memory for intermediate activations.

This linear dependence on TT is one reason truncated BPTT is important for long sequences.

BPTT Versus Transformer Training

Recurrent networks process time steps sequentially. The hidden state at step tt depends on the hidden state at step t1t-1. This limits parallelism.

Transformers replace recurrence with attention. During training, all positions can be processed in parallel, although causal masks may restrict which positions can attend to which other positions.

This difference is one reason transformers became dominant in large-scale language modeling.

However, recurrence still has useful properties:

PropertyRecurrent networksTransformers
Training parallelism over timeLimitedHigh
Memory per token during inferenceCompact hidden stateGrowing key-value cache
Natural streamingStrongRequires careful caching
Long sequence trainingHard with full BPTTExpensive due to attention

RNNs remain useful for streaming signals, embedded systems, online prediction, and cases where compact state matters.

Summary

Backpropagation through time trains recurrent neural networks by applying the chain rule to the unrolled recurrent graph.

Because recurrent parameters are shared across time, their gradients collect contributions from many sequence positions. Gradients must pass through repeated hidden-state transitions, which creates the possibility of exploding and vanishing gradients.

Full BPTT is exact for the unrolled sequence but can be expensive for long inputs. Truncated BPTT reduces memory and computation by limiting the length of the gradient path. In PyTorch, this is usually implemented by detaching hidden states between chunks.