ZeRO Optimization · Stage 2

ZeRO-2: A Concrete Walkthrough
with Actual Numbers

ZeRO-2 partitions both optimizer states AND gradients across GPUs. Follow every byte through a complete training step on a tiny transformer with 260 parameters.

260
Total Parameters
43.8%
Memory Saved
0
Extra Comm. Cost
2
GPUs

Our Tiny Transformer

Same model as the ZeRO-1 walkthrough — a single transformer block with concrete, traceable values.

Hidden dim

d = 4

Attention heads

2 heads, d_k = 2

FFN inner dim

16 (4x expansion)

Vocab & Seq

V=8, T=3 tokens

Layer Name Shape #Elements GPU Owner
LayerNorm 1γ₁(4,)4GPU-0
LayerNorm 1β₁(4,)4GPU-0
AttentionW_q(4, 4)16GPU-0
AttentionW_k(4, 4)16GPU-0
AttentionW_v(4, 4)16GPU-0
AttentionW_o(4, 4)16GPU-0
LayerNorm 2γ₂(4,)4GPU-0
LayerNorm 2β₂(4,)4GPU-0
FFNW₁(4, 16)50 + 14Split
FFNb₁(16,)16GPU-1
FFNW₂(16, 4)64GPU-1
FFNb₂(4,)4GPU-1
OutputW_vocab(4, 8)32GPU-1
Total260130 + 130
Flat vector partition: We flatten all 260 parameters into a 1D vector. Indices 0..129 go to GPU-0, indices 130..259 go to GPU-1. This is identical to ZeRO-1 — the partition boundary falls in the middle of W₁.

Memory Accounting

The key insight: there is no reason to keep the full gradient vector on every GPU. Each GPU only needs averaged gradients for its own optimizer slice.

Model & Gradient Storage

Parameters (BF16)2 bytes/elem
Gradients (BF16)2 bytes/elem

Adam Optimizer States (FP32)

Master params4 bytes/elem
First moment m4 bytes/elem
Second moment v4 bytes/elem

Per-GPU Memory: No ZeRO vs ZeRO-1 vs ZeRO-2

Click a strategy to highlight. All values in bytes for 260 parameters on 2 GPUs.

Parameters (BF16)
Gradients (BF16)
Optimizer (FP32)
No ZeRO (per GPU) 2P + 2P + 12P
= 16P bytes
= 4,160
ZeRO-1 (per GPU) 2P + 2P + 12P/N
= 4P + 12P/N
= 2,600
ZeRO-2 (per GPU) 2P + 2P/N + 12P/N
= 2P + 14P/N
= 2,340
The one change from ZeRO-1: The gradient term goes from 2P to 2P/N. For our model: 520 bytes → 260 bytes. The extra saving is 260 bytes per GPU — gradient partition!

Who Owns What

In ZeRO-2, "ownership" means more: each GPU stores optimizer states AND averaged gradients for its slice.

Flat parameter vector (260 elements) — hover over segments for details
GPU-0 owns indices 0–129
GPU-1 owns indices 130–259

GPU-0 owns slice [0:130]

Optimizer: m₀[130]FP32
Optimizer: v₀[130]FP32
Master params: p₀[130]FP32
Averaged gradients: avg_g[0:130]BF16 ← NEW!

GPU-1 owns slice [130:260]

Optimizer: m₁[130]FP32
Optimizer: v₁[130]FP32
Master params: p₁[130]FP32
Averaged gradients: avg_g[130:260]BF16 ← NEW!
Still replicated on BOTH GPUs: Full parameters θ[260] in BF16 (needed for forward and backward pass). That's it — nothing else is fully replicated.

Walking Through One Training Step

The critical difference from ZeRO-1: reduce-scatter is fused INTO the backward pass. Gradients are communicated and discarded immediately.

1+2 Forward & Backward with Fused Reduce-Scatter KEY CHANGE

ZeRO-1 approach:

1
Full backward pass → accumulate all 260 local gradients
2
Separate reduce-scatter step after backward completes
3
Peak gradient memory: 260 values

ZeRO-2 approach:

1
Compute gradient for a layer → immediately reduce-scatter
2
Non-local gradients are discarded instantly
3
Peak gradient memory: ~130 values
Forward pass is identical to ZeRO-1. Each GPU independently runs the full model on its micro-batch: LN1 → Attention → residual → LN2 → FFN → residual → W_vocab → softmax → CrossEntropy loss.

The Backward Pass, Frame by Frame

Watch gradient memory on GPU-0 during backpropagation. ZeRO-1 accumulates everything; ZeRO-2 discards immediately.

ZeRO-1 (GPU-0)
No gradients yet
Memory: 0 values
vs
ZeRO-2 (GPU-0)
No gradients yet
Memory: 0 values
Click the layer buttons above to step through the backward pass and watch gradient memory build up.
ZeRO-1 Peak Gradient Memory
0 values
ZeRO-2 Peak Gradient Memory
0 values
Ex Concrete Example: Backprop through W_q

Both GPUs compute local gradients for W_q, then immediately reduce-scatter. Since W_q is in GPU-0's slice, GPU-0 keeps the averaged result and GPU-1 discards everything.

GPU-0: g_q^(A) — local gradient
0.023
-0.011
0.045
-0.008
-0.031
0.019
-0.007
0.014
0.012
-0.028
0.033
-0.005
-0.016
0.009
-0.021
0.038
GPU-1: g_q^(B) — local gradient
0.017
-0.025
0.031
-0.013
-0.009
0.041
-0.018
0.006
0.028
-0.014
0.022
-0.035
-0.020
0.016
-0.012
0.027
↓ Reduce-Scatter ↓
GPU-0 receives: avg_g_q = (g_q^A + g_q^B) / 2
0.020
-0.018
0.038
-0.011
-0.020
0.030
-0.013
0.010
0.020
-0.021
0.028
-0.020
-0.018
0.013
-0.017
0.033
GPU-1 receives NOTHING for W_q (not its slice!) and discards g_q^(B) immediately. This is the core ZeRO-2 mechanism.

Optimizer Step & All-Gather

These steps are identical to ZeRO-1. Each GPU runs Adam on its 130-element slice, then broadcasts the updated parameters.

3 Adam Optimizer Update (Local Slice Only) Same as ZeRO-1
GPU-0 updates W_q[0,0] — Adam step, t=1 lr=0.001, β₁=0.9, β₂=0.999, ε=1e-8

p = 0.12 (current master param, FP32)
g = 0.020 (averaged gradient)

m_new = 0.9 × 0.0 + 0.1 × 0.020 = 0.002
v_new = 0.999 × 0.0 + 0.001 × 0.0004 = 0.0000004

= 0.002 / (1 - 0.9) = 0.02
= 0.0000004 / (1 - 0.999) = 0.0004

p_new = 0.12 - 0.001 × 0.02 / (√0.0004 + 1e-8)
        = 0.12 - 0.001 × 0.02 / 0.02
        = 0.12 - 0.001 = 0.119

Cast to BF16: W_q[0,0] ≈ 0.1191
4 All-Gather (Sync Updated Parameters) Same as ZeRO-1
GPU-0
new_params[0:130]
All-Gather
GPU-1
new_params[130:260]

After all-gather, both GPUs hold the full updated model θ[260].

5 End of Step — What Lives Where
GPU-0 Slice 0:130
θ[260] in BF16 (full model)520 B
avg_g[0:130] in BF16260 B
m₀[130] in FP32520 B
v₀[130] in FP32520 B
p₀[130] in FP32520 B
Total 2,340 bytes
GPU-1 Slice 130:260
θ[260] in BF16 (full model)520 B
avg_g[130:260] in BF16260 B
m₁[130] in FP32520 B
v₁[130] in FP32520 B
p₁[130] in FP32520 B
Total 2,340 bytes

Communication Cost Analysis

All three approaches have identical communication volume. ZeRO-2 just pipelines the reduce-scatter into the backward pass.

Method Reduce-Scatter All-Gather Total Comm. Timing
Standard All-Reduce 520 B 520 B 1,040 B After backward
ZeRO-1 520 B 520 B 1,040 B After backward + after optimizer
ZeRO-2 520 B 520 B 1,040 B During backward + after optimizer
Identical communication volume! The difference is purely about when communication happens and how long gradient data is kept in memory. ZeRO-2 pipelines the reduce-scatter into the backward pass, allowing immediate gradient memory reclamation at zero extra communication cost.

Scale It Up: Interactive Calculator

Adjust model size and GPU count to see how ZeRO-2 compares to ZeRO-1 and baseline.

Model Size 7B
Number of GPUs 8
Component No ZeRO (per GPU) ZeRO-1 (8 GPUs) ZeRO-2 (8 GPUs)
Params (BF16)14.0 GB14.0 GB14.0 GB
Gradients (BF16)14.0 GB14.0 GB1.75 GB
Optimizer: m (FP32)28.0 GB3.5 GB3.5 GB
Optimizer: v (FP32)28.0 GB3.5 GB3.5 GB
Optimizer: master p (FP32)28.0 GB3.5 GB3.5 GB
Total per GPU 112.0 GB 38.5 GB 26.25 GB
Saving vs baseline 65.6% 76.6%

ZeRO-1 → ZeRO-2: The One Key Change

Why not always use ZeRO-2? In practice, you should. The communication volume is identical, and the memory savings are free.

Dimension ZeRO-1 ZeRO-2
Gradient memoryFull (P/1)Partitioned (P/N)
Communication timingReduce-scatter AFTER backwardReduce-scatter DURING backward
ImplementationSimple (one RS call)More complex (bucket-by-bucket, fused)
Communication volumeIdenticalIdentical
Optimizer stepIdenticalIdentical
All-gatherIdenticalIdentical
ZeRO-1 per-GPU formula 2P + 2P + 12P/N
params + full grads + optimizer/N
ZeRO-2 per-GPU formula 2P + 2P/N + 12P/N
params + grads/N + optimizer/N
Progression to ZeRO-3 is natural: If we can partition gradients (ZeRO-2), why not partition the parameters themselves? That eliminates the remaining 2P term — but at the cost of extra all-gather communication during the forward pass.

What You Should Remember

01

Gradient Partitioning

ZeRO-2 adds gradient partitioning on top of ZeRO-1's optimizer partitioning. Each GPU only stores averaged gradients for its own optimizer slice.

02

Fused Reduce-Scatter

The key mechanism: reduce-scatter is fused into the backward pass. Gradients are communicated and discarded layer-by-layer, never stored in full.

03

Zero Extra Cost

Communication volume is identical to ZeRO-1 and standard all-reduce. The memory saving is completely free — only implementation complexity increases.

04

Peak Memory Win

Peak gradient memory drops from P to P/N. For 7B on 8 GPUs: 14 GB → 1.75 GB, saving 12.25 GB per GPU for free.

05

Always Use ZeRO-2

There is no real reason to use ZeRO-1 over ZeRO-2. Modern frameworks like DeepSpeed handle the implementation complexity transparently.

06

Path to ZeRO-3

ZeRO-2 leaves the 2P parameter term. ZeRO-3 partitions parameters too, eliminating this — but adds all-gather communication during forward pass.