GPU Workshop — Activation Checkpointing Deep Dive

Activation Recomputation

The fundamental memory-compute trade-off in transformer training. Save massive GPU memory by recomputing activations instead of storing them.

34
SBH per layer (activations)
5nS²
Quadratic attention term
~33%
Max compute overhead
90%+
Activation memory saved

What Are Activations?

Every intermediate tensor computed during the forward pass is an "activation". They're the hidden state of your network at each layer.

Input EmbeddingsShape: [S, B, H]
Repeated L times
LayerNorm 1Stores: 4 * S*B*H elements
Multi-Head Attention Stores: 11*S*B*H + 5*n_heads*S²*B elements + Residual
LayerNorm 2Stores: (included above)
Feed-Forward Network (FFN/MLP) Stores: 19 * S*B*H elements + Residual
Output / LossLogits + Cross-Entropy
šŸ”µ

Attention Activations

The largest component at long sequences. Contains 11 * S*B*H linear terms plus 5 * n_heads * S² * B — the infamous quadratic scaling from attention scores and softmax outputs.

Per Layer 11 · SBH + 5 · nheads · S² · B
🟢

FFN Activations

The feed-forward network stores intermediate results from its two large linear projections and the activation function (GeLU/SwiGLU). Scales linearly with sequence length.

Per Layer 19 · S · B · H
🟠

LayerNorm Activations

Small but always stored. LayerNorm needs to save input statistics (mean, variance) and the normalized output for the backward pass. Rarely worth recomputing.

Per Layer 4 · S · B · H
Total activations per layer: (11 + 19 + 4) * S*B*H + 5 * n_heads * S² * B = 34*S*B*H + 5*n_heads*S²*B. For bf16 (2 bytes per element), multiply by 2 to get bytes. The S² term is what makes long-context training so memory-hungry.

Training Memory Breakdown

GPU memory during training is split between model state (fixed) and activations (scales with sequence length).

Total GPU Memory Mtotal = Mparams + Mgradients + Moptimizer + Mactivations

Parameters

4N bytes (fp32)

Fixed cost. E.g., 8B model = 32 GB

Gradients

4N bytes (fp32)

Same size as parameters

Optimizer

8N bytes (Adam m,v)

Adam stores 2 states per param

Activations

Varies with S, B

This is what we can recompute!

Llama 3.1 8B — Sequence length: 4096
102420484096819216384
Total: — GiB
Params
Gradients
Optimizer
Attention
FFN
LayerNorm

Activation Recomputation

Don't store activations — recompute them during the backward pass. Trade compute time for memory.

šŸ’¾
Memory Saved

Don't store intermediates

ā‡Œ
⚔
Compute Overhead

Re-run forward during backward

Baseline (No Recomputation)

All activations stored in GPU memory during forward pass, used during backward pass.

Forward Pass
Layer 1
Compute
Store āœ“
Layer 2
Compute
Store āœ“
Backward Pass
Layer 2
Use stored āœ“
Layer 1
Use stored āœ“
Problem: Memory grows linearly with number of layers × sequence length.

With Recomputation (Checkpointing)

Discard activations during forward, recompute them on-the-fly during backward.

Forward Pass
Layer 1
Compute
Discard āœ—
Checkpoint
Layer 2
Compute
Discard āœ—
Checkpoint
Backward Pass
Layer 2
Recompute
Use āœ“
Layer 1
Recompute
Use āœ“
Result: Only store small "checkpoint tensors" at layer boundaries. Massive memory savings!

What Are Checkpoint Tensors?

When you discard activations, you can't throw away everything — you still need some "save points" so the backward pass knows where to start recomputing from. These save points are called checkpoint tensors.

Think of it like a video game
šŸ’¾
Checkpoint = Save Point

You save your game at the start of each level. If you need to replay, you load from the save point.

→
šŸ”„
Recompute = Replay Level

In backward pass, load the save point and re-play the forward computation for that layer to regenerate all intermediates.

