Pre Norm 和 Post Norm 各自的优缺点?
"Pre Norm 和 Post Norm 各自的优缺点?" 学妹这么回答
原创 看图学 2024年07月08日 07:55 山西
题目:
Pre Norm 和 Post Norm 各自的优缺点?
答案
这个问题其实还蛮难回答的,因为目前并没有特别好的理论来解释清楚。
我们先按时间顺序来梳理一下关于 Pre-Norm 和 Post-Norm 的研究。
Pre Norm 和 Post Norm 的区别 Layer Norm 和 Residual connections 组合方式的不同。
2017 Attention is All your Need
在原始的 Transformers 论文中,使用的是 Post Norm,如下所示。
Post Norm 用公式可以表示为:
每一层的输入先与 Attention 相加,然后才计算 Layer Norm。早期的很多模型都用的是 Post Norm,比如著名的 Bert。
Post Norm 之所以这么设计,是把 Normalization 放在一个模块的最后,这样下一个模块接收到的总是归一化后的结果。这比较符合 Normalization 的初衷,就是为了降低梯度的方差。但是层层堆叠起来,从上图可以看出,深度学习的基建 ResNet 的结构其实被破坏了。
这就导致大家在训练 Transformers 的时候,发现并不是那么容易的训练,learning rate warm up, 初始化等各种招都用上,训练的时候还得小心翼翼。
下一篇文章讨论了 ResNet 中 Identity mappings 的重要性,并且以此为基础提出了 Pre Norm。
2019 Transformers without Tears: Improving the Normalization of Self-Attention
假设输入 X、Attention(X)、FNN(X) 的均值为0,方差都为1,且相互独立。事实上可能没有那么理想,因为权重矩阵的分布在学习过程中并不一定能保持理想的分布,这里为了说明问题对建模进行了简化。
对于两个均值为0,方差为1且相互独立的分布,LayerNorm 可以简化为
,所以 Post Norm 可以简化表示为:
公式后面比较复杂的计算就用 来代替。从上面式子可以看出,输入X每经过一层,输出就变为 。
那么多层的结果如下:
原始的 ResNet 求导后会变成 ,这里的 1 起到了很好的防止梯度消失的作用。但是 Post Norm 之后求导变成了 。
这个式子的第一项会随着层数递减,而且是指数的变小。如果层数较低还好,如果像是现在的大模型一样堆叠32层,那
几乎和0没什么区别了,也就丧失了 ResNet 的意义。
Post Norm 就像在何凯明修的高速公路上,每一层都加了一个收费站。
没了 ResNet 的架构,就导致 Transforemrs 在训练的时候,需要小心翼翼。都要加一个 learning rate warm up 的过程,先让模型在小学习率上适应一段时间,然后再正常训练。warm up 的过程虽然在 Transformers 的论文里就提了一嘴,但是真正训练的时候会发现真的很重要。
然后这篇文章将 Layer Norm 的位置改了一下,变成了 Pre Norm
如图所示:
从实验结果发现,Pre Norm 基本上可以不用 warm up。
2020 On Layer Normalization in the Transformer Architecture
这一篇论文更是对上面的结论提出了理论上的证明。
作者先通过实验证明了对于 Post Norm 来说,Lr warming up 是必须的
上图可以看出
没有 warm up 效果很差,没有 warm up BLEU 只能到8,加上了可以到 34.
即使加了 warm up,对warm up 参数的设置也很敏感,比如 warm up step 在500 步时,不同的lrmax 的 BLEU 一个31,一个还不到3。
这就导致了大量的调参工作,还有warm up 阶段也拖延了训练速度。
然后论文的附录的第F 部分证明了,随着层数的增大,Post-Norm 期望的梯度会随着层数的变大而变大,而 Pre Norm 则几乎保持不变。如下图。
所以当层数深的时候,大的梯度加上大的学习率训练很容易就崩盘了,也是刚开始得 warm up 一下。
所以总结来看,Pre Norm 的训练更快,且更加稳定,所以之后的模型架构大多都是 Pre Norm 了,比如 GPT,MPT,Falcon,llama 等。
后面论文的结论和前面大差不大,所以就写的简略点了。
2020 Understanding the difficulty of training transformers
但是 Pre Norm 也并不是都是好的,这篇论文指出,Pre Norm 有潜在的(表示塌陷) representation collapse 问题,具体来说就是靠近输出位置的层会变得非常相似,从而对模型的贡献会变小。
所以在2023年有一篇《ResiDual: Transformer with Dual Residual Connections》试图融合 Pre Norm 和 Post Norm 的优点。
这也就暗示着 Post Norm 虽然不好训练,但是潜力似乎比 Pre Norm 要好。
然后这篇论文中提到的在 LayerNorm 的时候,调整 x 和 f(x) 的比重,其思路被 DeepNorm 借鉴。只不过这里是可学习的权重,而 DeepNorm 则是超参数。
2021 Catformer: Designing stable transformers via sensitivity analysis
Percy Liang 团队的,算是一个小的汇总。给出了 不同 skip connect 的梯度方差和敏感性的指标。
从图中也可以看出,相比于 Post Norm,Pre Norm 对层数更加不敏感。
汇总
通过上面的几篇论文,汇总如下:
Pre Norm 在训练稳定和收敛性方面有明显的优势,所以大模型时代基本都无脑使用 Pre Norm 了。但是其可能有潜在的(表示塌陷) representation collapse 问题,也就是上限可能不如 Post Norm。
Post Norm 则对训练不稳定,梯度容易爆炸,学习率敏感,初始化权重敏感,收敛困难。好处是有潜在效果上的优势,到底有没有呢?也不好说,因为现在大模型训练太费钱了,Post Norm 在效果上带来的提升很可能不如多扔点数据让 Pre Norm 更快的训练出来。
— END —
推荐阅读: