ZeRO Series — Stage 3

ZeRO-3 (FSDP)

The final frontier of data-parallel memory optimization. By sharding parameters, gradients, AND optimizer states, every GPU holds only 1/N of everything. Walk through every gather-compute-flush step with actual numbers.

260
Total parameters (tiny model)
50%
Memory saved vs baseline
1.5x
Communication overhead
16P/N
Perfect memory scaling

Our Tiny Transformer Setup

The same 260-parameter transformer we used in ZeRO-1 and ZeRO-2 — now with full parameter sharding.

Architecture

  • d = 4 — hidden dimension
  • heads = 2 — attention heads
  • d_k = 2 — head dimension
  • FFN = 16 — inner dimension (4x)
  • vocab = 8 — vocabulary size
  • T = 3 — sequence length
  • GPUs = 2 — world size N

Parameter Count

LayerShapeCount
LayerNorm 1 (γ, β)(4)+(4)8
W_q, W_k, W_v3 x (4x4)48
W_o(4x4)16
LayerNorm 2 (γ, β)(4)+(4)8
W₁ (FFN up)(4x16)64
b₁(16)16
W₂ (FFN down)(16x4)64
b₂(4)4
W_vocab(4x8)32
Total260

What's New in ZeRO-3

ZeRO-1 sharded optimizer states.
ZeRO-2 added gradient sharding.
ZeRO-3 goes all the way: parameters themselves are sharded.

No GPU holds the full model at rest. To compute any layer, GPUs must first all-gather that layer's parameters, compute, then flush the non-owned shards.

FSDP = ZeRO-3. PyTorch's FullyShardedDataParallel (FSDP) implements exactly the same algorithm as DeepSpeed ZeRO Stage 3.
BF16 precision: Each parameter = 2 bytes. So 260 parameters = 520 bytes of model weights. AdamW stores m (momentum), v (variance), and p (FP32 copy) — each 4 bytes per parameter.

Memory Accounting: The Full ZeRO Progression

See how memory per GPU drops as we shard more aggressively from No ZeRO through ZeRO-3.

Memory per GPU (bytes) — 260 Parameters, 2 GPUs

Hover over bars to see breakdown. Click to highlight.

Parameters (2P)
Gradients (2P)
Optimizer (m+v+p = 12P)

No ZeRO

Memory per GPU 16P

= 16 x 260 = 4,160 B

ZeRO-1

Memory per GPU (4 + 12/N)P

= 10 x 260 = 2,600 B

ZeRO-2

Memory per GPU (2 + 14/N)P

= 9 x 260 = 2,340 B

ZeRO-3

Memory per GPU 16P / N

= 8 x 260 = 2,080 B

Strategy Params (bytes) Grads (bytes) Optimizer (bytes) Total (bytes) Savings vs No ZeRO
No ZeRO 5205203,120 4,160 --
ZeRO-1 5205201,560 2,600 37.5%
ZeRO-2 5202601,560 2,340 43.8%
ZeRO-3 2602601,560 2,080 50.0%
Perfect scaling: With ZeRO-3, memory per GPU = 16P/N. Double the GPUs → halve the memory. This is the theoretical optimum for data parallelism.

Who Owns What — At Rest

All 260 parameters are flattened into one vector and split at index 130. Each GPU stores only its shard.

Index 0 Index 130 Index 259
γ₁, β₁ (8 params, idx 0-7) γ₁β₁
W_q (16 params, idx 8-23) W_q
W_k (16 params, idx 24-39) W_k
W_v (16 params, idx 40-55) W_v
W_o (16 params, idx 56-71) W_o
γ₂, β₂ (8 params, idx 72-79) γ₂β₂
W₁ partial (50 params, idx 80-129) W₁ (partial)
W₁ rest (14 params, idx 130-143) W₁
b₁ (16 params, idx 144-159) b₁
W₂ (64 params, idx 160-223) W₂
b₂ (4 params, idx 224-227) b₂
W_vocab (32 params, idx 228-259) W_vocab
GPU-0 (indices 0–129) — 130 params
GPU-1 (indices 130–259) — 130 params
GPU-0 At Rest
θ₀[130] (BF16 params) 260 B
m₀[130] (FP32 momentum) 520 B
v₀[130] (FP32 variance) 520 B
p₀[130] (FP32 copy) 520 B
Total at rest 1,820 B
GPU-1 At Rest
θ₁[130] (BF16 params) 260 B
m₁[130] (FP32 momentum) 520 B
v₁[130] (FP32 variance) 520 B
p₁[130] (FP32 copy) 520 B
Total at rest 1,820 B
Neither GPU can run the model alone! GPU-0 has only 130 of 260 parameters. To compute any layer, it must first all-gather the missing shard from GPU-1 (and vice versa). This is the fundamental trade-off of ZeRO-3: less memory, more communication.

