Part 9: Tensor Parallelism

What Problem Does This Solve?

Models are getting larger than GPUs. Llama 3.1 405B at FP16 needs 810GB of memory — no single GPU exists with that much. Even a 70B model at FP16 needs 140GB, exceeding the 80GB of an H100.

The solution: split the model across multiple GPUs. Tensor parallelism (TP) splits individual weight matrices across GPUs. Each GPU holds a slice of every layer, and they cooperate to produce the same output as a single large GPU would.

Single GPU (impossible — doesn't fit):
┌──────────────────────────────────────────────┐
│ GPU 0: Full 70B model (140GB)                │  ✗ Doesn't fit in 80GB
└──────────────────────────────────────────────┘

Tensor Parallelism (TP=2):
┌───────────────────────┐  ┌───────────────────────┐
│ GPU 0: Left half of   │  │ GPU 1: Right half of  │
│ every weight matrix   │  │ every weight matrix    │
│ (70GB)                │  │ (70GB)                 │
└───────────────────────┘  └───────────────────────┘
         │                          │
         └────── AllReduce ─────────┘
            (sync after each layer)

TP=2 cuts memory per GPU in half. TP=4 cuts it to a quarter. TP=8 puts a 405B model across 8 GPUs.


The Core Idea: Column-Parallel + Row-Parallel

The Megatron-LM paper introduced a pattern for parallelizing transformer layers using two types of linear layers: column-parallel and row-parallel. They’re always used in pairs, and the pair requires only ONE AllReduce communication.

Column-Parallel Linear

Split the weight matrix by columns (output dimension). Each GPU computes a slice of the output independently — no communication needed.

Standard:              Y = X @ W           W: [in, out]

Column-parallel:       Y_0 = X @ W_0      W_0: [in, out/2]   GPU 0
                       Y_1 = X @ W_1      W_1: [in, out/2]   GPU 1

Full W = [W_0 | W_1]  (concatenated by columns)
Full Y = [Y_0 | Y_1]  (each GPU has a different slice of Y)

Example with W: [4, 8], TP=2:
  GPU 0 gets W[:, 0:4] → computes Y[:, 0:4]
  GPU 1 gets W[:, 4:8] → computes Y[:, 4:8]
  No communication — each GPU works independently!

Row-Parallel Linear

Split the weight matrix by rows (input dimension). Each GPU multiplies its slice of the input by its slice of the weight, producing a partial output. An AllReduce sums the partials to get the final result.

Standard:              Y = X @ W           W: [in, out]

Row-parallel:          Y_0 = X_0 @ W_0    W_0: [in/2, out]   GPU 0
                       Y_1 = X_1 @ W_1    W_1: [in/2, out]   GPU 1
                       Y = Y_0 + Y_1      ← AllReduce(SUM)!

Full W = [W_0]   (stacked by rows)
         [W_1]

Example with W: [8, 4], TP=2:
  GPU 0 gets W[0:4, :] and input X[:, 0:4] → partial Y
  GPU 1 gets W[4:8, :] and input X[:, 4:8] → partial Y
  AllReduce(SUM) → full Y on all GPUs

The Megatron-LM Pair

In a transformer FFN, there are two linear layers in sequence. The first is column-parallel, the second is row-parallel:

┌──────────────────────────────────────────────────────────┐
│                  Tensor-Parallel FFN                      │
│                                                          │
│  Input X ──► [Column-Parallel W1] ──► GeLU ──► [Row-Parallel W2] ──► AllReduce ──► Output
│                                                                                    │
│  GPU 0:  X @ W1_0 → GeLU → hidden_0 @ W2_0 → partial_0 ─┐                       │
│  GPU 1:  X @ W1_1 → GeLU → hidden_1 @ W2_1 → partial_1 ─┤─► AllReduce → Y       │
│  GPU 2:  X @ W1_2 → GeLU → hidden_2 @ W2_2 → partial_2 ─┤                       │
│  GPU 3:  X @ W1_3 → GeLU → hidden_3 @ W2_3 → partial_3 ─┘                       │
│                                                                                    │
│  Communication: ZERO between W1 and W2 (the slices line up!)                       │
│                 ONE AllReduce after W2 (sum the partials)                           │
└──────────────────────────────────────────────────────────┘

The key insight: the output of column-parallel W1 on each GPU is exactly the input that row-parallel W2 needs on that GPU. The dimensions line up, so no communication is needed between the two layers. The only communication is the AllReduce at the end.


How It Works

AllReduce: The Cost of TP

AllReduce is the collective operation where every GPU contributes data, all data is summed element-wise, and every GPU gets the full result:

Before AllReduce:
  GPU 0: [1, 2, 3, 4]
  GPU 1: [5, 6, 7, 8]

AllReduce(SUM):
  GPU 0: [6, 8, 10, 12]   (element-wise sum)
  GPU 1: [6, 8, 10, 12]   (same result on both!)

AllReduce is typically implemented as a ring — each GPU sends and receives from its neighbors in a pipeline:

Ring AllReduce (4 GPUs):

  Step 1: Reduce-Scatter                Step 2: All-Gather
  (each GPU gets 1/4 of final sum)      (broadcast the partial sums)

       GPU 0                                 GPU 0
      ╱     ╲                               ╱     ╲
   GPU 3   GPU 1          →             GPU 3   GPU 1
      ╲     ╱                               ╲     ╱
       GPU 2                                 GPU 2

  GPU 0 → sends chunk to GPU 1         GPU 0 → sends result to GPU 1
  GPU 1 → sends chunk to GPU 2         GPU 1 → sends result to GPU 2
  GPU 2 → sends chunk to GPU 3         GPU 2 → sends result to GPU 3
  GPU 3 → sends chunk to GPU 0         GPU 3 → sends result to GPU 0

  Total data sent per GPU: 2 × (N-1)/N × data_size
  With NVLink (900 GB/s): microseconds for typical hidden sizes

For each transformer layer, TP requires 2 AllReduces: one after the attention output projection, one after the FFN down projection. A 32-layer model does 64 AllReduces per forward pass.

AllReduce communication cost: each GPU sends 2 × (N-1)/N × data_size bytes, where N is the TP degree. For a hidden size of 4096 with batch×seq = 2048 tokens at FP16:

data_size = 2048 × 4096 × 2 bytes = 16 MB
AllReduce cost (TP=8): 2 × 7/8 × 16 MB = 28 MB sent per AllReduce
64 AllReduces × 28 MB = 1.8 GB total communication per forward pass

H100 NVLink: 900 GB/s → 1.8 GB / 900 = 2 ms of communication

Weight Matrix Slicing: Visual

Here’s exactly how the weight matrices are sliced for TP=2:

Column-Parallel (gate_proj, up_proj, Q, K, V projections):

  Full weight: [out=8, in=4]          GPU 0 gets:        GPU 1 gets:
  ┌─────────────────────────┐        ┌────────────┐     ┌────────────┐
  │ a  b  c  d  │  e  f  g  h │      │ a  b  c  d │     │ e  f  g  h │
  │ i  j  k  l  │  m  n  o  p │  →   │ i  j  k  l │     │ m  n  o  p │
  │ q  r  s  t  │  u  v  w  x │      │ q  r  s  t │     │ u  v  w  x │
  │ 1  2  3  4  │  5  6  7  8 │      │ 1  2  3  4 │     │ 5  6  7  8 │
  └─────────────┴─────────────┘      └────────────┘     └────────────┘
   columns 0-3    columns 4-7         [4, 4]             [4, 4]

  X @ W_0 = Y_0 (partial output)     X @ W_1 = Y_1 (partial output)
  No communication needed — Y_0 and Y_1 are independent slices of Y

Row-Parallel (down_proj, output projection):

  Full weight: [out=4, in=8]          GPU 0 gets:        GPU 1 gets:
  ┌─────────────────────────┐        ┌────────────┐     ┌────────────┐
  │ a  b  c  d  e  f  g  h  │      │ a  b  c  d │     │ e  f  g  h │
  │ i  j  k  l  m  n  o  p  │  →   │ i  j  k  l │     │ m  n  o  p │
  │ q  r  s  t  u  v  w  x  │      │ q  r  s  t │     │ u  v  w  x │
  │ 1  2  3  4  5  6  7  8  │      │ 1  2  3  4 │     │ 5  6  7  8 │
  └─────────────────────────┘      └────────────┘     └────────────┘
                                     rows 0-3           rows 4-7
                                     [4, 4]             [4, 4]

  X_0 @ W_0 = partial_0             X_1 @ W_1 = partial_1
  AllReduce(partial_0 + partial_1) = Y   ← communication here

TP for Attention

Attention heads are naturally parallel — each head is independent. TP splits heads across GPUs:

Llama with 32 attention heads, TP=4:
  GPU 0: heads 0-7    (QKV projection: column-parallel)
  GPU 1: heads 8-15
  GPU 2: heads 16-23
  GPU 3: heads 24-31

Output projection: row-parallel (AllReduce after)
→ 1 AllReduce per attention sublayer

For GQA (grouped-query attention) with fewer KV heads, TP requires the number of KV heads to be divisible by the TP degree. Llama 3.1 70B has 8 KV heads, so TP≤8.

TP for FFN (SwiGLU)

Modern LLMs use SwiGLU FFN with gate and up projections:

SwiGLU FFN:   Y = (SiLU(X @ W_gate) ⊙ (X @ W_up)) @ W_down

TP split:
  W_gate, W_up  → column-parallel (split output dimension)
  W_down         → row-parallel (AllReduce after)

GPU 0: gate_0 = X @ W_gate_0,  up_0 = X @ W_up_0
        hidden_0 = SiLU(gate_0) ⊙ up_0
        partial_0 = hidden_0 @ W_down_0
GPU 1: ... same with shard 1 ...
AllReduce(partial_0 + partial_1 + ...) → full output

Weight Loading: Sharding Full Weights

When loading a model, the full checkpoint contains the complete weight matrices. Each GPU extracts its shard:

# Column-parallel: slice columns
full_weight: [intermediate, hidden]  = [11008, 4096]
gpu_0_shard: full_weight[0:5504, :]  = [5504, 4096]    # first half of outputs
gpu_1_shard: full_weight[5504:, :]   = [5504, 4096]    # second half

# Row-parallel: slice rows (input dimension)
full_weight: [hidden, intermediate]  = [4096, 11008]
gpu_0_shard: full_weight[:, 0:5504]  = [4096, 5504]    # first half of inputs
gpu_1_shard: full_weight[:, 5504:]   = [4096, 5504]    # second half

How vLLM/SGLang Implements This

Our CodeReal vLLMReal SGLang
ColumnParallelLinearColumnParallelLinearColumnParallelLinear
RowParallelLinearRowParallelLinearRowParallelLinear
AllReduce(SUM) in row-paralleltensor_model_parallel_all_reduce()Same (via PyTorch dist)
load_weight_shard()weight_loader() per layerload_weights() with sharding
torchrun --nproc_per_nodeRay or multiprocessing with NCCLmultiprocessing with NCCL
Manual weight slicingWeightTiledLoader + QuantizedWeightsLoaderSimilar weight loading

Key differences:

vLLM’s weight loading: vLLM doesn’t load the full checkpoint and then slice. Instead, each GPU loads only its shard directly from the checkpoint using the weight_loader function attached to each parameter. This avoids the peak memory of having all weights in CPU memory.

Custom AllReduce: vLLM implements a custom AllReduce kernel (custom_ar) that’s faster than NCCL for small message sizes. For hidden sizes under ~16K tokens, the latency of launching NCCL’s ring algorithm dominates. vLLM’s custom kernel uses shared memory and direct GPU-to-GPU copies via NVLink, reducing launch overhead.

Quantization + TP: When combining TP with quantization (e.g., INT4 weight-only), the weight shards are stored in quantized format. Each GPU holds quantized weights for its shard, dequantizing on-the-fly during the matmul. This compounds the memory savings: TP=4 with INT4 reduces memory by 16x vs FP16 on one GPU.


The Implementation

The complete implementation is in 09_tensor_parallelism.py (~280 lines).

Column-Parallel Linear

class ColumnParallelLinear(nn.Module):
    def __init__(self, in_features, out_features, world_size, rank):
        self.out_per_rank = out_features // world_size
        self.linear = nn.Linear(in_features, self.out_per_rank)

    def load_weight_shard(self, full_weight):
        start = self.rank * self.out_per_rank
        end = start + self.out_per_rank
        self.linear.weight.data.copy_(full_weight[start:end, :])

    def forward(self, x):
        return self.linear(x)  # no communication!

Row-Parallel Linear

class RowParallelLinear(nn.Module):
    def __init__(self, in_features, out_features, world_size, rank):
        self.in_per_rank = in_features // world_size
        self.linear = nn.Linear(self.in_per_rank, out_features)

    def load_weight_shard(self, full_weight):
        start = self.rank * self.in_per_rank
        end = start + self.in_per_rank
        self.linear.weight.data.copy_(full_weight[:, start:end])

    def forward(self, x):
        y = self.linear(x)
        dist.all_reduce(y, op=dist.ReduceOp.SUM)  # THE communication
        return y

TP MLP (Megatron-LM Pattern)

class TensorParallelMLP(nn.Module):
    def __init__(self, hidden_size, intermediate_size, world_size, rank):
        self.gate_proj = ColumnParallelLinear(...)   # no comm
        self.down_proj = RowParallelLinear(...)       # AllReduce
        self.act = nn.SiLU()

    def forward(self, x):
        h = self.act(self.gate_proj(x))   # each GPU: independent
        return self.down_proj(h)           # AllReduce inside

Running the Code

Demo with 2 GPUs:

torchrun --nproc_per_node=2 09_tensor_parallelism.py --demo

Demo with 4 GPUs:

torchrun --nproc_per_node=4 09_tensor_parallelism.py --demo

Expected output (TP=2 on H100):

World size (TP degree): 2
GPUs: 2x NVIDIA H100 80GB HBM3

--- Correctness Test ---
  Max absolute difference: 3.91e-03
  Match: YES (fp32 rounding from different op order)

--- Weight Sharding ---
  Full gate_proj weight: [4096, 1024]
    GPU 0: gate_proj shard [2048, 1024] (columns 0:2048)
    GPU 1: gate_proj shard [2048, 1024] (columns 2048:4096)

--- Performance ---
  Config: hidden=4096, inter=11008, batch=16, seq=128
  Single GPU:  7.59 ms/forward
  TP=2:       3.99 ms/forward
  Speedup:     1.90x

  Weight memory per GPU:
    Single GPU: 360.7 MB (full weights)
    TP=2:      180.4 MB (1/2 of weights)

Benchmarks

TP DegreeWeight Memory/GPUCompute SpeedupCommunication Overhead
TP=1100%1.0xNone
TP=250%~1.9x2 AllReduces/layer
TP=425%~3.5x2 AllReduces/layer
TP=812.5%~5-6x2 AllReduces/layer

Why isn’t TP=8 8x faster? AllReduce communication grows with TP degree (more GPUs to synchronize). At TP=8, the communication overhead eats into the compute savings. The sweet spot is usually TP=2 or TP=4 for inference.

ModelFP16 SizeMin GPUs (80GB)Recommended TP
Llama 3.1 8B16 GB1TP=1
Llama 3.1 70B140 GB2TP=2 or TP=4
Llama 3.1 405B810 GB11TP=8
DeepSeek V3 (671B)1.3 TB17TP=8 + EP

Key Takeaways

  1. Tensor parallelism splits weight matrices across GPUs so each GPU holds a fraction of every layer
  2. Column-parallel for the first linear (no communication), row-parallel for the second (AllReduce)
  3. The Megatron-LM pattern requires 2 AllReduces per transformer layer — that’s all the communication
  4. TP reduces both memory and compute per GPU, but AllReduce overhead limits scaling beyond TP=8
  5. Attention heads are split across GPUs — TP degree must divide the number of KV heads
  6. Real systems like vLLM use custom AllReduce kernels that beat NCCL for small messages

Further Reading