PyTorch 多卡训练概述
在深度学习模型训练中,随着数据量和模型复杂度的增加,单张 GPU 的计算资源往往成为瓶颈。多卡训练(Multi-GPU Training)通过利用多个计算设备并行处理任务,显著缩短训练时间。PyTorch 提供了多种机制来实现分布式训练,主要包括 (DP) 和 (DDP)。
详细阐述了 PyTorch 多卡训练的原理与实现方案。对比了 DataParallel 与 DistributedDataParallel 两种模式的机制差异,重点介绍了 DDP 在多机多卡场景下的进程初始化、梯度同步及数据采样方法。内容涵盖环境配置、模型封装、状态字典保存及常见调试技巧,旨在帮助开发者构建高效的分布式训练系统。

在深度学习模型训练中,随着数据量和模型复杂度的增加,单张 GPU 的计算资源往往成为瓶颈。多卡训练(Multi-GPU Training)通过利用多个计算设备并行处理任务,显著缩短训练时间。PyTorch 提供了多种机制来实现分布式训练,主要包括 (DP) 和 (DDP)。
torch.nn.DataParalleltorch.nn.parallel.DistributedDataParallel多卡训练的核心思想是将大任务分解为小任务分配给不同设备执行。通用流程如下:
torch.nn.DataParallel 是 PyTorch 早期提供的并行模块,适用于单机多卡场景。
使用方式非常简单,只需将模型包裹一层:
model = torch.nn.DataParallel(model, device_ids=[0, 1, 2, 3])
其内部逻辑是将输入数据在 CPU 端切分,发送给指定的 GPU 分别执行 forward 操作。例如,Batch Size 为 32,4 个 GPU,则每个 GPU 处理 8 条数据。计算完成后,各 GPU 的输出会被收集到主 GPU(通常是 cuda:0)上进行合并。
尽管使用便捷,但 DP 存在明显缺陷:
为解决 loss 计算不均衡问题,可以在模型的 forward 函数中直接计算 loss,并在返回前对多个 GPU 的 loss 取平均:
class Net(torch.nn.Module):
def __init__(self, ...):
super().__init__()
self.fc = torch.nn.Linear(...)
def forward(self, inputs, labels=None):
outputs = self.fc(inputs)
if labels is not None:
loss_fct = torch.nn.CrossEntropyLoss()
loss = loss_fct(outputs, labels)
return loss
else:
return outputs
当返回标量时,DataParallel 会将其收集为向量,backward 前需确保逻辑正确。但在生产环境中,更推荐使用 DDP。
DistributedDataParallel (DDP) 是目前推荐的标准分布式训练方式,支持单机多卡和多机多卡。
DistributedSampler 确保不同进程读取不同的数据切片。每个进程必须初始化分布式环境。首先通过命令行参数注入 local_rank:
python -m torch.distributed.launch --nproc_per_node=2 --nnodes=1 train.py
在代码中解析参数并设置环境变量:
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--local_rank", type=int, default=-1)
args = parser.parse_args()
# 设置随机种子以保证可复现性
def set_seed(seed):
import random
import numpy as np
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
set_seed(42)
# 设置当前进程使用的 GPU
torch.cuda.set_device(args.local_rank)
device = torch.device('cuda', args.local_rank)
# 初始化分布式进程组
# backend='nccl' 用于 GPU 通信,'gloo' 用于 CPU 通信
torch.distributed.init_process_group(backend='nccl')
torch.distributed.init_process_group 包含以下常用参数:
tcp://ip:port) 或文件系统 (file:///path/to/file)。使用 DistributedSampler 自动划分数据集,避免重复采样:
from torch.utils.data.distributed import DistributedSampler
train_sampler = DistributedSampler(train_dataset)
train_loader = torch.utils.data.DataLoader(
train_dataset,
sampler=train_sampler,
batch_size=batch_size
)
注意:这里的 batch_size 是每个 GPU 上的批次大小,实际全局 Batch Size 为 batch_size * world_size。
model = Net().to(device)
model = torch.nn.parallel.DistributedDataParallel(
model,
device_ids=[args.local_rank],
output_device=args.local_rank,
find_unused_parameters=True # 如果模型中有未使用的参数,需设为 True
)
在分布式环境下,保存模型需注意以下几点:
model.module.state_dict() 而非 model.state_dict(),因为 DDP 包裹了一层。if torch.distributed.get_rank() == 0:
model_to_save = model.module if hasattr(model, "module") else model
torch.save(model_to_save.cpu().state_dict(), "model.pth")
加载时同样需要恢复分布式环境后再实例化模型:
param = torch.load("model.pth", map_location=device)
model.load_state_dict(param)
batch_size 或使用梯度累积。world_size 与实际启动的进程数一致。timeout 参数或检查防火墙设置。PyTorch 的多卡训练能力是构建大规模深度学习系统的基础。对于单机场景,DataParallel 简单但效率受限;对于生产级训练,尤其是多机多卡场景,DistributedDataParallel 凭借更好的通信效率和扩展性成为首选。掌握其初始化流程、数据划分策略及模型持久化方法,是高效训练大模型的关键。

微信公众号「极客日志」,在微信中扫描左侧二维码关注。展示文案:极客日志 zeeklog
使用加密算法(如AES、TripleDES、Rabbit或RC4)加密和解密文本明文。 在线工具,加密/解密文本在线工具,online
生成新的随机RSA私钥和公钥pem证书。 在线工具,RSA密钥对生成器在线工具,online
基于 Mermaid.js 实时预览流程图、时序图等图表,支持源码编辑与即时渲染。 在线工具,Mermaid 预览与可视化编辑在线工具,online
解析常见 curl 参数并生成 fetch、axios、PHP curl 或 Python requests 示例代码。 在线工具,curl 转代码在线工具,online
将字符串编码和解码为其 Base64 格式表示形式即可。 在线工具,Base64 字符串编码/解码在线工具,online
将字符串、文件或图像转换为其 Base64 表示形式。 在线工具,Base64 文件转换器在线工具,online