Stacking Your Transformer:通过堆叠加快 LLM 预训练

Stacking Your Transformer:通过堆叠加快 LLM 预训练

Stacking Your Transformer:通过堆叠加快 LLM 预训练

原创 AI闲谈  2024年07月15日 20:00 北京

一、背景

我们之前的文章中介绍了几种模型增长的方案,然而针对 LLM 场景却又缺乏足够的数据支撑以及最佳实践。比如说不知道在 LLM 场景中这些方案的差异有多大,是否和模型规模、数据量、以及训练 FLOPs 有关?本文中我们介绍来自香港大学、清华大学和香港科技大学的 Stacking Your Transformer,作者做了大量的实验来尝试回答上述问题。

对应的论文为:[2405.15319] Stacking Your Transformers: A Closer Look at Model Growth for Efficient LLM Pre-Training

对应的 Blog:https://llm-stacking.github.io/

对应的代码库:GitHub - tongxuluo/prts

PS:需要说明的是,论文中的实验很多,这里也只列了一部分。当然,也有待完善的地方,比如论文中对比了 4 个抽象的增长算子的影响,但这并不代表就完全等效于之前的工作,如果能有一些具体的对比会更清晰;此外,论文中大小模型训练数据量也和之前的方法很不同,比如 [2309.03852] FLM-101B: An Open LLM and How to Train It with $100K Budget 中先在 16B 模型训练 245.37B Token,然后在 51B 模型训练 39.64B Token,最后在 101B 模型训练 26.54B Token;而本文中的方案基本是在小规模模型训练很少的 Token,比如 10B 规模。

二、摘要

论文中,作者进一步探索了模型扩展在 LLM 预训练中的可行性。首先,作者确定了 3 个关键障碍:

O1:缺乏全面的评估。

O2:未经测试的扩展可行性。

O3:缺乏经验指南。

为了解决 O1 问题,作者将现有的方案总结为 4 个原子增长算子,并在标准的 LLM 预训练中对其进行了系统的评估。结果表明,与 Baseline 相比,深度堆叠算子 Gstack 表现出了显著的加速,从而提升了在 8 个 NLP 基准的整体性能。基于此,作者深入的研究了 Gstack,以便解决 O2 和 O3。对于 O2,作者实验表明,Gstack 是可扩展的,并且始终表现良好,例如,与直接使用 300B Token 训练的 7B 模型相比,Gstack 只使用 194B Token 就可以达到相同损失,加速 54.6%。对于 O3,作者通过建模确定 Gstack 的增长规划(Growth Timing)和增长因子(Growth Factor),使其在常见的 LLM 预训练中更实用。

三、方法

3.1 O1:4 种增长算子

如下图 Figure 2 所示,作者将之前方案中的生长方案总结为 4 个生长算子:

(a):Gdirect 通过拷贝、切分和堆叠的方式实现,分为宽度方向和深度方向。

(b):Glearn 通过学习映射函数的方式实现。

(c):Gzero 通过扩充 0 值的方式实现。

(d):Grandom 通过随机初始化然后增加 Mask 的方式实现。(PS:是否也可以不使用随机初始化,比如拷贝后添加 Mask?)

www.zeeklog.com  - Stacking Your Transformer:通过堆叠加快 LLM 预训练

为了对比不同方法的效果,作者制定了一个统一的方案:总共两个训练阶段,增长前的小模型训练,增长后的大模型训练。其小模型训练的 Token 数 d,大模型训练的 Token 数 D 以及模型增长因子 g(对应非 Embedding 参数) 作为超参。

如下图 Figure 3 所示,作者先用 d=10B Token 预训练了一个 400M 参数的模型,然后扩展为 1.1B 参数,对应增长因子 g=4,并继续使用额外的 D=97.5B Token 进行训练,以此来验证不同方案的效果。可以看出深度堆叠 Gdirect 获得了最好的效果,与直接训练 100B Token大模型相比可以加速 49.1%,同时宽度扩展基本都是负优化。

www.zeeklog.com  - Stacking Your Transformer:通过堆叠加快 LLM 预训练

PS:实际上采用不同的 d 来训练小模型得到的结果很不一样,比如作者实际分别测试了使用 d=10B 和 d=50B Token 来训练小模型的结果,可以发现在 d=50B 的时候 Gzero 在深度上获得了更好的结果, Grandom 在宽度上获得更好的效果。

www.zeeklog.com  - Stacking Your Transformer:通过堆叠加快 LLM 预训练

3.2 O2:Gstack 扩展可行性

从上面可以看出,在作者 400M 的实验中,模型深度堆叠的方案能获得不错的收益,因此作者也聚焦在模型深度堆叠场景。如下所示,作者将这种方式称为 Gstack:

www.zeeklog.com  - Stacking Your Transformer:通过堆叠加快 LLM 预训练

扩展模型规模:如下图 Figure 4 和 Figure 5 所示,作者在 3B 模型和 7B 模型上进行了验证,其中 g=4,d=10B,具体来说,小模型的层数分别是 3B 和 7B 模型的 1/4,小模型训练 10B Token:

