torch-tensor-parallelism
Tensor Parallelism Implementation Guide
This skill provides guidance for implementing tensor parallelism patterns in PyTorch, specifically for ColumnParallelLinear and RowParallelLinear layers that distribute computation across multiple devices.
Core Concepts
Tensor Parallelism Overview
Tensor parallelism splits individual layers across multiple devices to parallelize computation within a single forward/backward pass. The two primary patterns are:
-
ColumnParallelLinear: Shards weights along the output dimension (columns). Each device computes a portion of the output features, then results are concatenated via all-gather.
-
RowParallelLinear: Shards weights along the input dimension (rows). Each device computes partial outputs using its shard of the input, then results are summed via all-reduce.
Critical Implementation Requirement
When implementing tensor parallelism (especially in simulation or testing contexts), the forward pass must actually perform the collective operations, not just compute local shards: