Distributed Inference

Why Tensor Parallelism Beats ZeRO at Inference

A prompt comes in. Tokens must be generated one by one. Watch what happens inside the GPUs under each strategy — and why one is ~10,000× more efficient.

1 token
Generated at a time during decoding
80 layers
Forward pass per token (Llama 70B)
~200 MB
ZeRO gathers per layer per token
~16 KB
TP all-reduces per layer per token

A Prompt Arrives

A user sends a prompt. The model must generate a response, one token at a time. Each new token requires a complete forward pass through all 80 layers.

User Prompt
"Explain how black holes form in simple terms."
The model generates the response one token at a time:
Each colored token = one full forward pass through all 80 layers
The critical question: The model is spread across 8 GPUs. For every single token, every single layer needs the GPUs to coordinate. How they coordinate is what makes TP and ZeRO fundamentally different at inference time.

How the Weights Live on GPUs

Before any token is generated, the 70B model's weights must be distributed across the 8 GPUs. ZeRO and TP distribute them very differently.

ZeRO-3 Parameter Sharding

Each GPU holds 1/8 of parameters — but these are row shards, so no GPU has a complete layer. To compute anything, a GPU must first collect shards from all other GPUs.

Problem: No single GPU can compute a layer on its own. It must all-gather the full weight matrix from 7 other GPUs first.

TP Column-Parallel Sharding

Each GPU holds its column slice of EVERY layer — it can immediately compute a partial result for any layer without fetching anything.

Advantage: Every GPU can start computing immediately. Only a tiny all-reduce on the output activation is needed to combine results.

Generating Tokens — The Main Event

Watch what happens inside the GPUs as each token is generated. Click play to step through the forward pass for each token.

Prompt "Explain how black holes form..."
Generated tokens:
Press Play to start generating tokens

ZeRO-3 Forward Pass

Comm time:
0 ms
Data moved: 0 MB

TP Forward Pass

Comm time:
0 ms
Data moved: 0 KB
Token 0 / 8
Notice: For every single token, ZeRO-3 must all-gather ~200 MB of weights per layer. The red blocks (communication) completely dominate. Meanwhile, TP's green all-reduce blocks are so tiny they're barely visible — because they only move ~16 KB of activation data per layer.

Zooming Into One Layer

Let's see exactly what each GPU does during a single transformer layer's forward pass, for a single token being decoded.

ZeRO-3 One Layer, One Token

1
All-Gather weights
Each GPU sends its 1/8 shard to all others. Every GPU reconstructs the full weight matrix.
~200 MB moved across network
2
Compute Y = W × x
Multiply the (now-complete) weight matrix with the single token's hidden state.
~few million FLOPs (tiny!)
3
Discard gathered weights
Free the memory — no room to keep them. Next layer will need its own all-gather.
Repeat for all 80 layers, for every token
TIME BREAKDOWN (per layer)
All-Gather 83%
17%

TP One Layer, One Token

1
Weights already resident!
Each GPU's column shard for this layer is already in HBM. No fetching needed.
0 bytes moved — weights are always hot
2
Compute Y_partial = W_shard × x
Each GPU multiplies its column shard with the input. Gets a partial output vector.
~few million FLOPs (same compute as ZeRO)
3
All-Reduce to sum partials
Sum the partial outputs across GPUs to get the full activation Y. Tiny data volume!
~16 KB moved (8192 floats × 2 bytes)
TIME BREAKDOWN (per layer)
Compute 95%
5%
ZeRO-3: Data per layer (Llama 70B) Weight matrix size: d × d_ff = 8192 × 28672
× 2 bytes (fp16) × ~2 matrices = ~200 MB
× 80 layers = ~16 GB per token
TP: Data per layer (Llama 70B) Activation vector: batch(1) × seq(1) × d_model(8192)
× 2 bytes (fp16) × 2 all-reduces = ~32 KB
× 80 layers = ~2.5 MB per token

Token After Token — It Adds Up

A typical response is 200+ tokens. Each one pays the full communication cost. Drag the slider to see how total data transferred grows with response length.

100
80
Total Data Moved to Generate Response
ZeRO-3
TP
6,400× difference

Why ZeRO Works Great for Training

ZeRO isn't a bad idea — the same all-gather that kills inference is barely noticeable during training. Here's why.

Large Batches Amortize Cost

With batch=32, seq=2048, each layer does 2×32×2048×d×d_ff FLOPs. The ~200 MB all-gather becomes a tiny fraction of total time.

Compute 95%
5%
Training: comm is ~5% of time

Backward Pass Overlaps

During training, you also do backward passes. ZeRO can overlap next-layer's weight gather with current-layer's backward computation — hiding latency.

Memory is the Bottleneck

ZeRO's real value is memory savings: each GPU stores only 1/N of parameters, gradients, and optimizer states. This lets you train models that can't fit on one GPU.

The same ~200 MB all-gather that eats ~83% of per-layer time during single-token decoding only consumes ~5% during training. The communication cost is constant — only the compute changes. That's why strategy choice depends entirely on the workload.

Try It Yourself

Adjust model parameters to see how communication costs change for each strategy.

8,192
28,672
80
8
600
MetricZeRO-3Tensor ParallelismRatio

The Bottom Line

01

Every Token = Full Forward Pass

During autoregressive decoding, each new token triggers a complete walk through all layers. The parallelism strategy's communication cost is paid per layer, per token.

02

ZeRO Moves Weights (~200 MB/layer)

ZeRO-3 must all-gather full weight matrices before each layer can compute. For a 70B model, that's ~16 GB of network traffic per token.

03

TP Moves Activations (~16 KB/layer)

TP only all-reduces tiny activation vectors. Weights are always resident. Communication per token is ~2.5 MB — roughly 6,400× less than ZeRO-3.

04

Compute is Tiny During Decoding

With batch=1, seq=1, the actual matmul per layer takes microseconds. ZeRO-3's all-gather takes milliseconds. Communication completely dominates.

05

Training Amortizes the Cost

The same all-gather that's 83% of inference time is only 5% during training, because large batches create 65,000× more compute per layer.

06

Production Uses TP for Inference

vLLM, TGI, TensorRT-LLM all use Tensor Parallelism within a node. ZeRO is reserved for training where memory savings justify the communication.