Figure 4:3B 模型从头开始训练

训练 180B Token 达到 loss 与 Gstack 花费 48.6% FLOPs 的 loss 相当。

训练 240B Token 达到 loss 与 Gstack 花费 54.5% FLOPs 的 loss 相当。

Figure 5:7B 从头开始训练,160B,220B 和 280B Token 对应 Gstack 的 FLOPs 为 40.8%, 55.3% 和 53.8%。

www.zeeklog.com  - Stacking Your Transformer:通过堆叠加快 LLM 预训练

从上可以看出,当对比 1B、3B 和 7B 模型时,Gstack 带来的优势并没有随着模型尺寸的增加而减少,这意味着,即使更大的模型中也可以利用 Gstack 来加速。

扩展数据规模:如下图 Figure 6 所示,作者进一步探索了不断扩展数据规模的时候 Gstack 是否还有优势。从 Figure 6a 可以看出,对于 410M 的模型,训练远超缩放法则确定的 Token 数(8B),模型的 Loss 一直在下降,并且 Gstack 一直能获得更低的 Loss。如下图 Figure 6b 所示,作者进一步预估当训练 Token 数达到 8T 时(1000倍),Gstack 对应的 Loss 依然更低,表明扩大数据规模 Gstack 依然会有加速的效果。

www.zeeklog.com  - Stacking Your Transformer:通过堆叠加快 LLM 预训练

Scaling Laws:如下图 Figure 7 所示,作者根据 410M、1.1B、3B、7B 模型的相关实验拟合了缩放法则曲线。可以看出,Gstack 在基于此预估出来的 13B 和 70B 模型模型上依然能够获得加速:

www.zeeklog.com  - Stacking Your Transformer:通过堆叠加快 LLM 预训练

3.3 O3:Growth Timing d 和 Growth Factor g

在上述的实验中作者直接采用了 d=10B 和 g=4,那么这个是否是最佳组合呢,在小模型上训练更多数据是否能带来更高的加速比?从 1 B 模型直接扩展到 30B 模型是否可行?为了回答这些问题,作者通过建模来确定 d 和 g 的影响。

如下图 Figure 8 所示,作者在 410M、1.1B 和 3B 模型上探索了 Growth Timing d 的影响。可以看出,对于给定的 FLOP 预算,可以确定一个最优的 Growth Timing d,比如说,对于 410M 模型,最优的 d 为 5-10B Token,对于 1.1B 模型,最优的 d 为 10-20B Token:

www.zeeklog.com  - Stacking Your Transformer:通过堆叠加快 LLM 预训练

如下图 Figure 9 所示,作者进一步拟合了对于给定大模型预训练 Token 数 C 和参数量 N 的情况下来预测 d 的曲线:

www.zeeklog.com  - Stacking Your Transformer:通过堆叠加快 LLM 预训练

如下图 Figure 10 所示,作者在 1.1B(24 层) 和 3B(32 层) 模型上探索了 Growth Factor g 的影响。可以看出,对于 1.1 B 的模型,最优的 g 为 2-3 左右,对于 3B 模型,最优的 g 为 4-5 左右。对于 3B 模型来说,即使 g=16 时(对应 small model 为 2 层) Gstack 依然能有加速。

www.zeeklog.com  - Stacking Your Transformer:通过堆叠加快 LLM 预训练

四、附录

4.1 如何堆叠?

在 [2011.13635] Progressively Stacking 2.0: A Multi-stage Layerwise Training Method for BERT Training Speedup 中,作者对 StackingBert-v1 进行了扩展。具体来说,将一个 N 层 Encoder 的模型分 K+1 次训练,第一次训练一个 N/k 层的 Bert 模型,然后每次扩展 N/k 层并且进行训练。其中绿色为冻结的层,红色为训练的层。也就是每次扩展后只训练新扩展的层,全部扩展完之后再解冻所有层继续训练:

www.zeeklog.com  - Stacking Your Transformer:通过堆叠加快 LLM 预训练

本文作者采用了稍有不同的方案,具体来说,不是逐层堆叠 N/k 个 Layer,而是分两次堆叠。比如对于 6 层到 24 层的模型,第一次从 6 层到 12 层,第二次直接从 12 层到 24 层。作者也对两种方案进行了消融实验,如下图 Figure 33 所示,可以看出本文 Gstack 的方案会更优一些:

www.zeeklog.com  - Stacking Your Transformer:通过堆叠加快 LLM 预训练

此外,如下图 Table 5 所示,作者也探索了不同的堆叠方式,比如只堆叠中间层,或只堆叠首/尾层,最终发现还是全部堆叠的方案最优:

www.zeeklog.com  - Stacking Your Transformer:通过堆叠加快 LLM 预训练

4.2 为什么没保证 FPI?

Function Preserving Initialization(FPI):目标是给定一个源模型,用它初始化目标模型,能保证给定相同的输入,目标模型和源模型有相同的输出。

