A visual, from-first-principles walkthrough of data-parallel training — from a single GPU baseline to production-grade PyTorch DDP.
Four different ways to train the same model — each progressively smarter about overlapping communication with computation.
Standard training on 1 GPU. No communication needed. This is our baseline.
N GPUs. Full forward + backward, then blocking AllReduce. Communication and computation are separate phases.
N GPUs. Non-blocking AllReduce fires per layer during backward. Communication overlaps with computation.
Production-grade interleaved DP with gradient bucketing for fewer, larger AllReduce calls.
One GPU processes the entire dataset. Training is slow because a single GPU can only process one batch at a time.
Each GPU gets a shard of the batch. Both GPUs compute forward + backward in parallel — ~2× faster compute!
(grad_A + grad_B) / 2Each GPU computed gradients from different data. We must average them so all GPUs update weights identically. This is AllReduce.
After all gradients are computed (Naive)? Or during backward, per-layer (Interleaved)? That choice is what this notebook explores.
The most important collective communication operation in distributed training.
Every GPU ends up with the same averaged gradient. This ensures weight updates are identical across all GPUs, keeping models in sync.
Before anything, verify that we have GPUs and NCCL (the GPU communication library).
import torch
import os
# How many GPUs does this machine have?
NUM_GPUS = torch.cuda.device_count()
print(f"PyTorch version : {torch.__version__}")
print(f"CUDA available : {torch.cuda.is_available()}")
print(f"NCCL available : {torch.distributed.is_nccl_available()}")
print(f"GPUs found : {NUM_GPUS}")
for i in range(NUM_GPUS):
props = torch.cuda.get_device_properties(i)
mem = props.total_memory / 1024**3
print(f" GPU {i}: {props.name} ({mem:.1f} GB)")
PyTorch version : 2.4.1+cu124 CUDA available : True NCCL available : True GPUs found : 4 GPU 0: NVIDIA A100 80GB PCIe (79.3 GB) GPU 1: NVIDIA A100 80GB PCIe (79.3 GB) GPU 2: NVIDIA A100 80GB PCIe (79.3 GB) GPU 3: NVIDIA A100 80GB PCIe (79.3 GB)
Returns the number of NVIDIA GPUs available. We need at least 2 for data parallelism (one model copy per GPU).
NVIDIA Collective Communications Library — the library that handles GPU-to-GPU communication. It implements AllReduce, Broadcast, and other operations using high-speed interconnects like NVLink or PCIe.
We're running on 4× A100 80GB — NVIDIA's datacenter GPU with 80GB of HBM2e memory. Each one has ~312 TFLOPS of BF16 compute. Having 4 GPUs lets us test both 2-GPU and 4-GPU scaling.
A deliberately deep model so that interleaved AllReduce has many layers to overlap with.
import torch
import torch.nn as nn
def build_model(size='medium'):
configs = {
'small': [784, 1024, 512, 256, 10], # ~1.2M params
'medium': [784, 2048, 2048, 1024, 512, 10], # ~7.2M params
'large': [784, 4096, 4096, 2048, # ~35M params
2048, 1024, 512, 10],
}
sizes = configs[size]
layers = []
for i in range(len(sizes) - 1):
layers.append(nn.Linear(sizes[i], sizes[i+1]))
if i < len(sizes) - 2:
layers.append(nn.ReLU())
model = nn.Sequential(*layers)
return model
def make_dataset(n_samples=32768):
torch.manual_seed(42)
X = torch.randn(n_samples, 784)
y = torch.randint(0, 10, (n_samples,))
return X, y
The model is a Multi-Layer Perceptron (fully-connected neural network). It's made deliberately deep (many layers) because interleaved AllReduce benefits from having many layers — each layer's gradient can overlap with the next layer's communication.
32,768 samples of random data shaped like MNIST (784 features, 10 classes). Using a fixed seed ensures every GPU generates identical data — critical for fair sharding.
| Size | Layers | Parameters |
|---|---|---|
| Small | 4 | 1.5M |
| Medium | 5 | 8.4M |
| Large | 7 | 35.2M |
| XLarge | 8 | 52M |
The simplest possible training loop. No communication, no parallelism. Our speed baseline.
# Standard training loop — the simplest possible form
device = torch.device('cuda:0')
model, n_params = build_model(config['model_size'])
model = model.to(device) # Move model to GPU
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
loss_fn = nn.CrossEntropyLoss()
X, y = make_dataset(32768)
X, y = X.to(device), y.to(device) # Data → GPU
# --- Warmup (so CUDA kernels are compiled) ---
for _ in range(3):
idx = torch.randint(0, len(X), (bs,))
out = model(X[idx])
loss_fn(out, y[idx]).backward()
optimizer.zero_grad()
torch.cuda.synchronize()
# --- Training loop ---
for epoch in range(n_epochs):
for start in range(0, len(X), bs):
xb = X[start:start+bs]
yb = y[start:start+bs]
optimizer.zero_grad() # 1. Clear old gradients
out = model(xb) # 2. Forward pass
loss = loss_fn(out, yb) # 3. Compute loss
loss.backward() # 4. Backward pass (gradients)
optimizer.step() # 5. Update weights
Every neural network training follows these steps:
The first time a CUDA kernel runs, PyTorch must compile it. This one-time cost would skew our benchmarks. Running 3 dummy batches ensures all kernels are pre-compiled before we start timing.
GPU operations are asynchronous — the CPU queues work and moves on. synchronize() makes the CPU wait until all GPU work is done. Essential for accurate timing.
Split the batch across GPUs, compute gradients independently, then synchronize everything at once.
# --- Setup distributed process group ---
dist.init_process_group(backend='nccl')
rank = dist.get_rank() # Which GPU am I? (0 or 1)
world_size = dist.get_world_size() # How many GPUs total?
device = torch.device(f'cuda:{rank}')
torch.cuda.set_device(device)
# All ranks build SAME model (same seed = same weights)
torch.manual_seed(42)
model, n_params = build_model(config['model_size'])
model = model.to(device)
# Each rank loads full dataset, then shards per batch
X_all, y_all = make_dataset(32768)
X_all, y_all = X_all.to(device), y_all.to(device)
for start in range(0, len(X_all), bs):
xb, yb = X_all[start:start+bs], y_all[start:start+bs]
# ① Shard the batch across GPUs
chunk = len(xb) // world_size
s = rank * chunk
e = s + chunk
x_local, y_local = xb[s:e], yb[s:e]
# ② COMPUTATION PHASE: full forward + backward
optimizer.zero_grad()
out = model(x_local)
loss = loss_fn(out, y_local)
loss.backward() # ALL gradients computed
# ③ COMMUNICATION PHASE: AllReduce ALL grads (BLOCKING)
for param in model.parameters():
if param.grad is not None:
dist.all_reduce(param.grad, op=dist.ReduceOp.SUM)
param.grad /= world_size # Average
# ④ WEIGHT UPDATE
optimizer.step()
torchrun launches one process per GPU. This call connects them into a group so they can communicate. NCCL is the backend that handles the actual GPU-to-GPU data transfer.
rank = this process's ID (0 or 1). world_size = total processes (2). Each process is assigned to a different GPU.
A batch of 1024 samples gets split in half: GPU 0 processes samples 0-511, GPU 1 processes samples 512-1023. Each GPU does ~half the compute.
After loss.backward() finishes, we loop through every parameter and call dist.all_reduce(). This is blocking — the GPU sits completely idle while gradients are transferred over the network. For 35M parameters, this can waste 20-40% of total time.
Fire non-blocking AllReduce per layer during backward — overlap communication with computation!
class InterleavedDP:
def __init__(self, model, world_size):
self.model = model
self.world_size = world_size
self._handles = [] # async AllReduce handles
# Register a hook on EVERY parameter
# When a gradient is ready → fire async AllReduce
for param in self.model.parameters():
param.register_post_accumulate_grad_hook(
self._make_hook(param)
)
def _make_hook(self, param):
def hook(p):
# 🔑 KEY: Non-blocking AllReduce!
# Fire immediately, don't wait
handle = dist.all_reduce(
p.grad,
op=dist.ReduceOp.SUM,
async_op=True # ← NON-BLOCKING!
)
self._handles.append((handle, p))
return hook
def finish_allreduce(self):
# Wait for ALL async ops before weight update
for handle, param in self._handles:
handle.wait()
param.grad /= self.world_size
self._handles.clear()
dp = InterleavedDP(model, world_size)
# Training loop — almost identical to naive!
optimizer.zero_grad()
out = model(x_local)
loss = loss_fn(out, y_local)
# Backward fires async AllReduce per layer via hooks
loss.backward()
# Wait for all async AllReduces to complete
dp.finish_allreduce()
optimizer.step()
Instead of blocking AllReduce after all gradients are computed, we fire a non-blocking AllReduce the instant each layer's gradient is ready. While NCCL transfers layer 6's gradients, the GPU is already computing layer 5's gradients.
register_post_accumulate_grad_hook tells PyTorch: "Call this function the moment this parameter's gradient is computed." This is exactly how ShallowSpeed does it with MPI:
autograd.register_grad_hook(backprop_allreduce_gradient)
Before optimizer.step(), we must ensure ALL async AllReduces are complete. handle.wait() blocks until that specific operation finishes. This mirrors ShallowSpeed's MPI.Request.Waitall().
The same interleaving idea, but with gradient bucketing for even fewer AllReduce calls.
from torch.nn.parallel import DistributedDataParallel as DDP
# Setup (same as naive)
dist.init_process_group(backend='nccl')
rank = dist.get_rank()
device = torch.device(f'cuda:{rank}')
torch.manual_seed(42)
model, n_params = build_model(config['model_size'])
model = model.to(device)
# ✨ ONE LINE — wraps model with DDP magic
model = DDP(model, device_ids=[rank])
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
loss_fn = nn.CrossEntropyLoss()
# Training loop — IDENTICAL to single GPU!
optimizer.zero_grad()
out = model(x_local)
loss = loss_fn(out, y_local)
loss.backward() # DDP handles AllReduce automatically!
optimizer.step()
PyTorch's DistributedDataParallel wraps your model and automatically:
Each NCCL AllReduce call has fixed launch overhead (~10-50μs). With 12 parameters, that's 12 launches. DDP groups them into 2-3 large buckets → fewer launches, better bandwidth utilization.
A script that launches all 4 training modes with torchrun and collects timing results.
# Configuration for all benchmark runs
base = {
'n_samples': 32768, # Dataset size
'batch_size': 1024, # Samples per batch
'lr': 0.01, # Learning rate
'n_epochs': 5, # Training epochs
}
model_sizes = ['small', 'medium', 'large']
# For each model size, run all 4 modes:
for msz in model_sizes:
# 1. Single GPU baseline
run(['python', 'train_single_gpu.py', cfg_str])
# 2-4. Multi-GPU modes via torchrun
tr = ['torchrun', '--nproc_per_node', str(ng)]
run(tr + ['train_dp_naive.py', cfg_str])
run(tr + ['train_dp_interleaved.py', cfg_str])
run(tr + ['train_ddp_builtin.py', cfg_str])
PyTorch's launcher for distributed training. --nproc_per_node 2 starts 2 identical processes, one per GPU. Each process gets assigned a unique rank (0 or 1) and knows the total world_size (2).
We test small (1.5M), medium (8.4M), large (35.2M), and xlarge (52M) models across 1, 2, and 4 GPUs. That's 28 benchmark runs. Bigger models have more parameters to communicate — making the naive vs interleaved gap more visible.
Each training script prints results as a JSON string prefixed with RESULTS_JSON:. The benchmark runner captures stdout and parses this line. A simple, reliable way to collect metrics from subprocess runs.
Real numbers from 4× A100 80GB GPUs. 28 benchmark runs across 4 model sizes, testing 2-GPU and 4-GPU scaling.
| Model | Mode | GPUs | Params | Avg Epoch | Comm Time | Throughput | Speedup |
|---|---|---|---|---|---|---|---|
| Small | Single GPU | 1 | 1.5M | 0.0361s | — | 908,941/s | 1.00× |
| Naive DP | 2 | 1.5M | 0.0602s | 0.0284s | 544,589/s | 0.60× | |
| Interleaved | 2 | 1.5M | 0.0620s | — | 528,861/s | 0.58× | |
| PyTorch DDP | 2 | 1.5M | 0.0436s | — | 751,415/s | 0.83× | |
| Naive DP | 4 | 1.5M | 0.0653s | 0.0307s | 501,482/s | 0.55× | |
| Interleaved | 4 | 1.5M | 0.0643s | — | 510,000/s | 0.56× | |
| PyTorch DDP | 4 | 1.5M | 0.0482s | — | 679,871/s | 0.75× | |
| Medium | Single GPU | 1 | 8.4M | 0.1180s | — | 277,724/s | 1.00× |
| Naive DP | 2 | 8.4M | 0.1071s | 0.0328s | 305,929/s | 1.10× | |
| Interleaved | 2 | 8.4M | 0.0898s | — | 364,829/s | 1.31× | |
| PyTorch DDP | 2 | 8.4M | 0.0871s | — | 376,138/s | 1.35× | |
| Naive DP | 4 | 8.4M | 0.0863s | 0.0370s | 379,604/s | 1.37× | |
| Interleaved | 4 | 8.4M | 0.0871s | — | 376,152/s | 1.35× | |
| PyTorch DDP | 4 | 8.4M | 0.0674s | — | 486,144/s | 1.75× | |
| Large | Single GPU | 1 | 35.2M | 0.4112s | — | 79,681/s | 1.00× |
| Naive DP | 2 | 35.2M | 0.2850s | 0.0501s | 114,977/s | 1.44× | |
| Interleaved | 2 | 35.2M | 0.2644s | — | 123,937/s | 1.56× | |
| PyTorch DDP | 2 | 35.2M | 0.2597s | — | 126,171/s | 1.58× | |
| Naive DP | 4 | 35.2M | 0.2143s | 0.0744s | 152,925/s | 1.92× | |
| Interleaved | 4 | 35.2M | 0.1837s | — | 178,394/s | 2.24× | |
| PyTorch DDP | 4 | 35.2M | 0.1784s | — | 183,684/s | 2.31× | |
| XLarge | Single GPU | 1 | 52.0M | 0.5956s | — | 55,016/s | 1.00× |
| Naive DP | 2 | 52.0M | 0.4045s | 0.0700s | 81,008/s | 1.47× | |
| Interleaved | 2 | 52.0M | 0.3669s | — | 89,303/s | 1.62× | |
| PyTorch DDP | 2 | 52.0M | 0.3620s | — | 90,531/s | 1.65× | |
| Naive DP | 4 | 52.0M | 0.2898s | 0.0940s | 113,077/s | 2.06× | |
| Interleaved | 4 | 52.0M | 0.2485s | — | 131,844/s | 2.40× | |
| PyTorch DDP | 4 | 52.0M | 0.2496s | — | 131,300/s | 2.39× |
With only 1.5M parameters, computation finishes in 36ms on one GPU. Every parallel mode is slower — even 4-GPU DDP achieves only 0.75× of the single-GPU speed. The overhead of process setup, batch sharding, and AllReduce communication exceeds the compute savings. The small model simply doesn't have enough work to keep the GPUs busy while communication happens.
As models grow from 8.4M to 52M parameters, scaling efficiency dramatically improves. The XLarge model achieves 2.40× speedup on 4 GPUs (60% scaling efficiency) vs medium's 1.75×. Larger models have more FLOPs per layer, giving NCCL more time to finish AllReduce before the next gradient is needed.
Across all model sizes with 4 GPUs: Large naive=1.92× vs interleaved=2.24× (14% faster). XLarge naive=2.06× vs interleaved=2.40× (14% faster). The improvement comes purely from overlapping communication with computation — the exact optimization ShallowSpeed teaches.
PyTorch DDP consistently matches our interleaved implementation (within 1-2%) and sometimes beats it — e.g., Medium 4G: DDP=1.75× vs interleaved=1.35×. DDP's gradient bucketing helps most when there are many small parameter tensors, reducing NCCL launch overhead.
Naive DP's communication time jumps from 0.050s (2G) to 0.074s (4G) for the large model — a 49% increase. Ring AllReduce requires 2(N-1)/N rounds with N GPUs. More GPUs = more communication rounds, making interleaving even more critical.
Our best result: XLarge model, 4 GPUs, interleaved DP = 2.40× speedup, processing 131,844 samples/sec vs 55,016/sec baseline. At this scale, the compute-to-communication ratio is high enough that interleaving hides nearly all communication latency.
The notebook generates 6 charts. Here's what each one reveals about our 4× A100 benchmark.
The Researcher's Mindset — Why does the crossover happen?
Every distributed training step has two costs: computation (forward + backward FLOPs) and communication (transferring gradients via AllReduce). Speedup only happens when the computation saved by splitting data across GPUs exceeds the communication overhead added.
For the small model (1.5M params), a single forward+backward pass takes only ~5ms of GPU compute — the A100 chews through it almost instantly. But AllReduce of 1.5M parameters over PCIe takes ~28ms regardless. So the communication is 5-6× more expensive than the compute it's trying to parallelize. Adding GPUs just adds overhead with no meaningful compute savings.
For the XLarge model (52M params), a single forward+backward takes ~500ms — a heavy workload. AllReduce of 52M params takes ~70-94ms. Now the compute-to-communication ratio has flipped: computation is 5-7× larger than communication. Splitting this across 4 GPUs saves ~375ms of compute while adding ~94ms of communication — a net win of ~280ms per epoch. This is why you see the dark bar (0.596s) shrink to the dark green bar (0.250s).
The fundamental law: Data parallelism pays off when T_compute / T_communication >> 1. Bigger models have more FLOPs per parameter, pushing this ratio higher. This is why in the real world, data parallelism is used for models with billions of parameters, not thousands.
The Researcher's Mindset — Why can't we reach ideal linear scaling, and why does the naive-interleaved gap widen?
1. Why not ideal scaling? Ideal linear scaling (4 GPUs = 4×) requires zero communication overhead. In practice, Ring AllReduce requires 2(N-1)/N × D data transfers (see the Ring AllReduce section above). With 4 GPUs and 52M parameters (each 4 bytes = 208MB of gradient data), that's 1.5 × 208MB = 312MB transferred per GPU. Over PCIe Gen4 (~25 GB/s effective), this takes at minimum ~12ms just for raw transfer — plus per-step latency for 6 sequential ring steps. This communication tax is irreducible overhead that prevents reaching the ideal.
2. Why does naive (red) always lag behind interleaved (blue/green)? The math is precise:
Naive total time: T_naive = T_compute/N + T_allreduce (sequential: compute first, then communicate)
Interleaved total time: T_interleaved ≈ max(T_compute/N, T_allreduce) (overlapped: compute and communicate in parallel)
The gap between them = T_naive - T_interleaved = T_compute/N + T_allreduce - max(T_compute/N, T_allreduce) = min(T_compute/N, T_allreduce). In other words, interleaving saves whichever is smaller: the compute time or the communication time.
3. Why does this gap WIDEN for larger models? Let's plug in real numbers:
Small (4G): T_compute/4 ≈ 3ms, T_allreduce ≈ 31ms → Gap = min(3, 31) = 3ms (tiny — compute is so small there's almost nothing to overlap)
Medium (4G): T_compute/4 ≈ 12ms, T_allreduce ≈ 37ms → Gap = min(12, 37) = 12ms (growing — more compute to overlap with)
Large (4G): T_compute/4 ≈ 35ms, T_allreduce ≈ 74ms → Gap = min(35, 74) = 35ms (significant — interleaving hides all the compute behind communication)
XLarge (4G): T_compute/4 ≈ 50ms, T_allreduce ≈ 94ms → Gap = min(50, 94) = 50ms (massive — 50ms of GPU idle time eliminated per epoch)
As models grow, T_compute/N grows, so min(T_compute/N, T_allreduce) grows — meaning interleaving saves more and more time. The gap keeps widening because there's more backward computation available to run concurrently with AllReduce. This is the central insight: interleaving is most valuable when both compute and communication are significant, and its benefit scales with model size.
4. Why do lines curve upward from 2→4 GPUs for large models but stay flat for small? Going from 2→4 GPUs halves per-GPU compute (saving ~100ms for XLarge) but increases AllReduce time by ~24ms (more ring steps). For large models, the net savings is ~76ms — a huge win. For small models, halving ~5ms of compute while adding ~3ms of communication yields a net savings of only ~2ms, swamped by the fixed overhead of additional process coordination. You're communication-bound — adding GPUs adds latency with negligible compute savings.
The Researcher's Mindset: Why does communication % increase with more GPUs? Two effects compound: (1) Ring AllReduce does 2(N-1)/N data passes — going from 2→4 GPUs increases total data transferred from 1× to 1.5× the gradient size. (2) Each GPU's compute halves (processing 1/4 vs 1/2 of the batch), so the denominator shrinks while the numerator grows. For small models (47% comm), the compute per GPU is only ~3ms while AllReduce takes ~28ms — the system is communication-bound. For XLarge (17% at 2G), compute dominates. This ratio is the single most important number in distributed training design.
The Researcher's Mindset: Why does interleaving save more at 4 GPUs than 2? Because AllReduce time grows with GPU count (more Ring hops), creating more idle time for naive DP to waste. Interleaving hides this growing cost by running communication concurrently with backward computation. The 14% savings at 4G represents ~30ms of overlapped communication per epoch — time the naive approach spends with GPUs completely idle. DDP slightly beats our interleaved because bucketing reduces NCCL kernel launch overhead: 2-3 large AllReduces are faster than 12+ small ones, even with the same total data volume. Each NCCL call has ~10-50μs of fixed overhead, which adds up.
The Researcher's Mindset: There are three distinct clusters, grouped by GPU count: single GPU (1G) sits highest, 2-GPU modes (dp_naive, dp_interleaved, ddp_builtin) cluster together in the middle, and 4-GPU modes cluster at the bottom.
Why do modes within the same GPU count produce identical curves? Because they're all doing the same math. Regardless of whether you use naive, interleaved, or DDP — every mode does AllReduce(SUM)/N on the gradients. The gradient update is mathematically identical. The only difference is when the AllReduce happens (after all layers vs. per-layer), not what it computes.
But why do different GPU counts produce different clusters? This is actually a logging artifact, not a real training difference. Here's what the code does:
The effective batch size is actually the same across all GPU counts (1024 samples). The batch is split across GPUs, each GPU computes gradients on its shard, and AllReduce(SUM)/N produces the exact same averaged gradient as a single GPU processing all 1024 samples. The models are learning identically.
The visual separation happens because loss.item() captures rank 0's local micro-batch loss (computed on 1024, 512, or 256 samples depending on GPU count) — and rank 0 always sees the same fixed slice of each batch (indices [0:chunk]) since the dataset isn't shuffled. These fixed subsets have systematically different average loss values.
Why not log the true average loss? You'd need an extra dist.all_reduce(loss_tensor) call every step just for monitoring. This adds communication overhead that would distort the timing benchmarks — which is the whole point of this notebook. For production training you'd want the averaged loss, but for benchmarking speed, this shortcut is acceptable.
Key takeaway: If you plotted loss vs total samples seen (instead of loss vs step), or if you averaged the loss across all ranks before logging, all three clusters would collapse onto the same curve. The training is identical — only the monitoring differs.
The Researcher's Mindset: Whether adding GPUs helps depends entirely on the compute-to-communication ratio. Parallel epoch time = T_compute/N + T_comm. If T_comm exceeds the compute saved, you lose.
Small (1.5M): Single GPU = 0.036s. With 4 GPUs: compute drops to ~0.009s but AllReduce adds ~0.030s → total 0.065s. Slower, because communication (0.030s) exceeds compute saved (0.027s). All parallel modes produce 0.55–0.83× throughput.
XLarge (52M): Single GPU = 0.596s. With 4 GPUs: compute drops to ~0.149s and AllReduce costs ~0.094s → total 0.249s. Much faster, because compute saved (0.447s) dwarfs communication cost. Interleaved achieves 2.40× speedup.
Where does f come from? We measure f directly from our naive DP benchmarks, where compute and communication are separate phases. For XLarge 4G: T_compute = epoch_time − comm_time = 0.290 − 0.094 = 0.196s, T_comm = 0.094s. Relative to the original single-GPU time: f = T_compute_single / (T_compute_single + T_comm) = 0.596 / (0.596 + 0.094) = 0.86. Plugging into Amdahl's Law: speedup = 1/((1-f) + f/N). The intuition: your total work has a parallelizable fraction f (compute, which splits across N GPUs → f/N) and a serial fraction 1-f (communication, which doesn't shrink with more GPUs). No matter how many GPUs you add, you can never eliminate that 1-f serial portion — it sets a hard ceiling on speedup. For XLarge: 1/(0.14 + 0.86/4) = 2.7× theoretical max — we achieve 2.4× (89% efficiency). For small models: f = 0.036/(0.036+0.030) = 0.55, theoretical max = 1.3× — but fixed overhead (process launch, NCCL init, barriers) pushes actual speedup below 1×.
Real-world implication: LLMs with billions of parameters have f > 0.999 — communication is negligible relative to compute. This is why training on thousands of GPUs works efficiently.
AllReduce is the operation that synchronizes gradients across GPUs. Understanding how it works explains every number in our benchmarks.
In data-parallel training, each GPU processes a different shard of the batch and computes its own gradient vector. But to update the weights, every GPU must end up with the same gradient — the average across all GPUs.
Each gradient vector Gi is a huge array of numbers — one number per model parameter. For our XLarge model, that's 52 million float32 numbers = 208 MB per GPU. The question is: how do we get the average of four 208 MB arrays onto all four GPUs as fast as possible?
The simplest approach: pick one GPU (say GPU 0) as the "leader." Every other GPU sends its entire gradient to GPU 0. GPU 0 adds them all up, divides by N, and sends the result back.
GPU 0's network link must handle 624 MB in + 624 MB out = 1,248 MB total. Meanwhile, GPU 1, 2, 3 are sitting idle — they only use their links once to send and once to receive.
With N GPUs, GPU 0 must transfer 2 × (N-1) × D data, where D is the gradient size. This grows linearly with N. With 100 GPUs, GPU 0 would need to transfer 198 × 208 MB = 41 GB. That would take minutes.
The core problem: one GPU does all the work while the others wait. We're wasting N-1 perfectly good network links.
Here's the insight that makes Ring AllReduce work: you don't have to send the entire gradient as one big blob.
Think of each gradient vector as a long array of 52 million numbers. We can split it into N equal pieces (one piece per GPU). For 4 GPUs:
Why split? Because now we can make each GPU responsible for averaging just one chunk. GPU 0 collects and averages Chunk 0 from all GPUs. GPU 1 collects and averages Chunk 1 from all GPUs. And so on. All four GPUs work in parallel — no bottleneck!
But there's a problem: if GPU 0 needs Chunk 0 from all other GPUs, and GPU 1 needs Chunk 1 from all other GPUs, etc., we'd need all-to-all communication — which has the same bottleneck. We need a smarter pattern for how the chunks travel between GPUs. That's where the ring topology comes in.
Arrange the GPUs in a logical ring. Each GPU has exactly two connections:
The crucial property: At every step, all four GPUs send and receive simultaneously. No GPU is ever idle. No GPU is ever overloaded. Every network link in the system is used at full bandwidth, all the time.
The algorithm runs in two phases:
Let's walk through both phases step by step.
Recall: each GPU's gradient is split into 4 chunks. We label them: GPU 0 has chunks [A₀, A₁, A₂, A₃], GPU 1 has [B₀, B₁, B₂, B₃], etc. The subscript is the chunk index, the letter is which GPU it came from.
The goal of Phase 1: make GPU 0 hold the sum A₀+B₀+C₀+D₀ (all GPUs' chunk 0), GPU 1 hold A₁+B₁+C₁+D₁, etc.
Each GPU has its own gradient, split into 4 equal chunks (52 MB each for our XLarge model).
Key rule: Each GPU sends one chunk to the right, receives one chunk from the left, and adds the received chunk to its own. Notice: all 4 GPUs send/receive simultaneously.
After step 2, the highlighted chunks contain partial sums from 3 out of 4 GPUs. One more step to go.
A₂+D₂+C₂ → GPU 1. GPU 1 sends B₃+A₃+D₃ → GPU 2. GPU 2 sends C₀+B₀+A₀ → GPU 3. GPU 3 sends D₁+C₁+B₁ → GPU 0.Why does this work? The accumulation rotated through chunks: Step 1 updated chunk 3/0/1/2, Step 2 updated chunk 2/3/0/1, Step 3 updated chunk 1/2/3/0. Each GPU had one chunk untouched for steps 1-2 — the chunk that received the final 3-GPU sum in Step 3, completing the full sum. Each GPU now holds 1/4 of the complete result. Phase 2 distributes these.
After Scatter-Reduce, each GPU has the complete sum of one chunk. Now we need to distribute those completed chunks so every GPU has all of them.
We use the same ring pattern — each GPU sends its completed chunk to the right, receives a completed chunk from the left — but this time we replace instead of adding (the incoming chunk is already the final sum, so we just overwrite).
Each GPU now has 2 completed chunks. 2 more steps to distribute the remaining ones...
Every GPU now has the identical, fully-summed gradient vector. Divide by N (=4) to get the average, then update weights. Every GPU updates identically.
Let's compare the two approaches concretely for our XLarge model (D = 208 MB) on 4 GPUs:
| Naive (Central Server) | Ring AllReduce | |
|---|---|---|
| Data GPU 0 must send | (N-1) × D = 3 × 208 = 624 MB | 2 × (N-1) × D/N = 2 × 3/4 × 208 = 312 MB |
| Data GPU 0 must receive | (N-1) × D = 624 MB | 2 × (N-1) × D/N = 312 MB |
| Bottleneck GPU's total load | 1,248 MB (GPU 0 does everything) | 624 MB (spread equally across all GPUs) |
| GPUs idle during transfer? | Yes — GPU 1,2,3 idle while GPU 0 works | No — all GPUs send/receive simultaneously |
| Scales to 100 GPUs? | GPU 0 transfers 99×208 = 20.6 GB | Each GPU transfers ~2×208 = 416 MB |
| Steps required | 2 (gather + broadcast) | 2×(N-1) = 6 steps (more latency per step) |
The critical difference: with naive, the bottleneck GPU's load grows linearly with N. With Ring AllReduce, each GPU's load approaches 2D regardless of N. At N=100, the naive approach is 50× slower than Ring AllReduce.
Gradient has D bytes total, split into N chunks → each chunk is D/N bytes.
Each GPU sends exactly one chunk (D/N bytes) to its right neighbor and receives one chunk from its left.
Takes N-1 steps. Each step, every GPU sends D/N. Total sent per GPU: (N-1) × D/N
Takes another N-1 steps with D/N per step. Total sent per GPU: (N-1) × D/N
2 × (N-1) × D/N = 2(N-1)/N × D
As N→∞, the total approaches 2D — essentially constant, independent of how many GPUs you have!
This is the beauty of Ring AllReduce: Adding more GPUs barely increases the per-GPU communication cost. Compare to the naive approach where the bottleneck GPU must handle 2(N-1) × D — which grows linearly with N. Ring AllReduce is why distributed training can scale to hundreds or thousands of GPUs.
The catch: More GPUs means more sequential steps (N-1 per phase). Each step has fixed latency (~10-50μs for NCCL kernel launch + network round-trip). With 4 GPUs: 6 steps. With 8 GPUs: 14 steps. This per-step latency is why AllReduce time does grow somewhat with GPU count — our benchmarks show communication going from 50ms (2 GPUs) to 94ms (4 GPUs) for XLarge. But it grows much slower than the naive approach would.
Every GPU has an identical copy of the model. The batch is split across GPUs. Each computes gradients on its shard, then they average gradients via AllReduce.
Computing all gradients first, then syncing them, means the GPU sits idle during AllReduce. This is the "two-stage" problem.
Firing async AllReduce per layer during backward lets the GPU compute the next layer's gradients while the network transfers the current layer's.
More layers = more opportunities to overlap. A 2-layer model has almost no benefit. A 7-layer model has 6 opportunities to overlap.
PyTorch DDP groups small gradient tensors into ~25MB buckets, reducing the number of AllReduce kernel launches. Production-grade and battle-tested.
NVLink (~600 GB/s) makes AllReduce fast — naive and interleaved gap shrinks. PCIe (~32 GB/s) makes AllReduce slow — interleaving saves much more time.