DeepSpeed ZeRO — Configuration Deep Dive

ZeRO Config Parameters

DeepSpeed's ZeRO optimizer has over a dozen tuning parameters that control communication granularity, memory layout, and prefetching. Instead of memorizing defaults, let's build intuition for what each parameter does and why.

3 Levers
Buckets, Overlap, Persistence
500M
Default bucket size (elements)
~1 GB
Buffer per bucket (BF16)
12+
Tunable parameters

Three Levers of ZeRO Config

Every ZeRO config parameter falls into one of three categories. Master these three levers and the rest follows naturally.

📦

Communication Granularity

How many parameter elements get batched together into a "bucket" before a collective communication call fires. Larger buckets = fewer calls, more latency before the first one.

reduce_bucket_size allgather_bucket_size stage3_prefetch_bucket_size
🔄

Compute-Comm Overlap

How much communication hides behind computation. When overlap is enabled, communication runs concurrently with backward passes, making the transfer time "invisible".

overlap_comm reduce_scatter
💾

Persistence & Caching

Which parameters are worth sharding vs. keeping replicated. Tiny params cost negligible memory to replicate but incur real latency to gather every time.

stage3_param_persistence_threshold stage3_max_live_parameters stage3_max_reuse_distance contiguous_gradients
The core tradeoff: Communication one parameter at a time gives maximum overlap but maximum overhead per message. Communicating everything at once gives minimum overhead but zero overlap. Buckets are the middle ground — batch enough before firing, but not so much that you waste overlap opportunity.

Bucket Sizes

Buckets control communication granularity — how many elements accumulate before a collective fires.

Bucket
reduce_bucket_size
500,000,000

Gradients accumulate into a buffer of this size before the reduce-scatter fires. Used in both Stage 2 and Stage 3.

Memory per buffer (BF16) 500M × 2 bytes = ~1 GB
Bucket
allgather_bucket_size
500,000,000

Same idea but for the all-gather step: when broadcasting updated parameters back. Controls how many elements are gathered per call.

Memory per buffer (BF16) 500M × 2 bytes = ~1 GB
Bucket
stage3_prefetch_bucket_size
200,000,000

ZeRO-3 only. How many parameters to prefetch (all-gather in advance) for the next layer while the current layer computes. Intentionally smaller — you want prefetching to start early.

Memory per buffer (BF16) 200M × 2 bytes = ~400 MB

Interactive: Bucket Size Tradeoff

Drag the slider to see how bucket size affects the number of communication calls and buffer memory for a 1B parameter model.

Bucket Size: 500M
Dtype:
2
Communication Calls
1.0 GB
Buffer Memory
Medium
First-Call Latency
Good
Overlap Potential
Why is stage3_prefetch_bucket_size smaller (200M vs 500M)?
Prefetching should start early to maximize overlap with compute. A smaller bucket fills faster, so the first all-gather fires sooner. If the prefetch bucket were 500M, you'd wait longer before any prefetching begins — defeating the purpose.

But Wait — Don't Buckets Kill Overlap?

If the bucket is 500M elements but a single layer only produces ~220M gradients, doesn't the bucket need multiple layers to fill? And doesn't that delay the first communication call?

The worry: If we accumulate gradients into a big 500M-element buffer before firing the reduce-scatter, and a single layer only produces ~220M gradients, then the first communication won't start until 2-3 layers of backward have finished. Aren't we losing overlap by waiting?
How Buckets Actually Fill During Backward
For a 7B model with 32 layers and bucket = 500M, each layer has ~220M params. Watch how buckets fill and fire:
Model Size:
Bucket Size: 500M
Backward (computing gradients)
Gradients in bucket (waiting)
Reduce-scatter FIRED

Why 14 Buckets Is Actually Enough

The key insight: on modern hardware with NVLink, communication is fast relative to compute. Compute (backward) is the slower one — which is exactly what you want, because it means the network finishes before the GPU does, and the overlap is free.

Compute vs Communication: 7B on 8× H100 (NVLink)
We compare the time to compute one bucket's worth of backward vs the time to reduce-scatter that bucket.
GPUs:
Interconnect:

Compute (Backward per Bucket)

Communication (RS per Bucket)

Time per Bucket: Compute vs Communication