在之前的很多工作中都在尝试保证 FPI,那是否一定要满足这个要求呢?针对这个问题作者也做了一些实验,具体来说,在 Gdirect 中通过添加噪声(加 20% 噪声),不加噪声,以及从头训练的方式进行对比,如下图 Figure 39 可以看出,初始阶段加噪的效果确实更差,但是随着 FLOPs 的增加,加噪的方式反而更好:

www.zeeklog.com  - Stacking Your Transformer:通过堆叠加快 LLM 预训练

五、参考链接

https://arxiv.org/abs/2405.15319

https://llm-stacking.github.io/

https://github.com/tongxuluo/prts

https://arxiv.org/abs/2309.03852

https://arxiv.org/abs/2011.13635

LLM104

训练38

优化34

模型56

数据19

LLM · 目录

上一篇混合模型:HybridLLM、RouterLLM 等优化 LLM 推理成本的新思路下一篇万字综述:全面梳理 FP8 训练和推理技术

Read more

UML学习总结(1)——UML学习入门

UML学习总结(1)——UML学习入门

随着亲手接触的项目越来越多,项目的复杂度越来越大,项目的理解程度也变的很难,尤其是在接收一个别人已经做好的项目时,你迫切先想到的就是“有没有文档啊”,当然是各种文档,概要设计文档,详细设计文档,数据库设计文档,第三方接口等等各种,但往往得到的答案就是“这个现在没有文档啊”,而且刚好作为经理考研你是否看懂熟悉代码的依据-让你自己写个文档,流程图等等。 下面还是先说说UML里面的图吧等等 。UML总共有用例图、类图、包图、对象图、协作图和序列图、活动图、构件图和部署图。关系主要有依赖关系(Dependency)、关联关系(Associate)(又分为组合和聚合)、泛华关系(Generalization)、实现关系(Realization)。 关联关系   聚合是部分与整体的关系(has a),体现在类成员变量。   组成则是一个比聚合更强形式的关联,在组合中,成员对象的生命周期取决于聚合的生命周期。   依赖体现在方法变量,返回值,局部变量等。    聚合和组成是结构上的关系,而依赖关系则强调的是语义上的关系 1、用例图 意义:有参与者(Actor)

By Ne0inhk
UML学习总结(2)——StartUML 各种类图的例子

UML学习总结(2)——StartUML 各种类图的例子

1.UML分为: 1)静态建模:系统基础和系统固定框架结构,这些图形往往是“静态”的。 * 类图(Class Diagram):常用来分析业务概念 * 用例图(Use Case Diagram):常用 * 对象图(Object Diagram):不常用 * 构件图(Component Diagram):偶尔用 * 部署图(Deployment Diagram):偶尔用 * 包图(Package Diagram):不常用 2)动态建模:描述的是某种行为,是“动态”的。 * 活动图(Activity Diagram):偶尔用 * 状态机图(State Machine Diagram):同上 * 时序图(Sequence

By Ne0inhk
UML学习总结(3)——StarUML指导手册

UML学习总结(3)——StarUML指导手册

StarUML使用说明-指导手册 原著:Stephen Wong            翻译:火猴 StarUML是一种生成类图和其他类型的统一建模语言(UML)图表的工具。这是一个用Java语言描述的创建类图的简明手册。 StarUML(简称SU),是一种创建UML类图,并能够自动生成Java的“stub code” 的工具。SU也可以做JAVA逆向工程,以产生相应的UML图表。 在本教程中,我们将使用SU设计一个pizza饼。执行下列步骤,可以创建如下面所示的UML图。SU可以生成反映类结构的代码,而不是任何对象的具体行动。因此,在使用SU创建图表后,你会为此stub code添加剩余的功能性代码,填写每种方法本来应该做的事。 2.安装 首先,我们必须先安装将要使用的软件。StarUML ,是一个开放源码软件, 遵循许可 ,并免费提供下载。 3.启动 安装以后就可以启动该程序。 4.添加新工程 然后,一个名叫:New Project By Approach的对话框会弹出。选择“

By Ne0inhk
一个java程序员的年终总结

一个java程序员的年终总结

第一. Java程序员需要不断的学习; 貌似这一点适应的行业最广,但是我可以很肯定的说:当你从事web开发一年后,重新找工作时,才会真实的感受到这句话。 工作第一年,往往是什么都充满新鲜感,什么都学习,冲劲十足的一年;WEB行业知识更新特别快,今天一个框架的新版本,明天又是另一个新框架,有时往往根据项目的需要来不断学习新东西;所有,很多时候感觉,自己用过的东西真多呀!但是真正深入研究的东西却不多。 面试,是跳槽后第一个需要面对的问题;而且不同公司面试的着重点不同;但是却有一个共同点:Java基础是必考的。工作第一年,可能问你String对象创建的理解,常用的框架是什么等等;工作第二年,就问你Java内存分配机制是什么,类是如何加载的等等;第三年,就问你常用的设计模式是什么,你在工作中充当什么角色,怎么独立完成一个模块等等; 可以看出——这是一个典型的程序员的成长过程: 使用Java—->深入理解Java积累经验——>独立设计分析能力——>独当一面的多面手! 因此,必须学习: 1. Java基础的深入理解; 不多作解释,推荐书目《

By Ne0inhk