4.2. Model Parallelism and Sharding
🪄 Step 1: Intuition & Motivation
- Core Idea: Modern Transformers can have hundreds of billions of parameters. No single GPU (not even the mighty A100 or H100) can fit them all in memory — let alone train them efficiently.
So, how do we train such massive models? We split the model, data, and computation across many GPUs — this is parallelism and sharding.
But simply slicing and scattering the work isn’t enough; GPUs need to communicate constantly to stay synchronized. That’s where clever engineering (and libraries like DeepSpeed, Megatron-LM, and FSDP) comes into play.
- Simple Analogy: Imagine a factory building cars 🚗:
- Data parallelism: each team builds the same car using different batches of parts (data).
- Model parallelism: each team builds a different part of the same car.
- Pipeline parallelism: each team works on different assembly stages — one attaches doors, another paints, another checks quality.
All teams must coordinate perfectly to finish one car per cycle — fast and efficiently.
🌱 Step 2: Core Concept
Large-scale training ≈ dividing the Transformer training workload across multiple GPUs or nodes. There are three main strategies (often combined in hybrid setups):
- Data Parallelism
- Tensor (Model) Parallelism
- Pipeline Parallelism
1️⃣ Data Parallelism — Multiple Cooks, Same Recipe
Each GPU holds a complete copy of the model, but works on a different mini-batch of data.
At the end of every iteration, gradients are synchronized (averaged) across GPUs to keep model weights consistent.
Formally:
$$ W_{t+1} = W_t - \eta \cdot \frac{1}{N} \sum_{i=1}^N \nabla_i $$where $N$ = number of GPUs, and $\nabla_i$ = gradient from GPU $i$.
Pros:
- Easy to implement (e.g.,
torch.nn.DataParallel,DistributedDataParallel). - Scales well when model fits in memory.
Cons:
- Each GPU holds the entire model → not suitable for huge models.
- Communication overhead increases with number of GPUs.
2️⃣ Tensor (Model) Parallelism — Dividing the Recipe Itself
When the model is too large for one GPU, we split the layers or weight matrices themselves across multiple GPUs.
Example: If a fully connected layer has weight $W \in \mathbb{R}^{4096 \times 4096}$, we can split it horizontally or vertically:
- GPU 1 handles $W_{[:, :2048]}$,
- GPU 2 handles $W_{[:, 2048:]}$.
During forward pass, each GPU computes its partial output; during backward, gradients are synced.
This is used heavily in Megatron-LM and GPT-3 training.
Pros:
- Handles extremely large models.
- Scales linearly with GPU count.
Cons:
- Requires constant inter-GPU communication.
- Slower if GPUs are far apart (network bottlenecks).
3️⃣ Pipeline Parallelism — The Assembly Line
The model is split vertically across layers. Each GPU holds a consecutive chunk of layers and passes activations to the next.
While GPU 1 is processing batch i+1, GPU 2 can already work on batch i, creating a pipeline of overlapping computation.
Pros:
- Efficient utilization of GPUs.
- Enables training very deep models.
Cons:
- Pipeline bubbles (idle time at the start and end of each batch).
- Complex gradient checkpointing for backward passes.
Libraries like DeepSpeed handle automatic pipeline scheduling to minimize idle time.
📐 Step 3: Mathematical Foundation
Gradient Synchronization in Data Parallelism
Each GPU computes its own gradients: $\nabla_i = \frac{\partial L_i}{\partial W}$
Then all GPUs perform all-reduce (summing and averaging):
$$ \nabla = \frac{1}{N} \sum_{i=1}^N \nabla_i $$This ensures all replicas update weights identically.
Implemented efficiently with NCCL (NVIDIA Collective Communications Library).
Sharding Optimizer States — ZeRO Optimization
DeepSpeed’s ZeRO (Zero Redundancy Optimizer) splits model states across devices:
| Stage | What is Sharded | Saved Memory |
|---|---|---|
| ZeRO-1 | Optimizer states | ~2× |
| ZeRO-2 | Gradients + optimizer states | ~4× |
| ZeRO-3 | Model weights + gradients + optimizer states | ~8× |
This reduces redundancy — each GPU only stores part of the model or optimizer.
Memory savings let us train trillion-parameter models on clusters of modest GPUs.
Mixed Precision Training — Keeping Things Fast
Instead of using 32-bit floats everywhere, we use 16-bit precision for most computations (FP16 or BF16).
- Cuts memory use by ~2×.
- Doubles training speed on tensor cores.
- Requires loss scaling to avoid underflow.
Paired with gradient accumulation, it keeps global batch sizes large even when GPU memory is small.
🧠 Step 4: Key Ideas
- Data Parallelism: Duplicate model, distribute data.
- Model (Tensor) Parallelism: Split model parameters across GPUs.
- Pipeline Parallelism: Split model layers sequentially.
- ZeRO Optimization: Shard gradients, weights, and optimizer states.
- Mixed Precision: Trade precision for speed and memory savings.
⚖️ Step 5: Strengths, Limitations & Trade-offs
- Enables training ultra-large models.
- Efficient GPU utilization via overlapping compute/communication.
- Supported by powerful frameworks (DeepSpeed, Megatron-LM, FSDP).
- Communication bottlenecks across GPUs.
- Pipeline bubbles reduce utilization.
- Harder to debug and checkpoint.
🚧 Step 6: Common Misunderstandings
🚨 Common Misunderstandings (Click to Expand)
- “Parallelism automatically speeds up training.” Not always — communication can offset compute gains if not optimized.
- “ZeRO removes the need for parallelism.” It complements it — ZeRO reduces memory use, not computational cost.
- “Mixed precision hurts accuracy.” When used correctly (with scaling), it maintains near-identical accuracy while speeding up training.
🧩 Step 7: Mini Summary
🧠 What You Learned: Large-scale Transformer training relies on combining multiple parallelism strategies to overcome GPU memory and communication limits.
⚙️ How It Works: Data parallelism duplicates models, tensor parallelism splits weights, and pipeline parallelism stages layers. ZeRO, gradient accumulation, and mixed precision keep everything memory-efficient.
🎯 Why It Matters: Understanding these strategies is crucial for scaling from million-parameter prototypes to trillion-parameter systems — where efficient communication determines whether your model trains or crashes.