跳到主要内容
极客日志极客日志面向AI+效率的开发者社区
首页博客GitHub 精选镜像工具UI配色美学隐私政策关于联系
搜索内容 / 工具 / 仓库 / 镜像...⌘K搜索
注册
博客列表
PythonAI

提升 PyTorch 训练效率的 9 个实用做法

训练速度慢通常不是模型本身的问题,而是数据加载、批量设置、显存使用和并行方式没配好。这里整理了 9 个常见优化点:用 DataLoader 和合适的 num_workers 提升输入吞吐,尽量把 batch size 调到硬件允许范围,显存不足时用梯度累积,记录 loss 时避免保留计算图,单卡训练减少 CPU/GPU 来回拷贝,开启 16-bit 混合精度,多卡和多节点场景优先考虑 DDP。整体思路很简单:先解决最容易出瓶颈的地方,再看模型和资源规模做更重的并行拆分。

女王发布于 2026/6/300 浏览
提升 PyTorch 训练效率的 9 个实用做法

在深度学习开发里,训练慢通常不是单一原因:有时是数据喂得太慢,有时是显存被浪费了,也有时只是把硬件用得不够满。下面这 9 个方法,基本都围绕 PyTorch 和 PyTorch-Lightning 的常见做法展开,偏实战,不追求花哨,但对大多数项目都能直接见效。

PyTorch-Lightning 只是把训练流程包了一层,核心还是 PyTorch。本身不会替你把模型变快,但它把很多参数和训练策略收拢得更清楚,适合拿来说明这些优化点。

1. 先把数据加载做好

训练速度卡住,最常见的地方其实不是 GPU,而是数据读取。比起把数据整成 h5py 或 numpy 再在主进程里慢慢读,直接用 DataLoader 往往更省事,也更容易把吞吐拉起来。图像任务可以直接上 PyTorch 的数据集和 DataLoader,NLP 场景则可以看看 TorchText。

在 PyTorch-Lightning 里,数据管道配置好之后,训练循环就不用自己一层层写了,框架会按你提供的 DataLoader 去跑。

from torch.utils.data import DataLoader
from torchvision.datasets import MNIST

dataset = MNIST(root='./data', train=True, download=True)
loader = DataLoader(dataset, batch_size=32, shuffle=True, num_workers=4, pin_memory=True)

for batch in loader:
    x, y = batch
    model.training_step(x, y)

这里 shuffle=True 是为了打乱样本顺序,pin_memory=True 则能让 CPU 到 GPU 的拷贝更顺一点。这个参数经常被忽略,但在 GPU 训练里挺实用。

2. 把 num_workers 调起来

DataLoader 默认是主进程读数据,数据集一大,IO 就会拖后腿。num_workers 的作用就是把加载工作拆到多个进程里,减少等待时间。

# 慢:主进程加载
loader = DataLoader(dataset, batch_size=32, shuffle=True)

# 快:启用 4 个 worker 进程
loader = DataLoader(dataset, batch_size=32, shuffle=True, num_workers=4)

一般可以从 4 或 8 开始试,接近 CPU 核心数不一定最优,但比完全不开强很多。别一上来就拉满,worker 太多会和别的任务抢内存、抢磁盘,最后反而抖得更厉害。

3. 批量大小尽量用满显存

batch size 往大调,通常是最直接的提速手段。更大的 batch 能让 GPU 更连续地干活,减少一次次小批次切换带来的浪费。很多时候,模型不是算不过来,而是算得太碎。

不过这一步有明显代价:显存会上去,学习率也往往要跟着改。常见做法是按线性缩放思路去调学习率,但具体还是得结合模型和数据看,不是机械套公式就行。

4. 显存不够时,用梯度累积

如果显存撑不起大 batch,梯度累积比硬拆模型要现实得多。它的思路很简单:先分几次算小 batch,把梯度攒起来,再统一更新一次参数。效果上接近大 batch,代价是单步会慢一点,但总比直接 OOM 强。

optimizer.zero_grad()
scaled_loss = 0
accumulated_steps = 4

for i in range(accumulated_steps):
    out = model.forward()
    loss = some_loss(out, y) / accumulated_steps
    loss.backward()
    scaled_loss += loss.item()

