主要外推方法概述
长度外推(Length Generalization)研究的是如何在预训练时使用较短的序列长度,但在推理时能够泛化到更大的长度。该问题依然是 Transformer 架构亟待解决但尚未完全攻克的技术难点。优秀的长文本外推性能表现为:在泛化到超长序列时,相关指标不会出现大幅下降,模型表现依然稳健。
目前两年内研究长度泛化的经典思路主要包括以下几种:
1. ALiBi 直接外推
ALiBi(Attention with Linear Biases)主要是在计算 attention score 后添加一个不可学习的 bias。其核心公式为在注意力分数上减去距离矩阵与斜率系数的乘积。
假设有 8 个 heads,m 是一个预先定义好的值,可以选择不同的取值来调整每个 head 在注意力计算中的权重分配。具体来说,m 的取值可以是如下:1/2, 1/4, 1/8, 1/16, 1/32, 1/64, 1/128, 1/256。这些值表示在每个头的注意力计算中,不同头的权重分配逐渐变小,通常是为了增强模型的表达能力或控制注意力的焦点。
ALiBi 由于其使用线性偏差的设计,无法在单层注意力机制中有效捕捉远距离的依赖关系。与标准的自注意力机制不同,ALiBi 的设计通过引入线性偏差来在局部范围内调整注意力分配,从而更适应于捕捉局部信息。然而,它能够进行外推(即捕捉更长距离的信息),这是因为它通过多层的注意力机制逐步扩展感知的范围。换句话说,ALiBi 的远距离信息感知能力依赖于网络的深度,随着层数的增加,模型能够通过多个注意力层逐步捕获更远的依赖信息。但这种能力是有限的,因为它的感知范围随着层数的增加呈线性增长。
2. 位置内插法(Position Interpolation, PI)
将预测的长文本的位置编码乘上因子 Ltrain / Ltest,缩放到训练长度范围内。流程如下:
- 训练阶段:(1, 2, 3, 4, …, n)
- 测试阶段:(1, 2, 3, 4, …, n, …, 2n) -> (0.5, 1, …, n) [通过内插的方式来实现]
尽管位置内插(PI)方法有效避免了远距离位置越界的问题,但它也同时压缩了相邻 Token 之间的距离,这可能会严重影响模型的局部分辨率,导致困惑度(PPL)增大。不过,研究表明,经过常规文本微调之后,PI 方法依然能够取得较好的效果。从整体上来看,这种做法实际上是对位置编码中的 sin(m/base^{-2i/d}) 中的 m 进行了缩放处理。
def _compute_linear_scaling_rope_parameters(
config: Optional[PretrainedConfig] = None,
device: Optional["torch.device"] = None,
seq_len: Optional[int] = None,
**rope_kwargs,
) -> Tuple["torch.Tensor", float]:
"""
Computes the inverse frequencies with linear scaling.
Credits to the Reddit user /u/kaiokendev
Args:
config ([`~transformers.PretrainedConfig`]):
The model configuration.
device (`torch.device`):
The device to use for initialization of the inverse frequencies.
seq_len (`int`, *optional*):
The current sequence length. Unused for this type of RoPE.
rope_kwargs (`Dict`, *optional*):
BC compatibility with the previous RoPE class instantiation.
Returns:
Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies.
"""
config (rope_kwargs) > :
ValueError(
)
(rope_kwargs) > :
factor = rope_kwargs[]
config :
factor = config.rope_scaling[]
inv_freq, attention_factor = _compute_default_rope_parameters(config, device, seq_len, **rope_kwargs)
inv_freq /= factor
inv_freq, attention_factor


