3.1. Model Sharding & Distributed Inference
🪄 Step 1: Intuition & Motivation
Core Idea (in 1 short paragraph): Sometimes, your model is too big for one GPU — like trying to fit an elephant into a suitcase. Distributed inference solves this by slicing the model (or its workload) across multiple GPUs or machines. Each GPU handles part of the computation or stores part of the weights, and together they act as a single, powerful system. The art lies in how we slice, schedule, and communicate so the elephant still walks gracefully.
Simple Analogy (one only): Imagine a bakery that needs to bake a 10-foot wedding cake.
- One oven can’t fit it (too big).
- So, you bake layers in parallel ovens (tensor parallelism) or pass the cake layer by layer down a production line (pipeline parallelism). Each baker (GPU) handles their portion, but timing and coordination are everything — otherwise, you end up with frosting chaos.
🌱 Step 2: Core Concept
There are three main strategies for splitting models too large for one device.
What’s Happening Under the Hood?
1️⃣ Tensor Parallelism
- Splits individual matrix operations (like in Transformer layers) across GPUs.
- Example: If your linear layer multiplies a $[1024 \times 4096]$ matrix, you can split it column-wise across 4 GPUs, each handling ¼ of the weights and activations.
- GPUs must communicate partial results (e.g., all-reduce or gather ops).
Use case: Extremely large layers (e.g., 13B–70B models).
2️⃣ Pipeline Parallelism
- Splits different layers across GPUs.
- Example: Layer 1–10 on GPU1, Layer 11–20 on GPU2.
- As input data flows through, GPU1 computes its part, passes results downstream, and starts the next batch while GPU2 works on the previous batch’s next step — like a conveyor belt.
Use case: Deep models where layer boundaries are cleanly separable.
3️⃣ ZeRO Partitioning (Zero Redundancy Optimizer)
- Introduced in DeepSpeed, ZeRO partitions optimizer states, gradients, and weights within data parallelism, reducing memory duplication.
- Instead of each GPU storing the full model copy, they hold only shards and reconstruct parameters on-the-fly during compute.
Use case: Efficient training and inference of multi-billion parameter models without huge redundancy.
Why It Works This Way
The fundamental problem: GPU memory is limited. A 40B parameter model in FP16 uses roughly:
$$ 40 \times 10^9 \times 2 \text{ bytes} = 80 \text{ GB} $$No single GPU has that capacity (most have 24–80 GB). Thus, splitting is the only way forward.
But — splitting means communication. So, efficiency comes from minimizing communication overhead (via clever scheduling, mixed precision, and caching).
Modern frameworks like vLLM, DeepSpeed, and Ray Serve optimize these coordination mechanics — deciding when and what to send between GPUs to keep them busy, not waiting.
How It Fits in ML Thinking
At inference time, distributed serving enables models that would otherwise be impossible to deploy. It also improves throughput by parallelizing across hardware.
For example:
- vLLM uses PagedAttention to share KV caches efficiently across requests.
- DeepSpeed Inference applies ZeRO-inference + tensor slicing to save memory and communication.
- Ray Serve orchestrates multiple replicas for parallel request handling, abstracting away GPU placement logic.
📐 Step 3: Mathematical Foundation
Parallelism & Memory Scaling
Let:
- $M$ = model memory footprint
- $N$ = number of GPUs
If partitioned ideally, each GPU stores roughly $\frac{M}{N}$ parameters, but you also add communication overhead ($C$):
$$ \text{Effective memory per GPU} = \frac{M}{N} + C $$Where $C$ grows with how frequently GPUs need to sync (e.g., during all-reduce or gather operations).
Goal: minimize $C$ so GPUs stay compute-bound, not communication-bound.
Pipeline Parallelism Timing
If each stage takes time $t_i$, the total time per batch with $k$ microbatches and $p$ pipeline stages is roughly:
$$ T \approx \sum_{i=1}^{p} t_i + (k - 1) \cdot \max(t_i) $$Early microbatches experience pipeline fill latency.
Steady-state throughput improves as pipeline fills up.
Balancing $t_i$ (compute per stage) is crucial — one slow stage throttles the rest.
Like an assembly line: if the decorator (stage 3) is slower than the baker (stage 1), orders pile up waiting for frosting.
🧠 Step 4: Assumptions or Key Ideas
- GPUs communicate via high-speed interconnects (NVLink, InfiniBand) — bandwidth matters as much as compute.
- You can combine parallelism types (hybrid parallelism) for massive models (tensor + pipeline + ZeRO).
- Load balancing across GPUs is critical — uneven work nullifies scaling gains.
- Frameworks like vLLM optimize KV cache memory sharing for LLM inference efficiency.
- Quantization (int8/fp16/fp8) or offloading (CPU/SSD) are valid fallback strategies when splitting isn’t enough.
⚖️ Step 5: Strengths, Limitations & Trade-offs
- Enables inference for models exceeding single-GPU memory.
- Improves throughput via concurrency.
- Leverages commodity hardware collectively instead of requiring exotic single nodes.
- Communication overhead can dominate at scale.
- Complex synchronization logic → higher engineering effort.
- Failures on one GPU can stall the entire pipeline.
Trade-offs:
- Memory vs. Communication: More splitting saves memory but increases latency.
- Simplicity vs. Scale: Single-node inference is simpler but capped by GPU RAM; distributed adds complexity but scales.
- Precision vs. Fit: Quantization reduces memory but can affect accuracy; balance depends on model tolerance.
🚧 Step 6: Common Misunderstandings
🚨 Common Misunderstandings (Click to Expand)
- “Just use more GPUs, it’ll scale linearly.” → Not true. Network bandwidth and communication patterns limit gains.
- “Pipeline = Parallel.” → Pipelines overlap, not duplicate; they help throughput but don’t reduce latency per request.
- “Quantization solves everything.” → It reduces memory, yes, but may degrade precision-sensitive tasks (like code generation or reasoning).
🧩 Step 7: Mini Summary
🧠 What You Learned: Distributed inference slices giant models across GPUs using tensor, pipeline, and ZeRO partitioning — enabling deployment beyond single-device limits. ⚙️ How It Works: Split weights or layers intelligently, minimize communication, and leverage frameworks like vLLM and DeepSpeed for optimized orchestration. 🎯 Why It Matters: It’s the foundation of modern LLM and multimodal system deployment — turning impossible models into scalable, production-ready systems.