The final frontier of data-parallel memory optimization. By sharding parameters, gradients, AND optimizer states, every GPU holds only 1/N of everything. Walk through every gather-compute-flush step with actual numbers.
The same 260-parameter transformer we used in ZeRO-1 and ZeRO-2 — now with full parameter sharding.
d = 4 — hidden dimensionheads = 2 — attention headsd_k = 2 — head dimensionFFN = 16 — inner dimension (4x)vocab = 8 — vocabulary sizeT = 3 — sequence lengthGPUs = 2 — world size N| Layer | Shape | Count |
|---|---|---|
| LayerNorm 1 (γ, β) | (4)+(4) | 8 |
| W_q, W_k, W_v | 3 x (4x4) | 48 |
| W_o | (4x4) | 16 |
| LayerNorm 2 (γ, β) | (4)+(4) | 8 |
| W₁ (FFN up) | (4x16) | 64 |
| b₁ | (16) | 16 |
| W₂ (FFN down) | (16x4) | 64 |
| b₂ | (4) | 4 |
| W_vocab | (4x8) | 32 |
| Total | 260 |
ZeRO-1 sharded optimizer states.
ZeRO-2 added gradient sharding.
ZeRO-3 goes all the way: parameters themselves are sharded.
No GPU holds the full model at rest. To compute any layer, GPUs must first all-gather that layer's parameters, compute, then flush the non-owned shards.
m (momentum), v (variance), and p (FP32 copy) — each 4 bytes per parameter.
See how memory per GPU drops as we shard more aggressively from No ZeRO through ZeRO-3.
Hover over bars to see breakdown. Click to highlight.
= 16 x 260 = 4,160 B
= 10 x 260 = 2,600 B
= 9 x 260 = 2,340 B
= 8 x 260 = 2,080 B
| Strategy | Params (bytes) | Grads (bytes) | Optimizer (bytes) | Total (bytes) | Savings vs No ZeRO |
|---|---|---|---|---|---|
| No ZeRO | 520 | 520 | 3,120 | 4,160 | -- |
| ZeRO-1 | 520 | 520 | 1,560 | 2,600 | 37.5% |
| ZeRO-2 | 520 | 260 | 1,560 | 2,340 | 43.8% |
| ZeRO-3 | 260 | 260 | 1,560 | 2,080 | 50.0% |
16P/N.
Double the GPUs → halve the memory. This is the theoretical optimum for data parallelism.
All 260 parameters are flattened into one vector and split at index 130. Each GPU stores only its shard.
The core FSDP loop. For each layer: all-gather its parameters, compute the output, then flush the non-owned shard.
| Layer | Params in layer | Temp memory (BF16) | All-Gather comm | Note |
|---|---|---|---|---|
| LN1 | 8 | 16 B | 16 B | Small layer |
| Attention | 64 | 128 B | 128 B | W_q + W_k + W_v + W_o |
| LN2 | 8 | 16 B | 16 B | Small layer |
| FFN | 148 | 296 B | 296 B | PEAK memory |
| Output | 32 | 64 B | 64 B | W_vocab |
| Total | 260 | — | 520 B | NEW vs ZeRO-1/2 |
Mirror of forward but in reverse layer order, with reduce-scatter to distribute gradients.
Same as forward — must re-gather each layer's full parameters to compute gradients.
Each GPU receives only its owned shard of gradients, averaged across all GPUs.
W_q has 16 parameters (indices 8–23). Both GPUs compute the full gradient ∇W_q[16].
ZeRO-3 trades 50% more communication for the ability to shard parameters.
| Phase | ZeRO-1 | ZeRO-2 | ZeRO-3 |
|---|---|---|---|
| Forward All-Gather | 0 | 0 | 520 B |
| Backward All-Gather | 0 | 0 | 520 B |
| Backward Reduce-Scatter | 520 B | 520 B | 520 B |
| Post-Optim All-Gather | 520 B | 520 B | 0 |
| Total | 1,040 B | 1,040 B | 1,560 B |
Overlap the all-gather of layer N+1 with the compute of layer N to hide latency.
prefetch_count / forward_prefetch).
Compare how memory usage changes during the forward pass for ZeRO-2 vs ZeRO-3.
See how ZeRO-3's perfect linear scaling makes large model training feasible on consumer hardware.
Model: 7B parameters (BF16). Adjust the number of GPUs.
| Strategy | Formula | Memory per GPU | Fits in 24GB? |
|---|
From no sharding to full sharding — each stage trades communication for memory savings.
| Feature | No ZeRO | ZeRO-1 | ZeRO-2 | ZeRO-3 (FSDP) |
|---|---|---|---|---|
| Shard optimizer | No | Yes | Yes | Yes |
| Shard gradients | No | No | Yes | Yes |
| Shard parameters | No | No | No | Yes |
| Memory formula | 16P |
(4+12/N)P |
(2+14/N)P |
16P/N |
| Comm per step | 2P (AllReduce) |
2P |
2P |
3P |
| Forward all-gather | None | None | None | Required |
| Full model in memory | Yes (wasteful) | Yes | Yes | No |
| 7B on 8 GPUs (mem) | 112 GB | 38.5 GB | 26.25 GB | 14 GB |
Model fits easily in one GPU. Don't want any communication overhead. Quick experiments and prototyping.
Model fits in memory but optimizer is tight. Zero extra communication cost. Best default choice for multi-GPU.
Need more memory savings than ZeRO-1. Still no extra communication. Gradients are large relative to optimizer states.
Model doesn't fit in one GPU. Willing to pay 1.5x communication for parameter sharding. The only choice for truly large models with data parallelism alone.
ZeRO-3 achieves the theoretical optimum: total memory divided equally across all GPUs. Double the GPUs, halve the memory per GPU.
For every layer: all-gather the full parameters, run the computation, flush the non-owned shard. This pattern repeats for both forward and backward passes.
ZeRO-3 costs 3P total communication vs 2P for ZeRO-1/2. The extra all-gather during forward and backward is the cost of not storing full parameters.
By overlapping the all-gather of the next layer with the compute of the current layer, much of the extra communication latency can be hidden behind useful work.
Unlike ZeRO-1/2 where every GPU has all parameters, ZeRO-3 means no single GPU can run the model alone. Parameters exist only temporarily during compute.
ZeRO-3/FSDP is the go-to when your model's memory footprint exceeds what a single GPU can handle. It's the bridge between data parallelism and model parallelism.