LLM 常见归一化方法解析
在大语言模型(LLM)的架构中,归一化(Normalization)技术对于模型的训练稳定性、收敛速度以及最终性能起着至关重要的作用。Transformer 及其变体广泛采用了不同的归一化策略。本文将详细解析 LayerNorm、RMSNorm 和 DeepNorm 的原理、代码实现及适用场景,并对比 PreLN 与 PostLN 在 Transformer 中的位置差异。
1. Layer Norm 的计算公式与实现
Layer Normalization (LayerNorm) 是 Transformer 架构中最基础的归一化方法之一。它对单个样本的特征维度进行归一化,使其均值为 0,方差为 1,然后通过可学习的缩放参数 γ 和平移参数 β 进行调整。
数学原理
给定输入特征 $x$,LayerNorm 的计算过程如下:
$$\mu = \frac{1}{N} \sum_{i=1}^{N} x_i$$
$$\sigma^2 = \frac{1}{N} \sum_{i=1}^{N} (x_i - \mu)^2$$
$$\hat{x}_i = \frac{x_i - \mu}{\sqrt{\sigma^2 + \epsilon}}$$
$$y_i = \gamma \cdot \hat{x}_i + \beta$$
其中:
- $\mu$ 为 $x$ 的均值。
- $\sigma$ 为 $x$ 的标准差。
- $\gamma$ 和 $\beta$ 是可训练的模型参数,分别控制新分布的方差和均值。
- $\epsilon$ 是一个极小值(如 $1e-6$),添加到方差上以避免分母为 0。
PyTorch 代码实现
import torch
import torch.nn as nn
def layer_norm(feature):
size = feature.shape
alpha = nn.Parameter(torch.ones(size[-1]))
beta = nn.Parameter(torch.ones(size[-1]))
input_dtype = feature.dtype
feature = feature.to(torch.float32)
mean = feature.mean(-1, keepdim=True)
std = feature.std(-1, keepdim=True, unbiased=False)
normalized = (feature - mean) / (std + 1e-6)
output = alpha * normalized + beta
return output.to(input_dtype)
2. RMS Norm 的计算公式与实现
RMSNorm (Root Mean Square Layer Normalization) 是 LayerNorm 的一种简化变体,主要应用于大语言模型(如 LLaMA 系列)。它去除了均值减去的步骤,仅保留缩放部分。
数学原理
RMSNorm 的核心思想是假设数据的均值已经接近 0,或者减去均值带来的收益不如直接对 RMS 进行归一化显著。其公式如下:
$$\text{RMS}(x) = \sqrt{\frac{1}{N} \sum_{i=1}^{N} x_i^2}$$
$$y_i = \frac{x_i}{\text{RMS}(x)} \cdot \gamma$$
相比于 LayerNorm,RMSNorm 只保留了缩放部分,去除了平移部分(即不需要减去均值,也不需要 $\beta$ 参数)。
PyTorch 代码实现
def rms_norm(feature):
size = feature.shape
weight = nn.Parameter(torch.ones(size[-1]))
input_dtype = feature.dtype
feature = feature.to(torch.float32)
variance = feature.pow(2).mean(-1, keepdim=True)
feature = feature * torch.rsqrt(variance + 1e-6)
return (weight * feature).to(input_dtype)
3. RMS Norm 相比于 Layer Norm 的特点
RMSNorm 相比一般的 LayerNorm 具有以下特点:
- 计算效率更高:减少了计算均值和平移系数的部分,降低了计算开销。
- 训练速度更快:由于参数量减少且计算图更简单,反向传播速度通常更快。
- 效果相当或提升:在大多数大模型实验中,RMSNorm 的效果与 LayerNorm 基本相当,甚至在某些场景下有所提升,特别是在深层网络中。
- 参数更少:省去了 $\beta$ 参数,减少了模型总参数量。
4. Deep Norm 思路与优点
DeepNorm 是由微软提出的一种针对深层 Transformer 结构的归一化改进方法。它不改变归一化层本身的位置,而是对残差连接(Residual Connection)的权重进行修正。
核心思路
传统的 Transformer 使用固定的残差连接 $x + f(x)$。DeepNorm 引入了缩放因子,将残差路径和主路径的权重调整为:
$$\text{Output} = \alpha \cdot x + \beta \cdot f(x)$$
其中 $\alpha$ 和 $\beta$ 是可学习的标量参数,用于平衡残差分支和前馈分支的贡献。
优点
- 缓解梯度爆炸:通过限制残差连接的更新幅度,DeepNorm 可以缓解模型参数爆炸式更新的问题。
- 参数范围可控:把模型参数更新限制在一个常数域范围内,使得模型训练过程更加稳定。
- 支持更深网络:实验表明,结合 DeepNorm 的模型规模可以达到 1000 层以上,而传统结构难以训练如此深的网络。
- 兼顾稳定与性能:DeepNorm 兼具 PreLN 的训练稳定性和 PostLN 的效果性能。
5. LN 在 LLMs 中的不同位置区别
在 Transformer 结构中,LayerNorm 的位置主要有两种选择:Pre-LN 和 Post-LN。这直接影响梯度的流动和训练的稳定性。
Post-LN (Post-Normalization)
- 定义:LayerNorm 放置在残差连接之后。即先计算 $x + f(x)$,再进行归一化。
- 原始结构:Transformer 原始论文采用此结构。
- 缺点:在 LLM 训练过程中发现,Post-LN 的输出层附近的梯度过大会造成训练的不稳定性,尤其是在深层网络中,容易导致梯度消失或爆炸。
- 现状:LLM 很少单独使用 Post-LN。例如 GLM-130B 中采用 Post-LN 与 Pre-LN 结合的方式。
Pre-LN (Pre-Normalization)
- 定义:LayerNorm 放置在残差连接之前。即先对输入进行归一化,再送入子层,最后加残差。
- 优势:Pre-LN 在每层的梯度范数近似相等,有利于提升训练稳定性。相比 Post-LN,使用 Pre-LN 的深层 Transformer 的训练更稳定。
- 劣势:早期研究表明,单纯使用 Pre-LN 可能会在一定程度上损害最终的性能表现(尽管这种差距在现代优化器下已缩小)。
- 现状:为了提升训练稳定性,许多现代大模型(如 BERT 后续版本、GPT-NeoX 等)都采用了 Pre-LN 结构。
结构对比总结
| 特性 | Post-LN | Pre-LN |
|---|
| 位置 | 残差连接后 | 残差连接前 |
| 梯度稳定性 | 较差,深层易震荡 | 较好,梯度范数均衡 |
| 收敛速度 | 较慢 | 较快 |
| 最终性能 | 理论上限高 | 略低但更稳定 |
| 适用场景 | 浅层网络 | 深层大模型 |
6. 实践建议与总结
在选择归一化方案时,建议遵循以下原则:
- 默认选择 RMSNorm:对于新的 LLM 项目,RMSNorm 通常是首选,因为它在性能和效率之间取得了最佳平衡。
- 优先使用 Pre-LN:除非有特定理由需要 Post-LN,否则 Pre-LN 能提供更稳健的训练体验,减少调参成本。
- 关注显存占用:虽然 RMSNorm 计算量小,但在某些框架下,LayerNorm 的实现可能经过高度优化,需根据实际硬件环境测试。
- DeepNorm 的适用性:如果计划训练超深层网络(超过 100 层),考虑引入 DeepNorm 机制来增强稳定性。
综上所述,理解这些归一化方法的底层逻辑有助于开发者更好地设计模型架构,优化训练流程,并在资源受限的情况下做出合理的权衡。