GPU 显存分析
在大模型微调或预训练过程中,GPU 显存占用是决定能否成功运行的关键瓶颈。显存主要被以下四个部分占用:模型参数、参数梯度、优化器状态和中间激活结果。
对于一个 6B(60 亿)参数量的模型,若使用 FP32(单精度浮点)格式存储,其模型参数占用计算如下:
6 × 10^9 × 4 (Bytes) / 1024^3 ≈ 22 GB
将模型参数视为基准,反向传播时产生的参数梯度占用量通常与模型参数相同,即约 22GB。
优化器方面,主流采用 Adam Optimizer。其核心计算公式包含动量项 m 和二阶矩估计 v:
m_t = β1 * m_{t-1} + (1 - β1) * g_t
v_t = β2 * v_{t-1} + (1 - β2) * g_t^2
由于需要保存 m 和 v 两个状态变量,且它们的规模与参数梯度相同,因此优化器状态通常需要两倍于参数梯度的显存容量,即约 44GB。
此外,在计算中得到的中间结果(激活值)需要保存在显存中,以便反向传播时计算梯度。对于每一个中间结果,其数据形状通常为 [Batch, SeqLen, Dim]。在深层网络中,这部分显存占用可能非常巨大,甚至超过参数本身。
Collective Operations
为了节省显存并提升计算效率,可以将模型或者数据分配到不同的显卡上。多卡之间通过集合通信操作(Collective Operations)进行数据交换,常见的操作包括:
Broadcast
广播操作将一张显卡(Root Rank)上的 N 元素缓冲区复制到所有其他显卡(Ranks)。这常用于初始化参数或分发配置信息。
AllReduce、Reduce、ReduceScatter
- AllReduce:对所有设备上的数据进行聚合(如求和、取最大值),并将结果写入所有设备的接收缓冲区。这是数据并行中同步梯度的核心操作。
- Reduce:执行与 AllReduce 相同的操作,但只将结果写入指定的 Root Rank 的接收缓冲区。
- ReduceScatter:执行 Reduce 操作后,将结果分散到所有设备中,每个设备获得一部分数据块。
AllGather
收集所有 Ranks 的数据,合并成一个大小为 k*N 的输出,并分发给所有 Ranks。这在模型并行或 ZeRO 同步参数时常用。
数据并行
数据并行(Data Parallelism, DP)是最基础的并行策略。它将训练数据分成若干份,装载到不同节点上进行计算,每个节点维护一份完整的模型副本。