optimizer.step()
actual_loss = scaled_loss

PyTorch-Lightning 里就更简单了,直接配 accumulate_grad_batches:

trainer = Trainer(accumulate_grad_batches=4)
trainer.fit(model)

这类优化的好处是不用改模型结构,坏处是训练节奏会变慢一点,但在显存紧张的时候很值。

5. 记录 loss 时别把计算图一起存了

这个坑很隐蔽。很多人为了后面画曲线,直接把 loss append 到列表里,结果把整张计算图也一起留住了,显存回收不了。

# 错误:保留计算图副本
losses.append(loss)

# 正确:仅存储数值
losses.append(loss.item())

.item() 会把标量取成普通 Python 数值,计算图也就断开了。只是记录日志的话,这么做就够了。

6. 单卡训练时,别让 CPU 和 GPU 来回折腾

单 GPU 训练并不等于'把模型丢上去就完了'。模型和输入都要在同一个设备上,传输次数越少越好。Lightning 这类框架会帮你处理一部分,但底层逻辑还是要知道。

model.cuda()
x = x.cuda()
out = model(x)

如果每一步都在频繁搬数据,GPU 很容易空转。torch.cuda.empty_cache() 这种操作只适合你明确知道自己在做什么的时候用,别把它当成常规清理手段,它不会 magically 解决内存问题,反而可能让同步更频繁。

7. 混合精度通常值得开

FP16 的直接收益很明显:显存占用更低,很多新卡上算得也更快。混合精度不是把所有计算都换成 16 位,而是该保留精度的地方保留,能降精度的地方降一点,整体上更平衡。

原生 PyTorch 可以用 AMP:

from torch.cuda.amp import autocast, GradScaler
scaler = GradScaler()

with autocast():
    output = model(input)
    loss = criterion(output, target)

scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()

在 PyTorch-Lightning 里,通常直接设 precision=16 就能开起来:

trainer = Trainer(precision=16)
trainer.fit(model)

这类优化一般属于'先开了再看效果',成本低,收益往往不差。

8. 多 GPU 时优先考虑分布式训练

多卡训练里,最容易踩坑的是把概念混在一起。DataParallel 能用,但单机多卡下性能和稳定性都不算理想,更多时候我会优先看 DistributedDataParallel。简单说,前者更像把模型复制几份,后者更像让每张卡各干各的,再同步梯度。

分批次训练(DataParallel)
model = DataParallel(model, device_ids=[0, 1, 2, 3])
out = model(x)

Lightning 里可以直接指定多卡:

# 旧写法示意
trainer = Trainer(gpus=[0, 1, 2, 3])
模型分布训练(Model Parallelism)

当单卡放不下模型时,才会考虑把不同模块拆到不同 GPU 上,比如编码器放一张卡、解码器放另一张卡。

self.encoder.cuda(0)
self.decoder.cuda(1)
out = self.decoder(self.encoder(x))

这类方式的代价是实现复杂,调试也麻烦。除非模型确实大到单卡塞不下,否则一般不会优先走这条路。

组合使用

有些场景会把不同模块拆分,再叠加多卡并行,但这已经不是'提速小技巧'了,更像系统设计问题。

9. 分布式多节点训练适合更大的规模

当单机多卡还不够用,就会碰到多节点训练。每个节点上的每张 GPU 都有自己的模型副本,数据切分后各自训练,再通过梯度同步保持一致。DistributedDataParallel 是这类场景里最常见的选择。

def main_process_entrypoint(gpu_nb):
    dist.init_process_group("nccl", rank=gpu_nb, world_size=world)
    torch.cuda.set_device(gpu_nb)
    model = DistributedDataParallel(model, device_ids=[gpu_nb])
    # 训练逻辑...

if __name__ == '__main__':
    mp.spawn(main_process_entrypoint, nprocs=8)

Lightning 里配置会省很多事:

trainer = Trainer(gpus=8, accelerator='ddp')
trainer.fit(model)

这类方案的收益很大,但前提是你的数据管道、通信和训练逻辑都能跟上。否则卡在同步上,卡再多也不一定快。

结语

