大模型分布式训练方法详解
随着大语言模型(LLM)参数量级的不断攀升,单卡显存已无法容纳整个模型权重。当训练数据规模扩大时,单卡训练不仅速度极慢,甚至因 OOM(Out Of Memory)错误而无法启动。为了解决这一问题,利用多 GPU(单机多卡或多机多卡集群)进行分布式并行训练成为标准方案。
本文详细解析四种主流并行策略:数据并行(DP/DDP)、模型张量并行(TP)、流水线并行(PP)以及 ZeRO 优化技术,并对比其适用场景与实现细节。
1. 数据并行(Data Parallelism)
数据并行是最基础的并行方式。在 DP 模式下,每个 GPU 都加载一份完整的模型副本,将输入数据分割成多个批次(Batch),分别送入不同的 GPU 进行前向传播和反向传播计算。
1.1 DP 与 DDP 的区别
虽然概念相似,但实际工程中常用的是 DDP(Distributed Data Parallel)而非传统的 DP。
- 实现机制:传统 DP 基于多线程实现,受 Python GIL(全局解释器锁)限制;DDP 基于多进程实现,每个 GPU 由独立进程控制,互不干扰。
- 通信效率:DP 存在多次数据交换,而 DDP 通过 NCCL(NVIDIA Collective Communications Library)等后端,在梯度计算完成后进行一次全量同步(All-Reduce),显著降低通信开销。
- 扩展性:DP 仅支持单机;DDP 支持单机及多机集群,配合 Gloo 或 NCCL 后端可实现跨节点通信。
1.2 PyTorch DDP 示例
import torch.distributed as dist
import torch.multiprocessing as mp
def train(rank, world_size):
# 初始化进程组
dist.init_process_group("nccl", rank=rank, world_size=world_size)
# 设置设备
torch.cuda.set_device(rank)
model = MyModel().to(rank)
ddp_model = DistributedDataParallel(model, device_ids=[rank])
# 训练循环...
dist.destroy_process_group()
if __name__ == "__main__":
world_size = torch.cuda.device_count()
mp.spawn(train, args=(world_size,), nprocs=world_size, join=True)
2. ZeRO 深度解析
ZeRO(Zero Redundancy Optimizer)是 DeepSpeed 提出的优化技术,旨在进一步减少显存占用。它属于数据并行的范畴,但通过将模型状态分片到不同 GPU 上,实现了比传统 DP 更高的显存利用率。
2.1 ZeRO 的分片阶段
ZeRO 根据分片粒度的不同分为三个阶段:
- ZeRO-1(优化器状态分片):仅将优化器状态(Optimizer States)分片存储。适用于显存紧张但模型参数较小的场景。
- ZeRO-2(优化器 + 梯度分片):除了优化器状态,还将梯度(Gradients)分片。这是最常用的配置,能大幅减少反向传播时的显存峰值。
- ZeRO-3(优化器 + 梯度 + 参数分片):将模型参数(Parameters)也分片。每个 GPU 只存储部分参数,计算时需要从其他 GPU 拉取所需参数。这允许在有限显存下训练更大的模型,但通信开销最大。