Concretely, for a transformer layer:
šŸ’¾ Checkpoint
input to layer
Shape: [S, B, H]
→
LN output āœ—
Q, K, V āœ—
Attn scores āœ—
FFN hidden āœ—
SAVED (tiny)                                DISCARDED (huge) — will recompute from checkpoint during backward
Checkpoint memory per layer ~1 · S · B · H · 2 bytes   (just one hidden-state tensor in bf16)
Why ~1*S*B*H? The checkpoint is typically the input hidden state to each transformer layer — a single tensor of shape [S, B, H]. That's S × B × H elements, stored in bf16 (2 bytes each). Compare this to the 34*S*B*H + 5*n*S²*B elements you'd store without checkpointing — the checkpoint is ~34× smaller (ignoring the S² term which makes it even more dramatic).

So checkpointing is not "free" in memory — you do store one small tensor per layer. But it's a tiny fraction of what you'd store without checkpointing. That's why the code adds CHECKPOINT_SBH_PER_LAYER = 1 to the memory model: even when discarding everything, there's still this small residual cost.

Three Checkpointing Strategies

šŸ”µ Strategy A

Recompute Attention

Discard all attention activations (both 11*SBH and 5*n*S²*B terms). Keeps FFN activations stored.

Eliminates the S² term! Memory growth becomes linear.
🟢 Strategy B

Recompute FFN

Discard FFN activations (19*SBH). Keeps attention activations stored — the S² term remains.

S² term still present. Less memory savings at long sequences.
🟣 Strategy C

Recompute Both

Discard both attention and FFN activations. Only LayerNorm + checkpoints remain in memory.

Maximum savings. Only ~5*SBH per layer stored.

Memory vs. Sequence Length

See how different checkpointing strategies affect memory across Llama 3.1 model sizes. Click the buttons to switch strategies.

Memory Breakdown — Llama 3.1 8B — Baseline

All activations stored. Memory grows quadratically with sequence length due to attention S² term.
Params
Gradients
Optimizer
Checkpoints
LayerNorm
FFN
Attention

How FLOPs Are Computed

Understanding the compute cost of each transformer component is key to understanding recomputation overhead.

Training Step FLOPs FLOPsstep = FLOPsforward + FLOPsbackward ā‰ˆ F + 2F = 3F
Why backward ā‰ˆ 2× forward? For each matmul Y = XW in the forward pass, the backward pass computes two matmuls of similar size: dW = X⊃T dY (gradient w.r.t. weights) and dX = dY W⊃T (gradient w.r.t. inputs).

Attention FLOPs (per layer)

Projections (Q, K, V, Output) 8 · B · S · H²
Attention Matmuls (QK⊃T, Score·V) 4 · B · S² · H

The S² term in attention compute means attention FLOPs grow quadratically with sequence length, but for typical model sizes, the 8BSH² projection term often still dominates.

FFN FLOPs (per layer)

Two Linear Projections (expansion=4) 16 · B · S · H²

The FFN is purely linear in S — no quadratic terms. But with expansion factor 4, its FLOPs are 2× the attention projections. This makes FFN the compute-dominant block in most configurations.

Key insight: FFN has more FLOPs than attention, but attention has the S² memory term.

Compute Overhead Formula

With Recomputation Overhead = Frecomputed / (3 · Ftotal)

Where Do These Percentages Come From?

Let's derive them from the FLOPs formulas. Per layer, the forward FLOPs are:

Attention forward Fattn = 8·B·S·H² (projections) + 4·B·S²·H (score matmuls)
FFN forward (expansion=4) Fffn = 16·B·S·H²

Case 1: Short sequences (S « H), the S² term is tiny:

Ftotal ≈ 8BSH² + 16BSH² = 24·B·S·H²
Attn overhead
8 / (3×24) = 11.1%
FFN overhead
16 / (3×24) = 22.2%
Both overhead
24 / (3×24) = 33.3%

Case 2: When S = H (e.g., S=4096, H=4096):

Now the 4·B·S²·H term is no longer negligible. Since S=H, this equals 4BSH², so:

Fattn = 8BSH² + 4BSH² = 12·B·S·H²
Ftotal = 12BSH² + 16BSH² = 28·B·S·H²
Attn overhead
12 / (3×28) = 14.3%
FFN overhead
16 / (3×28) = 19.0%
Both overhead
28 / (3×28) = 33.3%
Key takeaway: As sequence length grows, attention's compute share increases (because of the S² FLOPs term), so attention recomputation overhead rises from ~11% toward ~14%+. FFN overhead decreases slightly. "Recompute Both" is always exactly 33.3% because you're replaying the entire forward pass once: F / (3F) = 1/3.