如果只挑几个最值得先做的,我会先看数据加载、batch size 和混合精度。这三项通常成本最低,也最容易看见效果。显存紧张就加梯度累积,多卡场景再考虑 DDP。优化训练速度没有统一答案,最后还是要回到模型规模、硬件条件和当前瓶颈上去判断。

目录

  1. 1. 先把数据加载做好
  2. 2. 把 num_workers 调起来
  3. 慢:主进程加载
  4. 快:启用 4 个 worker 进程
  5. 3. 批量大小尽量用满显存
  6. 4. 显存不够时,用梯度累积
  7. 5. 记录 loss 时别把计算图一起存了
  8. 错误:保留计算图副本
  9. 正确:仅存储数值
  10. 6. 单卡训练时,别让 CPU 和 GPU 来回折腾
  11. 7. 混合精度通常值得开
  12. 8. 多 GPU 时优先考虑分布式训练
  13. 分批次训练(DataParallel)
  14. 旧写法示意
  15. 模型分布训练(Model Parallelism)
  16. 组合使用
  17. 9. 分布式多节点训练适合更大的规模
  18. 结语
  • 免费图片AI生成工具免费生成了解详情
  • Magick API 一键接入全球大模型注册送1000万token查看
  • 免费图片视频在线生成30秒,将你的创意变成现实开始设计
  • X/Twitter免费视频下载器免登陆无限额度免费视频解析下载了解详情
  • 100+免费在线小游戏爽一把
极客日志微信公众号二维码

微信扫一扫,关注极客日志

微信公众号「极客日志V2」,在微信中扫描左侧二维码关注。展示文案:极客日志V2 zeeklog

更多推荐文章

查看全部
  • Seedance 2.0 实测:AI 视频从“能看”走向“能用”
  • Open3D.Art 生成模型到拓竹打印的实用流程
  • Python 3.11 新特性:性能、异常与类型系统的变化
  • IntelliJ IDEA 2026.1 EAP:Java 26、Spring Boot 4 与 Gradle 9 适配
  • NWPU VHR-10 遥感目标检测与 YOLO 实践
  • 文心一言 4.5:中文能力实测与本地部署记录
  • 在 WSL2 上部署 OpenClaw 的实操记录
  • Vue 3 常用编程技巧整理
  • 在 Ubuntu 22.04 上部署 llama.cpp 和 llama-server
  • Pencil.dev 安装与实战:在 VS Code 里做设计
  • PaddleNLP 3.0:大模型训推一体与多硬件适配实践
  • Unreal Engine 集成 VRM4U 的实战方案
  • Kali Linux 2025.4 发布:Wayland 默认、桌面与工具链更新
  • 小米 9 改复古掌机:天马 G 前端实战
  • Linux 下安装 libwebkit2gtk-4.1-0 的方法与作用
  • CASIC MOTOR 14.8V 无刷减速电机拆解记录
  • 用 LLaMA-Factory WebUI 微调 Qwen2.5-VL
  • Win10 里关闭 Microsoft 365 Copilot 弹窗的几种办法
  • Seedance 2.0 双分支扩散 Transformer 解析
  • Java 8 基础知识整理:运算符、控制流与面向对象

相关免费在线工具

  • RSA密钥对生成器

    生成新的随机RSA私钥和公钥pem证书。 在线工具,RSA密钥对生成器在线工具,online

  • Mermaid 预览与可视化编辑

    基于 Mermaid.js 实时预览流程图、时序图等图表,支持源码编辑与即时渲染。 在线工具,Mermaid 预览与可视化编辑在线工具,online

  • 随机西班牙地址生成器

    随机生成西班牙地址(支持马德里、加泰罗尼亚、安达卢西亚、瓦伦西亚筛选),支持数量快捷选择、显示全部与下载。 在线工具,随机西班牙地址生成器在线工具,online

  • curl 转代码

    解析常见 curl 参数并生成 fetch、axios、PHP curl 或 Python requests 示例代码。 在线工具,curl 转代码在线工具,online

  • Base64 字符串编码/解码

    将字符串编码和解码为其 Base64 格式表示形式即可。 在线工具,Base64 字符串编码/解码在线工具,online

  • Base64 文件转换器

    将字符串、文件或图像转换为其 Base64 表示形式。 在线工具,Base64 文件转换器在线工具,online