01 大模型训练总体架构
如何利用计算中心成千上百的 AI 加速芯片集群,训练参数量超过百亿的大规模模型?并行计算是一种行之有效的方法。除了分布式并行计算相关的技术之外,在训练大模型的过程中还会融合更多的技术,如新的算法模型架构和内存/计算优化技术等。
本文梳理在大模型训练中使用到的相关技术点,主要分为三个方面来回顾现阶段使用多 AI 加速芯片训练大模型的主流方法:
- 分布式并行加速:并行训练主要分为数据并行(Data Parallel)、模型并行(Model Parallel)、流水线并行(Pipeline Parallel)、张量并行(Tensor Parallel)四种并行方式,通过上述四种主要的分布式并行策略作为大模型训练并行的主要策略。
- 算法模型架构:大模型训练离不开 Transformer 网络模型结构的提出,后来到了万亿级稀疏场景中经常遇到专家混合模型(MoE),都是大模型离不开的新算法模型结构。
- 内存和计算优化:关于内存优化技术主要由激活(Activation)重计算、内存高效的优化器、模型压缩组成;而计算优化则集中体现在混合精度训练、算子融合、梯度累加等技术上。
02 大模型训练的目标公式
超大模型训练的总体目标就是提升总的训练速度,减少大模型的训练时间。训练一个大模型基本上从按下回车的那一刻开始要 1 到 2 个月,是非常耗时的。下面看一下在大模型训练中的总训练速度的公式:
$$ \text{Total Time} = \frac{\text{Total Compute}}{\text{Single GPU Speed} \times \text{GPU Count} \times \text{Scaling Efficiency}} $$
上面公式当中,单卡速度主要由单块 AI 加速芯片的运算速度、数据 IO 来决定;而加速芯片数量这个很清楚,数量越多增加训练速度;而多卡加速比则是由计算和通讯效率决定。
我们再把使用到的技术与这个公式关联在一起:
- 单卡速度:单卡速度既然是运算速度和数据 IO 的快慢来决定,那么就需要对单卡训练进行优化,于是主要的技术手段有精度训练、算子融合、梯度累加来加快单卡的训练性能。
- 加速芯片数量:理论上,AI 芯片数量越多,模型训练越快。但是,随着训练数据集规模的进一步增长,加速比的增长并不明显。如数据并行就会出现局限性,当训练资源扩大到一定规模时,由于通信瓶颈的存在,增加计算资源的边际效应并不明显,甚至增加资源也没办法进行加速。这时候需要通讯拓扑进行优化,例如通过 ring-all-reduce 的通讯方式来优化训练模式。
- 多卡加速比:多卡加速比既然由计算、通讯效率决定,那么就需要结合算法和集群中的网络拓扑一起优化,于是有了数据并行 DP、模型并行 MP、流水线并行 PP 相互结合的多维度混合并行策略,来增加多卡训练的效率。
总的来说呢,超大模型训练的目标就是优化上面的公式,提升总训练速度。核心思想是将数据和计算有关的图/算子切分到不同设备上,同时尽可能降低设备间通信所需的代价,合理使用多台设备的计算资源,实现高效的并发调度训练,最大化提升训练速度。
03 大模型训练的集群架构
这里的集群架构是为了解决机器学习模型的分布式训练问题。深度学习的大模型目前主要是在集群中才能训练出来,而集群的架构也需要根据分布式并行、深度学习、大模型训练的技术来进行合理安排。
在 2012 年左右 Spark 采取了简单直观的数据并行的方法解决模型并行训练的问题,但由于 Spark 的并行梯度下降方法是同步阻断式的,且模型参数需通过全局广播的形式发送到各节点,因此 Spark 的并行梯度下降是相对低效的。
2014 年李沐提出了分布式可扩展的 Parameter Server 架构,很好地解决了机器学习模型的分布式训练问题。Parameter Server 不仅被直接应用在各大公司的机器学习平台上,而且也被集成在 TensorFlow、PyTorch、MindSpore、PaddlePaddle 等主流的深度框架中,作为机器学习分布式训练最重要的解决方案之一。
目前最流行的模式有两种:
- 参数服务器模式(Parameter Server, PS)
- 集合通讯模式(Collective Communication, CC)
其中参数服务器主要是有一个或者多个中心节点,这些节点称为 PS 节点,用于聚合参数和管理模型参数。而集合通信则没有管理模型参数的中心节点,每个节点都是 Worker,每个 Worker 负责模型训练的同时,还需要掌握当前最新的全局梯度信息。
参数服务器模式
参数服务器架构 Parameter Server,PS 架构包括两个部分:
- 把计算资源分为两个部分:参数服务器节点和工作节点。
- 参数服务器节点用来存储参数。