Recompute Attention

~11–14%

F_attn / (3 * F_total)

Lower overhead because attention FLOPs are a minority. Rises slightly at long sequences as the S² compute term grows.

Recompute FFN

~19–22%

F_ffn / (3 * F_total)

Larger overhead because FFN's two big matmuls (16BSH²) dominate the forward compute budget.

Recompute Both

33.3%

F / (3 * F) = 1/3

Always exactly one-third — you replay the full forward once out of the 3F total step compute.

Worked Example

Suppose per-layer forward FLOPs are: Attention = 100, FFN = 300

StrategyStep FLOPsOverhead
Baseline3 × 400 = 1200—
Recompute Attention1200 + 100 = 1300+8.3%
Recompute FFN1200 + 300 = 1500+25%
Recompute Both1200 + 400 = 1600+33.3%

Compute Overhead vs. Sequence Length

How recomputation overhead varies with sequence length. Note: attention overhead increases at longer sequences due to S² FLOPs.
Recompute Attention
Recompute FFN
Recompute Both
No Recompute

Why Checkpoint Attention?

Attention checkpointing gives you the best bang for your buck: maximum memory savings with minimal compute overhead.

Attention: Memory Hog, Compute Lightweight

Memory share:
Dominates at long S
Compute share:
~33-43% of forward

Attention activations scale as O(S²) in memory but its compute is only a fraction of total FLOPs. Discarding attention saves the most memory per FLOP of recomputation.

FFN: Compute Heavy, Memory Linear

Memory share:
Linear in S
Compute share:
~67% of forward

FFN dominates compute (2× more FLOPs than attention projections) but its memory only scales O(S). Recomputing FFN is expensive for modest memory gains.

Memory Saved per % Compute Overhead

StrategyMemory Saved (8B, S=16K)Compute OverheadEfficiency Ratio
Verdict: Recomputing attention saves ~10× more memory per unit of extra compute compared to recomputing FFN. This is why most training frameworks default to attention checkpointing.

The Advantage Grows with Sequence Length

As sequence length increases, attention's memory share grows quadratically while its compute share grows less steeply. This means the efficiency of attention checkpointing improves at longer sequences — exactly where you need it most.

Flash Attention

Flash Attention is a fused, tiled attention kernel that's both faster and more memory-efficient — and it has deep connections to activation checkpointing.

First: What is SRAM? The GPU Memory Hierarchy

A GPU doesn't have one uniform pool of memory. It has a hierarchy, just like a CPU has L1/L2 cache and main RAM. Understanding this hierarchy is the key to understanding why Flash Attention works.

GPU Chip (e.g., NVIDIA A100)
SM 0
SRAM
192 KB
SM 1
SRAM
192 KB
SM 2
SRAM
192 KB
... ×108
Total SRAM
~20 MB
on-chip bus (very fast, ~19 TB/s)
CUDA Cores
Tensor Cores
Off-chip memory bus (~2 TB/s — 10× slower!)
HBM (High Bandwidth Memory)
80 GB
A100
Model weights
Activations
Optimizer states
PropertySRAM (On-Chip)HBM (Off-Chip)
What it is Static RAM cache inside each Streaming Multiprocessor (SM). Also called "shared memory" or "L1 cache". High Bandwidth Memory — the main GPU memory. Multiple stacks of DRAM chips on the GPU package.
Size ~192 KB per SM, ~20 MB total 40-80 GB
Bandwidth ~19 TB/s ~2 TB/s (A100)
Latency ~28 cycles ~200-400 cycles
Programmer controls? Yes — CUDA shared memory, explicitly managed Yes — all torch.tensor allocations live here
The core problem: SRAM is ~10× faster than HBM, but ~4000× smaller. Standard attention writes the entire S × S attention matrix to HBM because it's too big for SRAM. With S=4096 and 32 heads, that's 32 × 4096 × 4096 × 2 bytes = 1 GB per layer — way too large for 20 MB of SRAM. So the GPU spends most of its time waiting for data to move between HBM and SRAM, rather than actually doing math.

Standard Attention vs Flash Attention

Standard Attention

Each step reads/writes the full matrix to HBM:

