The fundamental memory-compute trade-off in transformer training. Save massive GPU memory by recomputing activations instead of storing them.
Every intermediate tensor computed during the forward pass is an "activation". They're the hidden state of your network at each layer.
The largest component at long sequences. Contains 11 * S*B*H linear terms plus 5 * n_heads * S² * B ā the infamous quadratic scaling from attention scores and softmax outputs.
The feed-forward network stores intermediate results from its two large linear projections and the activation function (GeLU/SwiGLU). Scales linearly with sequence length.
Small but always stored. LayerNorm needs to save input statistics (mean, variance) and the normalized output for the backward pass. Rarely worth recomputing.
(11 + 19 + 4) * S*B*H + 5 * n_heads * S² * B = 34*S*B*H + 5*n_heads*S²*B. For bf16 (2 bytes per element), multiply by 2 to get bytes. The S² term is what makes long-context training so memory-hungry.
GPU memory during training is split between model state (fixed) and activations (scales with sequence length).
4N bytes (fp32)
Fixed cost. E.g., 8B model = 32 GB
4N bytes (fp32)
Same size as parameters
8N bytes (Adam m,v)
Adam stores 2 states per param
Varies with S, B
This is what we can recompute!
Don't store activations ā recompute them during the backward pass. Trade compute time for memory.
Don't store intermediates
Re-run forward during backward
All activations stored in GPU memory during forward pass, used during backward pass.
Discard activations during forward, recompute them on-the-fly during backward.
When you discard activations, you can't throw away everything ā you still need some "save points" so the backward pass knows where to start recomputing from. These save points are called checkpoint tensors.
You save your game at the start of each level. If you need to replay, you load from the save point.
In backward pass, load the save point and re-play the forward computation for that layer to regenerate all intermediates.
[S, B, H]. That's S × B × H elements, stored in bf16 (2 bytes each). Compare this to the 34*S*B*H + 5*n*S²*B elements you'd store without checkpointing ā the checkpoint is ~34× smaller (ignoring the S² term which makes it even more dramatic).
So checkpointing is not "free" in memory ā you do store one small tensor per layer. But it's a tiny fraction of what you'd store without checkpointing. That's why the code adds CHECKPOINT_SBH_PER_LAYER = 1 to the memory model: even when discarding everything, there's still this small residual cost.
Discard all attention activations (both 11*SBH and 5*n*S²*B terms). Keeps FFN activations stored.
Discard FFN activations (19*SBH). Keeps attention activations stored ā the S² term remains.
Discard both attention and FFN activations. Only LayerNorm + checkpoints remain in memory.
See how different checkpointing strategies affect memory across Llama 3.1 model sizes. Click the buttons to switch strategies.
Understanding the compute cost of each transformer component is key to understanding recomputation overhead.
Y = XW in the forward pass, the backward pass computes two matmuls of similar size: dW = X⊃T dY (gradient w.r.t. weights) and dX = dY W⊃T (gradient w.r.t. inputs).
The S² term in attention compute means attention FLOPs grow quadratically with sequence length, but for typical model sizes, the 8BSH² projection term often still dominates.
The FFN is purely linear in S ā no quadratic terms. But with expansion factor 4, its FLOPs are 2× the attention projections. This makes FFN the compute-dominant block in most configurations.
Let's derive them from the FLOPs formulas. Per layer, the forward FLOPs are:
Case 1: Short sequences (S « H), the S² term is tiny:
Case 2: When S = H (e.g., S=4096, H=4096):
Now the 4·B·S²·H term is no longer negligible. Since S=H, this equals 4BSH², so:
F / (3F) = 1/3.
F_attn / (3 * F_total)
Lower overhead because attention FLOPs are a minority. Rises slightly at long sequences as the S² compute term grows.
F_ffn / (3 * F_total)
Larger overhead because FFN's two big matmuls (16BSH²) dominate the forward compute budget.
F / (3 * F) = 1/3
Always exactly one-third ā you replay the full forward once out of the 3F total step compute.
Suppose per-layer forward FLOPs are: Attention = 100, FFN = 300
| Strategy | Step FLOPs | Overhead |
|---|---|---|
| Baseline | 3 × 400 = 1200 | ā |
| Recompute Attention | 1200 + 100 = 1300 | +8.3% |
| Recompute FFN | 1200 + 300 = 1500 | +25% |
| Recompute Both | 1200 + 400 = 1600 | +33.3% |
Attention checkpointing gives you the best bang for your buck: maximum memory savings with minimal compute overhead.
Attention activations scale as O(S²) in memory but its compute is only a fraction of total FLOPs. Discarding attention saves the most memory per FLOP of recomputation.
FFN dominates compute (2× more FLOPs than attention projections) but its memory only scales O(S). Recomputing FFN is expensive for modest memory gains.
| Strategy | Memory Saved (8B, S=16K) | Compute Overhead | Efficiency Ratio |
|---|
As sequence length increases, attention's memory share grows quadratically while its compute share grows less steeply. This means the efficiency of attention checkpointing improves at longer sequences ā exactly where you need it most.
Flash Attention is a fused, tiled attention kernel that's both faster and more memory-efficient ā and it has deep connections to activation checkpointing.
A GPU doesn't have one uniform pool of memory. It has a hierarchy, just like a CPU has L1/L2 cache and main RAM. Understanding this hierarchy is the key to understanding why Flash Attention works.
| Property | SRAM (On-Chip) | HBM (Off-Chip) |
|---|---|---|
| What it is | Static RAM cache inside each Streaming Multiprocessor (SM). Also called "shared memory" or "L1 cache". | High Bandwidth Memory ā the main GPU memory. Multiple stacks of DRAM chips on the GPU package. |
| Size | ~192 KB per SM, ~20 MB total | 40-80 GB |
| Bandwidth | ~19 TB/s | ~2 TB/s (A100) |
| Latency | ~28 cycles | ~200-400 cycles |
| Programmer controls? | Yes ā CUDA shared memory, explicitly managed | Yes ā all torch.tensor allocations live here |
S × S attention matrix to HBM because it's too big for SRAM. With S=4096 and 32 heads, that's 32 × 4096 × 4096 × 2 bytes = 1 GB per layer ā way too large for 20 MB of SRAM. So the GPU spends most of its time waiting for data to move between HBM and SRAM, rather than actually doing math.
Each step reads/writes the full matrix to HBM:
Tiles Q, K, V into blocks that fit in SRAM:
Softmax needs the entire row to compute the denominator sum(exp(x_i)). If you only have one tile, you don't see the full row. Flash Attention solves this with the online softmax trick:
Process tile 1: compute local softmax, keep running max and running sum(exp)
Process tile 2: if new max is larger, rescale the previous partial sum and output using correction factor exp(old_max - new_max)
Repeat for all tiles. At the end, the accumulated output is mathematically identical to full softmax ā just computed incrementally.
This is why Flash Attention is exact (not an approximation). It gives the same result as standard attention, just computed in a more memory-efficient order.
Flash Attention doesn't store the full S×S attention matrix during the forward pass. In the backward pass, it recomputes the attention scores tile-by-tile from the stored Q, K, V matrices. This is exactly the same idea as activation checkpointing ā but baked into the CUDA kernel itself.
| Aspect | Standard Attention | Flash Attention |
|---|---|---|
| Forward memory | O(S²) ā full attention matrix in HBM | O(S) ā only Q, K, V, O + small stats |
| What's stored | Full S×S score matrix + softmax output | Just Q, K, V, output, and per-row (max, sum) for softmax |
| Backward pass | Reads stored S×S matrix from HBM | Recomputes S×S tile-by-tile in SRAM from Q, K, V |
| HBM reads/writes | O(S²) ā memory-bandwidth bound | O(S²/M) where M = SRAM size |
| Wall-clock speed | Slower (waiting for HBM) | 2-4× faster (compute-bound, not memory-bound) |
| Activation checkpointing | Needs explicit torch.checkpoint() | Built-in! Recomputation is part of the kernel |
Flash Attention breaks the usual trade-off. By exploiting the SRAM/HBM hierarchy, it's both faster AND uses less memory than standard attention.
Flash Attention handles the attention block's memory. You can still use layer-level checkpointing for FFN activations on top. They stack well.
By eliminating the O(S²) HBM memory cost, Flash Attention makes training with 32K-128K+ sequence lengths feasible. The S×S matrix stays in SRAM tiles.
For long sequences, activation memory can dwarf model parameters + optimizer. The quadratic S² term from attention is the main culprit.
Attention recomputation gives the best memory/compute ratio: ~11-14% compute overhead to eliminate the S² memory term. Always checkpoint attention before FFN.
FFN dominates compute (~57-67% of forward FLOPs depending on sequence length), so recomputing it adds ~19-22% overhead. Only use when you've exhausted other memory savings.
Even recomputing everything adds at most ~33% compute overhead (one extra forward pass out of the 3F total). This is a known, bounded cost.
Flash Attention applies the recomputation principle at the hardware level. It eliminates the S×S memory footprint while being faster, not slower.
Use Flash Attention (eliminates attention S² memory) + selective FFN checkpointing for the ultimate memory-efficient training setup.