DeepSpeed's ZeRO optimizer has over a dozen tuning parameters that control communication granularity, memory layout, and prefetching. Instead of memorizing defaults, let's build intuition for what each parameter does and why.
Every ZeRO config parameter falls into one of three categories. Master these three levers and the rest follows naturally.
How many parameter elements get batched together into a "bucket" before a collective communication call fires. Larger buckets = fewer calls, more latency before the first one.
How much communication hides behind computation. When overlap is enabled, communication runs concurrently with backward passes, making the transfer time "invisible".
Which parameters are worth sharding vs. keeping replicated. Tiny params cost negligible memory to replicate but incur real latency to gather every time.
Buckets control communication granularity — how many elements accumulate before a collective fires.
Gradients accumulate into a buffer of this size before the reduce-scatter fires. Used in both Stage 2 and Stage 3.
Same idea but for the all-gather step: when broadcasting updated parameters back. Controls how many elements are gathered per call.
ZeRO-3 only. How many parameters to prefetch (all-gather in advance) for the next layer while the current layer computes. Intentionally smaller — you want prefetching to start early.
Drag the slider to see how bucket size affects the number of communication calls and buffer memory for a 1B parameter model.
stage3_prefetch_bucket_size smaller (200M vs 500M)?If the bucket is 500M elements but a single layer only produces ~220M gradients, doesn't the bucket need multiple layers to fill? And doesn't that delay the first communication call?
The key insight: on modern hardware with NVLink, communication is fast relative to compute. Compute (backward) is the slower one — which is exactly what you want, because it means the network finishes before the GPU does, and the overlap is free.
Time per Bucket: Compute vs Communication
overlap_comm = false. For small models, reduce your bucket size to ~100M.
Set reduce_bucket_size to roughly 2-3 layers' worth of parameters. This ensures:
The magic that makes distributed training fast: hide communication behind computation.
When enabled, gradient communication (reduce-scatter) runs concurrently with the backward pass. While one layer's gradients are being communicated, the next layer's backward is already computing.
Use reduce-scatter instead of all-reduce for gradients. In reduce-scatter, each GPU only receives its own slice of the averaged gradients — exactly what ZeRO needs for gradient partitioning.
overlap_comm transforms the timeline. When ON, communication is hidden behind backward compute.How DeepSpeed arranges memory to make communication faster and reduce fragmentation.
Pre-allocate one big contiguous block of memory for all gradients instead of letting them scatter across the heap.
Benefits:
Visualizing how gradients are laid out in GPU memory:
reduce_bucket_size and allgather_bucket_size at 500M elements in BF16, you're allocating ~2 GB just for communication buffers. On a 40 GB A100, that's 5% of VRAM — worth it for the communication efficiency, but be aware on smaller GPUs.
Stage 3 partitions parameters across GPUs, so each layer must be gathered before use and flushed after. These parameters control that lifecycle.
Parameters smaller than stage3_param_persistence_threshold (100,000 elements) are kept on all GPUs permanently instead of being gathered/flushed. Drag the threshold to see which layers stay replicated.
Maximum number of parameter elements that can be materialized (gathered) across all GPUs at once during forward/backward.
This caps peak memory from temporary all-gathers. If a layer needs more than this, it gets broken into smaller chunks and gathered piece by piece.
If a parameter will be needed again within this many elements of compute, keep it around instead of flushing and re-gathering.
High value = more aggressive caching. Trades memory for reduced communication. Useful when the same parameter appears in multiple operations close together.
During the optimizer step, parameters are processed in sub-groups of this size to control peak memory.
Smaller = less peak memory during optimizer step, but more overhead from processing many small groups. Default 1,000M means "process everything at once" for most models.
When saving a checkpoint, temporarily all-gather all sharded parameters into a full model in FP16/BF16.
This gives you a single complete checkpoint file instead of N shard files. Set to false if you always load with the same number of GPUs and want faster saves.
Step through the lifecycle of a parameter in ZeRO-3. See how parameters are gathered, used, and flushed across 4 GPUs.
Configure ZeRO parameters and see how they affect the training timeline for a 7B parameter model on 4 GPUs.
Here's what all these parameters look like in a DeepSpeed configuration file, with annotations.
// DeepSpeed ZeRO Stage 3 config { "zero_optimization": { "stage": 3, // === Bucket Sizes (Communication Granularity) === "reduce_bucket_size": 500000000, // 500M elements → ~1 GB (BF16) "allgather_bucket_size": 500000000, // 500M elements → ~1 GB (BF16) "stage3_prefetch_bucket_size": 200000000, // 200M — smaller for early prefetch // === Communication Overlap === "overlap_comm": true, // Overlap comm with backward compute "reduce_scatter": true, // Use reduce-scatter (not all-reduce) // === Memory Layout === "contiguous_gradients": true, // One big buffer, less fragmentation // === ZeRO-3: Persistence & Caching === "stage3_param_persistence_threshold": 100000, // <100K → keep replicated "stage3_max_live_parameters": 1000000000, // 1,000M max gathered at once "stage3_max_reuse_distance": 1000000000, // Cache if reused within 1,000M elems "sub_group_size": 1000000000, // 1,000M optimizer sub-group // === Checkpointing === "stage3_gather_16bit_weights_on_model_save": true // Full checkpoint } }
| Parameter | Default | Category | Effect of Increasing | Effect of Decreasing |
|---|---|---|---|---|
reduce_bucket_size |
500M | Bucket | Fewer comm calls | Earlier first call, better overlap |
allgather_bucket_size |
500M | Bucket | Fewer all-gather calls | More granular parameter broadcast |
stage3_prefetch_bucket_size |
200M | Bucket | Fewer prefetch calls | Earlier prefetch start |
overlap_comm |
true | Overlap | Always keep true unless debugging | |
reduce_scatter |
true | Overlap | Required for ZeRO gradient partitioning | |
contiguous_gradients |
true | Memory | Always keep true for performance | |
stage3_param_persistence_threshold |
100K | ZeRO-3 | More params replicated (more mem, less comm) | More params sharded (less mem, more comm) |
stage3_max_live_parameters |
1,000M | ZeRO-3 | Higher peak memory | Lower peak memory, more chunks |
stage3_max_reuse_distance |
1,000M | ZeRO-3 | More caching (mem ↑, comm ↓) | Less caching (mem ↓, comm ↑) |
sub_group_size |
1,000M | ZeRO-3 | Higher peak optimizer memory | Lower peak, more overhead |
Not too small (overhead from many calls), not too big (delays overlap opportunity). The 500M default is tuned for most multi-GPU setups.
overlap_comm = true hides communication behind backward compute. This is what makes distributed training practical at scale.
stage3_prefetch_bucket_size is 200M (not 500M) because you want prefetching to start early, overlapping with current-layer compute.
LayerNorm weights (~1K params) cost nothing to keep on every GPU. Gathering them every forward pass wastes more time than the memory they save.
DMA engines and NCCL prefer contiguous buffers. contiguous_gradients = true is a free performance win.
Two 500M buckets in BF16 = ~2 GB of buffer memory. On smaller GPUs (16 GB), consider reducing bucket sizes to 200M.