ZeRO-2 partitions both optimizer states AND gradients across GPUs. Follow every byte through a complete training step on a tiny transformer with 260 parameters.
Same model as the ZeRO-1 walkthrough — a single transformer block with concrete, traceable values.
d = 4
2 heads, d_k = 2
16 (4x expansion)
V=8, T=3 tokens
| Layer | Name | Shape | #Elements | GPU Owner |
|---|---|---|---|---|
| LayerNorm 1 | γ₁ | (4,) | 4 | GPU-0 |
| LayerNorm 1 | β₁ | (4,) | 4 | GPU-0 |
| Attention | W_q | (4, 4) | 16 | GPU-0 |
| Attention | W_k | (4, 4) | 16 | GPU-0 |
| Attention | W_v | (4, 4) | 16 | GPU-0 |
| Attention | W_o | (4, 4) | 16 | GPU-0 |
| LayerNorm 2 | γ₂ | (4,) | 4 | GPU-0 |
| LayerNorm 2 | β₂ | (4,) | 4 | GPU-0 |
| FFN | W₁ | (4, 16) | 50 + 14 | Split |
| FFN | b₁ | (16,) | 16 | GPU-1 |
| FFN | W₂ | (16, 4) | 64 | GPU-1 |
| FFN | b₂ | (4,) | 4 | GPU-1 |
| Output | W_vocab | (4, 8) | 32 | GPU-1 |
| Total | 260 | 130 + 130 | ||
0..129 go to GPU-0, indices 130..259 go to GPU-1. This is identical to ZeRO-1 — the partition boundary falls in the middle of W₁.
The key insight: there is no reason to keep the full gradient vector on every GPU. Each GPU only needs averaged gradients for its own optimizer slice.
Click a strategy to highlight. All values in bytes for 260 parameters on 2 GPUs.
2P to 2P/N. For our model: 520 bytes → 260 bytes. The extra saving is 260 bytes per GPU — gradient partition!
In ZeRO-2, "ownership" means more: each GPU stores optimizer states AND averaged gradients for its slice.
θ[260] in BF16 (needed for forward and backward pass). That's it — nothing else is fully replicated.
The critical difference from ZeRO-1: reduce-scatter is fused INTO the backward pass. Gradients are communicated and discarded immediately.
Watch gradient memory on GPU-0 during backpropagation. ZeRO-1 accumulates everything; ZeRO-2 discards immediately.
Both GPUs compute local gradients for W_q, then immediately reduce-scatter. Since W_q is in GPU-0's slice, GPU-0 keeps the averaged result and GPU-1 discards everything.
These steps are identical to ZeRO-1. Each GPU runs Adam on its 130-element slice, then broadcasts the updated parameters.
After all-gather, both GPUs hold the full updated model θ[260].
All three approaches have identical communication volume. ZeRO-2 just pipelines the reduce-scatter into the backward pass.
| Method | Reduce-Scatter | All-Gather | Total Comm. | Timing |
|---|---|---|---|---|
| Standard All-Reduce | 520 B | 520 B | 1,040 B | After backward |
| ZeRO-1 | 520 B | 520 B | 1,040 B | After backward + after optimizer |
| ZeRO-2 | 520 B | 520 B | 1,040 B | During backward + after optimizer |
Adjust model size and GPU count to see how ZeRO-2 compares to ZeRO-1 and baseline.
| Component | No ZeRO (per GPU) | ZeRO-1 (8 GPUs) | ZeRO-2 (8 GPUs) |
|---|---|---|---|
| Params (BF16) | 14.0 GB | 14.0 GB | 14.0 GB |
| Gradients (BF16) | 14.0 GB | 14.0 GB | 1.75 GB |
| Optimizer: m (FP32) | 28.0 GB | 3.5 GB | 3.5 GB |
| Optimizer: v (FP32) | 28.0 GB | 3.5 GB | 3.5 GB |
| Optimizer: master p (FP32) | 28.0 GB | 3.5 GB | 3.5 GB |
| Total per GPU | 112.0 GB | 38.5 GB | 26.25 GB |
| Saving vs baseline | — | 65.6% | 76.6% |
Why not always use ZeRO-2? In practice, you should. The communication volume is identical, and the memory savings are free.
| Dimension | ZeRO-1 | ZeRO-2 |
|---|---|---|
| Gradient memory | Full (P/1) | Partitioned (P/N) |
| Communication timing | Reduce-scatter AFTER backward | Reduce-scatter DURING backward |
| Implementation | Simple (one RS call) | More complex (bucket-by-bucket, fused) |
| Communication volume | Identical | Identical |
| Optimizer step | Identical | Identical |
| All-gather | Identical | Identical |
2P term — but at the cost of extra all-gather communication during the forward pass.
ZeRO-2 adds gradient partitioning on top of ZeRO-1's optimizer partitioning. Each GPU only stores averaged gradients for its own optimizer slice.
The key mechanism: reduce-scatter is fused into the backward pass. Gradients are communicated and discarded layer-by-layer, never stored in full.
Communication volume is identical to ZeRO-1 and standard all-reduce. The memory saving is completely free — only implementation complexity increases.
Peak gradient memory drops from P to P/N. For 7B on 8 GPUs: 14 GB → 1.75 GB, saving 12.25 GB per GPU for free.
There is no real reason to use ZeRO-1 over ZeRO-2. Modern frameworks like DeepSpeed handle the implementation complexity transparently.
ZeRO-2 leaves the 2P parameter term. ZeRO-3 partitions parameters too, eliminating this — but adds all-gather communication during forward pass.