Backward compute
1.5ms
Reduce-scatter
1.25ms
Bucket-by-Bucket Overlap Timeline
Each bucket's reduce-scatter runs while the next bucket's backward computes. Since they're in the same ballpark (~1.5ms compute vs ~1.25ms comm), the communication fits neatly behind compute.
The punchline: With NVLink, compute is the slower side (~1.5ms backward per bucket vs ~1.25ms for the reduce-scatter). That's the ideal regime — the network finishes before the GPU does. So while 1 GB of gradient data is flying across NVLink, the GPU is still busy with the next bucket's backward and doesn't even notice the transfer. You don't need hundreds of tiny overlapping calls — each 500M-element bucket already fully saturates the link, and the compute time per bucket is just long enough to completely hide it.
But switch to PCIe and the story flips — now communication becomes the bottleneck (50 GB/s vs 800 GB/s). Each 1 GB bucket takes ~20ms over PCIe, far exceeding the ~1.5ms of compute. Toggle "PCIe 5.0" above to see the bars flip. This is why NVLink matters for multi-GPU training.
The edge case to watch: This works beautifully for large models (7B+) where total params >> bucket size, giving you many buckets and lots of overlap. But for a small model (300M params) with a 500M bucket, all gradients fit in a single bucket that never fills during backward. The reduce-scatter only fires after the entire backward pass — zero overlap, identical to overlap_comm = false. For small models, reduce your bucket size to ~100M.

Rule of Thumb for Bucket Sizing

Set reduce_bucket_size to roughly 2-3 layers' worth of parameters. This ensures:

  • The first bucket fires within 2-3 layers of backward — early enough for good overlap
  • Each bucket is large enough to saturate the network link
  • You get at least 10+ buckets for a large model — plenty of overlap opportunities
Guideline bucket_size ≈ params_per_layer × 2–3 ≈ (total_params / num_layers) × 2

Communication Overlap

The magic that makes distributed training fast: hide communication behind computation.

Overlap
overlap_comm
true

When enabled, gradient communication (reduce-scatter) runs concurrently with the backward pass. While one layer's gradients are being communicated, the next layer's backward is already computing.

bwd(layer N) + reduce-scatter(layer N+1) // concurrent!
Overlap
reduce_scatter
true (Stage 2)

Use reduce-scatter instead of all-reduce for gradients. In reduce-scatter, each GPU only receives its own slice of the averaged gradients — exactly what ZeRO needs for gradient partitioning.

Communication Volume All-Reduce: 2P  |  Reduce-Scatter: P (each GPU gets 1/N)
Overlap OFF vs ON
Toggle to see how overlap_comm transforms the timeline. When ON, communication is hidden behind backward compute.
Forward
Backward
Reduce-Scatter
All-Gather
Optimizer
Idle

Memory Layout & Contiguous Gradients

How DeepSpeed arranges memory to make communication faster and reduce fragmentation.

Memory
contiguous_gradients = true

Contiguous Gradient Buffer

Pre-allocate one big contiguous block of memory for all gradients instead of letting them scatter across the heap.

Benefits:

  • No memory fragmentation from scattered allocations
  • Faster network transfers (DMA prefers contiguous memory)
  • Predictable memory usage

Scattered vs Contiguous

Visualizing how gradients are laid out in GPU memory:

Scattered contiguous_gradients = false
g1
g2
g3
g4
Gaps = fragmentation. Each send requires separate DMA.
Contiguous contiguous_gradients = true
g1
g2
g3
g4
One contiguous block. Single DMA transfer for each bucket.
Memory cost of communication buffers: With both reduce_bucket_size and allgather_bucket_size at 500M elements in BF16, you're allocating ~2 GB just for communication buffers. On a 40 GB A100, that's 5% of VRAM — worth it for the communication efficiency, but be aware on smaller GPUs.

Parameter Sharding & Lifecycle

Stage 3 partitions parameters across GPUs, so each layer must be gathered before use and flushed after. These parameters control that lifecycle.

Parameter Persistence Threshold

Parameters smaller than stage3_param_persistence_threshold (100,000 elements) are kept on all GPUs permanently instead of being gathered/flushed. Drag the threshold to see which layers stay replicated.

Threshold: 100K
Kept replicated (below threshold)
Sharded (above threshold — gathered on demand)
4
Replicated Params
6
Sharded Params
0.4 MB
Replication Memory Cost
ZeRO-3
stage3_max_live_parameters
1,000M (1 Billion elements)

Maximum number of parameter elements that can be materialized (gathered) across all GPUs at once during forward/backward.

This caps peak memory from temporary all-gathers. If a layer needs more than this, it gets broken into smaller chunks and gathered piece by piece.

Peak memory cap (BF16) 1,000M × 2 bytes = ~2 GB max live
ZeRO-3
stage3_max_reuse_distance
1,000M (1 Billion elements)

If a parameter will be needed again within this many elements of compute, keep it around instead of flushing and re-gathering.

