Context Parallelism

Context Parallelism Visualized

Context Parallelism splits long sequences across GPUs so each device only computes attention for its chunk — enabling million-token contexts without running out of memory.

4
GPUs
Ring
Attention
Zigzag
Ring Attention
Zero
Redundancy

Long Sequences Don't Fit

Self-attention memory scales quadratically with sequence length. At 128K tokens, a single GPU simply cannot hold the attention matrix.

A Single GPU vs. Long Sequences

Drag the slider to see how attention memory explodes with sequence length.

8K tokens
Single GPU
80 GB HBM (A100)
Attention Memory Usage 2 GB
80 GB
Memory = S2 × 2 bytes = 128 MB
Key Takeaway

At S = 128K, the attention matrix alone is 32 GB — nearly half an A100's memory, leaving no room for model weights, gradients, or activations. Context Parallelism splits this across GPUs.

Ring Attention — The Core Idea

Arrange GPUs in a ring. Each GPU holds Q for its chunk and passes K,V around the ring, computing partial attention at each step.

Step 0 of 3

Each GPU starts with its own K,V chunk. Compute local Q × KT attention.

GPU 0
Q: chunk 0
GPU 1
Q: chunk 1
GPU 2
Q: chunk 2
GPU 3
Q: chunk 3
No communication — each GPU uses local K,V
Attention Blocks Computed per GPU
GPU 0 (Q0)
Q0 × K?
GPU 1 (Q1)
Q1 × K?
GPU 2 (Q2)
Q2 × K?
GPU 3 (Q3)
Q3 × K?
Key Insight

At each step, only 1/P of K,V is in flight between neighbors. Memory per GPU stays at O(S²/P²) for the attention matrix, achieving linear memory scaling with the number of GPUs.

Zigzag Ring Attention — Better Balance

Naive chunking + causal masking creates severe load imbalance. Zigzag interleaving fixes this elegantly.

With naive chunking, GPU 0 gets the earliest tokens and must attend to all chunks (the top-left triangle is full). GPU 3 gets the latest tokens but its causal mask blocks most of the attention — leaving it mostly idle.

Token Assignment (16 tokens across 4 GPUs)

Causal Attention Matrix

Row = query token, Col = key token. Colored = computed, dark = masked.

Workload per GPU

Number of attention computations (Q×K pairs)

Why Zigzag Works

By assigning tokens in a zigzag pattern (first + last, second + second-to-last, ...), each GPU gets a mix of early and late tokens. Under causal masking, this ensures each GPU performs roughly the same number of attention computations — eliminating the load imbalance problem.

Ring vs Zigzag Comparison

Both methods use the same ring communication topology. Zigzag adds smarter token assignment for causal models.

Naive Ring Attention

Contiguous chunks → Uneven causal work

Load imbalance: ~40%

Zigzag Ring Attention

Interleaved tokens → Balanced causal work

Load imbalance: ~2%
Metric Naive Ring Attention Zigzag Ring Attention
Token Assignment Contiguous chunks [0..S/P-1], [S/P..2S/P-1], ... Interleaved: GPU i gets tokens i, 2P-1-i, 2P+i, ...
Load Balance (Causal) Poor — GPU 0 does most work, GPU P-1 least Excellent — near-perfect balance
Memory per GPU O(S²/P²) attention + O(S/P) for K,V Same: O(S²/P²) attention + O(S/P) for K,V
Communication Volume P-1 rounds of K,V transfer, each O(S/P) Same: P-1 rounds of K,V transfer
Communication Pattern Ring (each GPU sends to next neighbor) Same ring topology
Causal Mask Efficiency Wastes compute on masked positions Minimal wasted compute
Implementation Complexity Simple Slightly more complex token remapping
Best For Bidirectional attention (BERT, encoders) Causal / autoregressive models (GPT, LLaMA)
Bottom Line

For bidirectional attention (BERT-style), naive Ring Attention is fine — every GPU does the same work. For causal/autoregressive models (GPT, LLaMA, etc.), Zigzag Ring Attention is strictly better: same communication cost, same memory, but perfectly balanced workloads.