PyTorch 多卡训练概述
在深度学习模型训练中,随着数据量和模型复杂度的增加,单张 GPU 的计算资源往往成为瓶颈。多卡训练(Multi-GPU Training)通过利用多个计算设备并行处理任务,显著缩短训练时间。PyTorch 提供了多种机制来实现分布式训练,主要包括 torch.nn.DataParallel (DP) 和 torch.nn.parallel.DistributedDataParallel (DDP)。
一、多卡训练基本原理
多卡训练的核心思想是将大任务分解为小任务分配给不同设备执行。通用流程如下:
- 节点指定:确定主机节点及从属节点。
- 数据划分:将 Batch 数据平均分到每个机器或 GPU 上。
- 模型分发:将模型参数从主机拷贝到各个计算节点。
- 前向传播:各节点独立进行前向计算。
- 损失计算:各节点计算局部 Loss。
- 梯度同步:收集所有节点的梯度或 Loss 结果,进行聚合。
- 参数更新:根据聚合后的梯度更新模型参数,并同步回各节点。
二、单机多卡训练:DataParallel
torch.nn.DataParallel 是 PyTorch 早期提供的并行模块,适用于单机多卡场景。
1. 工作原理
使用方式非常简单,只需将模型包裹一层:
model = torch.nn.DataParallel(model, device_ids=[0, 1, 2, 3])
其内部逻辑是将输入数据在 CPU 端切分,发送给指定的 GPU 分别执行 forward 操作。例如,Batch Size 为 32,4 个 GPU,则每个 GPU 处理 8 条数据。计算完成后,各 GPU 的输出会被收集到主 GPU(通常是 cuda:0)上进行合并。
2. 局限性
尽管使用便捷,但 DP 存在明显缺陷:
- 通信瓶颈:所有梯度汇聚到主 GPU 进行 backward 和参数更新,导致主 GPU 负载过重,其他 GPU 空闲等待。
- Loss 计算位置:默认情况下,loss 计算仅在 cuda:0 上进行,无法并行化。
- GIL 限制:由于 Python 的全局解释器锁(GIL),CPU 端的预处理可能成为瓶颈。
3. 优化方案
为解决 loss 计算不均衡问题,可以在模型的 forward 函数中直接计算 loss,并在返回前对多个 GPU 的 loss 取平均:
class Net(torch.nn.Module):
def __init__(self, ...):
super().__init__()
self.fc = torch.nn.Linear(...)
def forward(self, inputs, labels=None):
outputs = .fc(inputs)
labels :
loss_fct = torch.nn.CrossEntropyLoss()
loss = loss_fct(outputs, labels)
loss
:
outputs