Gather → Compute → Flush

The core FSDP loop. For each layer: all-gather its parameters, compute the output, then flush the non-owned shard.

GATHER
COMPUTE
FLUSH
GPU-0 Memory During Forward: Gathering LN1 params...
Own shard (260B)
Gathered params (temp)
Optimizer states (1,560B)
Layer Params in layer Temp memory (BF16) All-Gather comm Note
LN1 8 16 B 16 B Small layer
Attention 64 128 B 128 B W_q + W_k + W_v + W_o
LN2 8 16 B 16 B Small layer
FFN 148 296 B 296 B PEAK memory
Output 32 64 B 64 B W_vocab
Total 260 520 B NEW vs ZeRO-1/2
Key insight: In ZeRO-1 and ZeRO-2, the forward pass required zero extra communication because every GPU held the full model. In ZeRO-3, each forward pass costs an all-gather of all 260 parameters (520 bytes total communication), but the temporary memory is bounded by the largest single layer (FFN: 296 B).

Gather → Backprop → Flush + Reduce-Scatter

Mirror of forward but in reverse layer order, with reduce-scatter to distribute gradients.

Select a layer above to see the backward pass steps.

All-Gather (parameters)

NEW in ZeRO-3 520 B

Same as forward — must re-gather each layer's full parameters to compute gradients.

Reduce-Scatter (gradients)

Same as ZeRO-2 520 B

Each GPU receives only its owned shard of gradients, averaged across all GPUs.

Concrete Example: Gradient for W_q

W_q has 16 parameters (indices 8–23). Both GPUs compute the full gradient ∇W_q[16].

1
All-Gather W_q: GPU-0 sends its 8 elements of W_q, GPU-1 sends its 8. Both now have all 16.
2
Compute ∇W_q: Each GPU computes the full 16-element gradient using its own mini-batch data.
3
Reduce-Scatter: Average the 16 gradients, then GPU-0 gets ∇W_q[0:8], GPU-1 gets ∇W_q[8:16]. Each owns only its shard.
4
Flush: Discard the gathered W_q parameters (non-owned shard) and the non-owned gradient shard. Memory returns to baseline.

The Price of ZeRO-3

ZeRO-3 trades 50% more communication for the ability to shard parameters.

Communication Volume per Training Step (bytes)
Total bytes sent per GPU through the network
All-Gather
Reduce-Scatter
Phase ZeRO-1 ZeRO-2 ZeRO-3
Forward All-Gather 0 0 520 B
Backward All-Gather 0 0 520 B
Backward Reduce-Scatter 520 B 520 B 520 B
Post-Optim All-Gather 520 B 520 B 0
Total 1,040 B 1,040 B 1,560 B
The trade-off explained: ZeRO-1 and ZeRO-2 both need a post-optimizer all-gather (520 B) to rebuild full parameters. ZeRO-3 eliminates this but adds two all-gathers (forward + backward = 1,040 B). Net extra = 1 additional all-gather (520 B), making ZeRO-3 cost 1.5x the communication of ZeRO-1/2.
Why it's worth it: That extra 520 B of communication buys you parameter sharding — the difference between fitting a model in GPU memory or not. For large models, this is the only way to train with data parallelism alone.

Prefetching: Hiding the Communication Cost

Overlap the all-gather of layer N+1 with the compute of layer N to hide latency.

Forward Pass with Prefetching
While computing layer N, simultaneously gather layer N+1's parameters
GPU Compute
Compute LN1
Compute Attn
Compute LN2
Compute FFN
Compute Out
GPU Network
Gather LN1
Gather Attn
Gather LN2
Gather FFN
Gather Out
Compute
Gather (initial)
Prefetched Gather (overlapped)
Backward Pass with Prefetching
Same overlap strategy in reverse order, plus reduce-scatter interleaving
GPU Compute
Backprop Out
Backprop FFN
Backprop LN2
Backprop Attn
Backprop LN1
All-Gather
Gather Out
Gather FFN
Gather LN2
Gather Attn
Gather LN1
Reduce-Scatter
RS Out
RS FFN
RS LN2
RS Attn
RS LN1
Backprop Compute
All-Gather
Reduce-Scatter
Effective prefetching limit: Prefetching works well as long as compute time per layer ≥ communication time. For very large clusters (>512 GPUs), the all-gather may become the bottleneck, reducing the effectiveness of overlap. DeepSpeed and FSDP both support configurable prefetch counts (prefetch_count / forward_prefetch).

Frame-by-Frame Memory During Forward Pass

Compare how memory usage changes during the forward pass for ZeRO-2 vs ZeRO-3.

ZeRO-2: Memory During Forward Pass
Parameters stay constant since every GPU has the full model
Key difference: In ZeRO-2, the parameter memory is constant at 520 B throughout the forward pass. In ZeRO-3, the base is only 260 B, but temporary gathered parameters spike during each layer computation. Peak temp memory is bounded by the largest layer (FFN: 296 B extra).

Scaling to 7B Parameters on 8 GPUs

See how ZeRO-3's perfect linear scaling makes large model training feasible on consumer hardware.

Interactive Scaling Calculator

Model: 7B parameters (BF16). Adjust the number of GPUs.

Number of GPUs 8
Strategy Formula Memory per GPU Fits in 24GB?
ZeRO-3 with 8 GPUs 14.00 GB per GPU
14 GB per GPU! Train a 7-billion parameter model on 8 consumer GPUs (RTX 4090, 24 GB each). Without ZeRO-3, you'd need 112 GB per GPU — impossible even on A100-80GB. ZeRO-3 makes data-parallel training of large models accessible.

The Full ZeRO Progression

From no sharding to full sharding — each stage trades communication for memory savings.

Feature No ZeRO ZeRO-1 ZeRO-2 ZeRO-3 (FSDP)
Shard optimizer No Yes Yes Yes
Shard gradients No No Yes Yes
Shard parameters No No No Yes
Memory formula 16P (4+12/N)P (2+14/N)P 16P/N
Comm per step 2P (AllReduce) 2P 2P 3P
Forward all-gather None None None Required
Full model in memory Yes (wasteful) Yes Yes No
7B on 8 GPUs (mem) 112 GB 38.5 GB 26.25 GB 14 GB

When to Use Which?

No ZeRO

Model fits easily in one GPU. Don't want any communication overhead. Quick experiments and prototyping.

ZeRO-1

Model fits in memory but optimizer is tight. Zero extra communication cost. Best default choice for multi-GPU.

ZeRO-2

Need more memory savings than ZeRO-1. Still no extra communication. Gradients are large relative to optimizer states.

ZeRO-3 / FSDP

Model doesn't fit in one GPU. Willing to pay 1.5x communication for parameter sharding. The only choice for truly large models with data parallelism alone.

Key Takeaways

01

16P/N — Perfect Linear Memory Scaling

ZeRO-3 achieves the theoretical optimum: total memory divided equally across all GPUs. Double the GPUs, halve the memory per GPU.

02

Gather-Compute-Flush — The Core FSDP Loop

For every layer: all-gather the full parameters, run the computation, flush the non-owned shard. This pattern repeats for both forward and backward passes.

03

1.5x Communication — The Price of Full Sharding

ZeRO-3 costs 3P total communication vs 2P for ZeRO-1/2. The extra all-gather during forward and backward is the cost of not storing full parameters.

04

Prefetching — Hiding the Extra Cost

By overlapping the all-gather of the next layer with the compute of the current layer, much of the extra communication latency can be hidden behind useful work.

05

No Full Replica — Each GPU Holds Only Its Shard

Unlike ZeRO-1/2 where every GPU has all parameters, ZeRO-3 means no single GPU can run the model alone. Parameters exist only temporarily during compute.

06

When to Use — Model Doesn't Fit in One GPU

ZeRO-3/FSDP is the go-to when your model's memory footprint exceeds what a single GPU can handle. It's the bridge between data parallelism and model parallelism.