1.
Load Q, K from HBM → compute S = QKT → write S to HBM
2.
Load S from HBM → compute P = softmax(S) → write P to HBM
3.
Load P, V from HBM → compute O = PV → write O to HBM
3 round-trips to HBM! S×S matrix written and read repeatedly. Memory-bandwidth bound.
HBM reads + writes:
O(S²)

Flash Attention

Tiles Q, K, V into blocks that fit in SRAM:

1.
Load one tile of Q, K, V from HBM → SRAM
2.
Compute QKT, softmax, × V — all in SRAM
3.
Accumulate partial output → repeat for next tile → write only final O to HBM
S×S matrix never leaves SRAM! Only Q, K, V, O touch HBM. Compute-bound.
Q tiles
Q1
Q2
Q3
Q4
×
K tiles
K1
K2
K3
K4
→
In SRAM
QiKjT → softmax → ×Vj
One tile at a time
HBM reads + writes:
O(S² / M) where M = SRAM size

But wait — how can you do softmax on tiles?

Softmax needs the entire row to compute the denominator sum(exp(x_i)). If you only have one tile, you don't see the full row. Flash Attention solves this with the online softmax trick:

1

Process tile 1: compute local softmax, keep running max and running sum(exp)

2

Process tile 2: if new max is larger, rescale the previous partial sum and output using correction factor exp(old_max - new_max)

3

Repeat for all tiles. At the end, the accumulated output is mathematically identical to full softmax — just computed incrementally.

This is why Flash Attention is exact (not an approximation). It gives the same result as standard attention, just computed in a more memory-efficient order.

Flash Attention = Built-in Activation Checkpointing

Flash Attention doesn't store the full S×S attention matrix during the forward pass. In the backward pass, it recomputes the attention scores tile-by-tile from the stored Q, K, V matrices. This is exactly the same idea as activation checkpointing — but baked into the CUDA kernel itself.

AspectStandard AttentionFlash Attention
Forward memory O(S²) — full attention matrix in HBM O(S) — only Q, K, V, O + small stats
What's stored Full S×S score matrix + softmax output Just Q, K, V, output, and per-row (max, sum) for softmax
Backward pass Reads stored S×S matrix from HBM Recomputes S×S tile-by-tile in SRAM from Q, K, V
HBM reads/writes O(S²) — memory-bandwidth bound O(S²/M) where M = SRAM size
Wall-clock speed Slower (waiting for HBM) 2-4× faster (compute-bound, not memory-bound)
Activation checkpointing Needs explicit torch.checkpoint() Built-in! Recomputation is part of the kernel
The paradox explained: Flash Attention does more total FLOPs (it recomputes attention in backward) but is faster in wall-clock time. Why? Because the bottleneck was never compute — it was memory bandwidth. By keeping data in fast SRAM and reducing slow HBM transfers, the extra recomputation is "free" — the CUDA cores were idle waiting for data anyway.
⚔

Faster + Less Memory

Flash Attention breaks the usual trade-off. By exploiting the SRAM/HBM hierarchy, it's both faster AND uses less memory than standard attention.

🧩

Composable with Checkpointing

Flash Attention handles the attention block's memory. You can still use layer-level checkpointing for FFN activations on top. They stack well.

šŸ“

Enables Long Context

By eliminating the O(S²) HBM memory cost, Flash Attention makes training with 32K-128K+ sequence lengths feasible. The S×S matrix stays in SRAM tiles.

Key Takeaways

1

Activations dominate memory at scale

For long sequences, activation memory can dwarf model parameters + optimizer. The quadratic S² term from attention is the main culprit.

2

Checkpoint attention first

Attention recomputation gives the best memory/compute ratio: ~11-14% compute overhead to eliminate the S² memory term. Always checkpoint attention before FFN.

3

FFN recomputation is expensive

FFN dominates compute (~57-67% of forward FLOPs depending on sequence length), so recomputing it adds ~19-22% overhead. Only use when you've exhausted other memory savings.

4

Total overhead caps at ~33%

Even recomputing everything adds at most ~33% compute overhead (one extra forward pass out of the 3F total). This is a known, bounded cost.

5

Flash Attention is checkpointing

Flash Attention applies the recomputation principle at the hardware level. It eliminates the S×S memory footprint while being faster, not slower.

6

They compose together

Use Flash Attention (eliminates attention S² memory) + selective FFN checkpointing for the ultimate memory-efficient training setup.