High value = more aggressive caching. Trades memory for reduced communication. Useful when the same parameter appears in multiple operations close together.

ZeRO-3
sub_group_size
1,000M (1 Billion elements)

During the optimizer step, parameters are processed in sub-groups of this size to control peak memory.

Smaller = less peak memory during optimizer step, but more overhead from processing many small groups. Default 1,000M means "process everything at once" for most models.

ZeRO-3
stage3_gather_16bit_weights_on_model_save
true

When saving a checkpoint, temporarily all-gather all sharded parameters into a full model in FP16/BF16.

This gives you a single complete checkpoint file instead of N shard files. Set to false if you always load with the same number of GPUs and want faster saves.

ZeRO-3 Parameter Lifecycle (Interactive)

Step through the lifecycle of a parameter in ZeRO-3. See how parameters are gathered, used, and flushed across 4 GPUs.

Communication Timeline Builder

Configure ZeRO parameters and see how they affect the training timeline for a 7B parameter model on 4 GPUs.

reduce_bucket: 500M
allgather_bucket: 500M
prefetch_bucket: 200M
overlap_comm:
ZeRO Stage:
14
Reduce-Scatter Calls
2.0 GB
Total Buffer Memory
100%
Relative Step Time
85%
Compute Efficiency

Full ZeRO Config

Here's what all these parameters look like in a DeepSpeed configuration file, with annotations.

// DeepSpeed ZeRO Stage 3 config
{
  "zero_optimization": {
    "stage": 3,

    // === Bucket Sizes (Communication Granularity) ===
    "reduce_bucket_size":          500000000,  // 500M elements → ~1 GB (BF16)
    "allgather_bucket_size":       500000000,  // 500M elements → ~1 GB (BF16)
    "stage3_prefetch_bucket_size": 200000000,  // 200M — smaller for early prefetch

    // === Communication Overlap ===
    "overlap_comm":    true,   // Overlap comm with backward compute
    "reduce_scatter":  true,   // Use reduce-scatter (not all-reduce)

    // === Memory Layout ===
    "contiguous_gradients": true,  // One big buffer, less fragmentation

    // === ZeRO-3: Persistence & Caching ===
    "stage3_param_persistence_threshold": 100000,      // <100K → keep replicated
    "stage3_max_live_parameters":        1000000000,  // 1,000M max gathered at once
    "stage3_max_reuse_distance":         1000000000,  // Cache if reused within 1,000M elems
    "sub_group_size":                    1000000000,  // 1,000M optimizer sub-group

    // === Checkpointing ===
    "stage3_gather_16bit_weights_on_model_save": true  // Full checkpoint
  }
}

Parameter Quick Reference

Parameter Default Category Effect of Increasing Effect of Decreasing
reduce_bucket_size 500M Bucket Fewer comm calls Earlier first call, better overlap
allgather_bucket_size 500M Bucket Fewer all-gather calls More granular parameter broadcast
stage3_prefetch_bucket_size 200M Bucket Fewer prefetch calls Earlier prefetch start
overlap_comm true Overlap Always keep true unless debugging
reduce_scatter true Overlap Required for ZeRO gradient partitioning
contiguous_gradients true Memory Always keep true for performance
stage3_param_persistence_threshold 100K ZeRO-3 More params replicated (more mem, less comm) More params sharded (less mem, more comm)
stage3_max_live_parameters 1,000M ZeRO-3 Higher peak memory Lower peak memory, more chunks
stage3_max_reuse_distance 1,000M ZeRO-3 More caching (mem ↑, comm ↓) Less caching (mem ↓, comm ↑)
sub_group_size 1,000M ZeRO-3 Higher peak optimizer memory Lower peak, more overhead

What to Remember

1

Buckets are the Goldilocks zone

Not too small (overhead from many calls), not too big (delays overlap opportunity). The 500M default is tuned for most multi-GPU setups.

2

Overlap is the #1 performance lever

overlap_comm = true hides communication behind backward compute. This is what makes distributed training practical at scale.

3

Prefetch starts small on purpose

stage3_prefetch_bucket_size is 200M (not 500M) because you want prefetching to start early, overlapping with current-layer compute.

4

Small params: just replicate them

LayerNorm weights (~1K params) cost nothing to keep on every GPU. Gathering them every forward pass wastes more time than the memory they save.

5

Contiguous memory = faster transfers

DMA engines and NCCL prefer contiguous buffers. contiguous_gradients = true is a free performance win.

6

Buffer memory adds up

Two 500M buckets in BF16 = ~2 GB of buffer memory. On smaller GPUs (16 GB), consider reducing bucket sizes to 200M.