GPU 显存分析
在大模型微调或预训练过程中,GPU 显存占用是决定能否成功运行的关键瓶颈。显存主要被以下四个部分占用:模型参数、参数梯度、优化器状态和中间激活结果。
对于一个 6B(60 亿)参数量的模型,若使用 FP32(单精度浮点)格式存储,其模型参数占用计算如下:
6 × 10^9 × 4 (Bytes) / 1024^3 ≈ 22 GB
将模型参数视为基准,反向传播时产生的参数梯度占用量通常与模型参数相同,即约 22GB。
优化器方面,主流采用 Adam Optimizer。其核心计算公式包含动量项 m 和二阶矩估计 v:
m_t = β1 * m_{t-1} + (1 - β1) * g_t
v_t = β2 * v_{t-1} + (1 - β2) * g_t^2
由于需要保存 m 和 v 两个状态变量,且它们的规模与参数梯度相同,因此优化器状态通常需要两倍于参数梯度的显存容量,即约 44GB。
此外,在计算中得到的中间结果(激活值)需要保存在显存中,以便反向传播时计算梯度。对于每一个中间结果,其数据形状通常为 [Batch, SeqLen, Dim]。在深层网络中,这部分显存占用可能非常巨大,甚至超过参数本身。
Collective Operations
为了节省显存并提升计算效率,可以将模型或者数据分配到不同的显卡上。多卡之间通过集合通信操作(Collective Operations)进行数据交换,常见的操作包括:
Broadcast
广播操作将一张显卡(Root Rank)上的 N 元素缓冲区复制到所有其他显卡(Ranks)。这常用于初始化参数或分发配置信息。
AllReduce、Reduce、ReduceScatter
- AllReduce:对所有设备上的数据进行聚合(如求和、取最大值),并将结果写入所有设备的接收缓冲区。这是数据并行中同步梯度的核心操作。
- Reduce:执行与 AllReduce 相同的操作,但只将结果写入指定的 Root Rank 的接收缓冲区。
- ReduceScatter:执行 Reduce 操作后,将结果分散到所有设备中,每个设备获得一部分数据块。
AllGather
收集所有 Ranks 的数据,合并成一个大小为 k*N 的输出,并分发给所有 Ranks。这在模型并行或 ZeRO 同步参数时常用。
数据并行
数据并行(Data Parallelism, DP)是最基础的并行策略。它将训练数据分成若干份,装载到不同节点上进行计算,每个节点维护一份完整的模型副本。
分布式数据并行流程
- 每个设备(Replica)都保存完整的模型参数。
- 每个设备处理一部分数据批次,独立进行前向传播和反向传播。
- 每个设备得到局部梯度后,通过 AllReduce 操作将所有设备的梯度聚合。
- 聚合后的梯度被同步回所有设备,每个设备根据更新规则更新本地参数。
- 在后向传播时,每计算完一层的梯度,就可以进行 Reduce 操作,提高并行性。
在分布式数据并行中,每个设备显存占用情况依然较高,因为每个设备仍需保存完整的模型参数、梯度和优化器参数。当模型参数量超过单卡显存时,此方法不可用。
模型并行
由于模型越来越大,单个设备保存模型参数、梯度和优化器越来越难。深度学习主要是矩阵计算,而矩阵计算可以分块计算,因此可以将模型参数拆成若干份,每份单独计算,以减少显存占用。
以矩阵乘法为例,假设 $Y = W \times X$,其中 $W$ 为 $A \times B$ 的矩阵。可以将 $W$ 沿行方向切分为 $n$ 个子矩阵 $W^{(1)}, W^{(2)}, ..., W^{(n)}$,分别存储在 $n$ 个设备上。
计算流程
- 将参数矩阵分成若干子矩阵,分发到不同设备中。
- 每个设备计算不同矩阵对应的输出部分。
- 最后通过 AllGather 等操作将结果收集起来。
模型并行后,显存占用显著降低,因为每个设备只需存储部分权重。但由于每个设备处理所有数据,中间激活结果(Activation)仍会保存在所有设备中,这限制了显存的节省比例。
ZeRO (Zero Redundancy Optimizer)
在分布式数据并行中,多个设备中参数相同,梯度相同,优化器状态相同,存在大量冗余。ZeRO 旨在消除这些冗余,进一步降低显存占用。
ZeRO-1
ZeRO-1 对优化器状态进行分片。在计算梯度后,通过 ReduceScatter 将梯度分片,每个 replica 只负责更新对应分片的参数,从而减少优化器状态的显存占用。
ZeRO-2
ZeRO-2 在 ZeRO-1 的基础上,进一步对梯度进行分片。在后向传播时,每计算一层梯度,就可以使用 ReduceScatter 进行同步。这样每个 replica 只需要保存部分梯度即可,无需保留完整梯度。
ZeRO-3
ZeRO-3 在 ZeRO-2 的基础上,将模型参数也进行分片。这是最激进的优化策略。
- 每个 replica 处理一部分输入。
- 前向传播时,当需要别的层参数,使用 AllGather 获取。
- 反向传播时,当需要别的层参数时,使用 AllGather 获取,同时计算出每一层梯度时,使用 ReduceScatter 分发到对应 replica。
- 每个 replica 用于部分优化器参数和梯度,进行对应参数更新。
不同 ZeRO 级别对应的显存占用呈阶梯式下降,ZeRO-3 理论上可将优化器和参数显存占用降至原来的 1/N(N 为卡数),但通信开销也会相应增加。
流水线并行
将模型一层一层分开,不同层放入不同 GPU 进行计算。与模型并行不同的是,模型并行保留从头到尾每一层的部分参数,输入可以计算出结果;流水线并行需要等前一层计算完毕才能进行计算。
显存分析与气泡
流水线并行显存占用较低,因为每块 GPU 只负责部分层。然而,它引入了**气泡(Bubble)**问题,即某些 GPU 在前向传播完成后需等待后续层完成反向传播才能继续工作,导致算力闲置。可以通过微批次(Micro-batching)和流水线调度算法来缓解这一问题。
混合精度
FP16(半精度)相较于 FP32 计算更快,同时占用更少的显存。但同时 FP16 表示的范围小,可能产生溢出错误。
特别的,在权重更新时 gradient * lr 可能导致下溢出(Underflow),使得梯度变为 0。
混合精度训练的思路是在优化器中保留一份 FP32 格式的参数副本(Master Weights),而模型权重、梯度等数据在训练中都是用 FP16 来存储。优化器中参数更新在 FP32 格式下保证精度,之后转换为 FP16 格式。
为了防止 FP16 梯度下溢出,通常需要使用 Loss Scaling 技术,即在反向传播前放大 Loss 值,反向传播后再缩小梯度。
Checkpointing (重计算)
由于模型反向传播需要中间结果计算梯度,大量中间结果占用大量显存。Checkpointing 思路是保存部分隐藏层的结果(作为检查点),其余的中间结果直接释放。
当反向传播需要计算梯度时,从检查点开始重新前向传播计算中间结果,得到梯度后再次释放。这种方法以计算时间为代价换取显存空间的节省,通常可节省约 50% 的激活显存。
总结与选型建议
在实际的大模型训练中,通常会组合使用上述策略:
- 数据并行 + ZeRO:适用于大多数场景,尤其是参数规模在数十亿至数百亿级别时。
- 流水线并行:适用于超大规模模型(千亿级以上),需配合 ZeRO 使用。
- 混合精度:几乎必选,能显著提升速度并降低显存。
- Checkpointing:当显存极度紧张时使用,需权衡训练时间成本。
选择合适的并行策略需要综合考虑硬件资源(GPU 数量、显存大小)、网络带宽(影响通信开销)以及模型架构特性。