Transformer 位置编码详解:绝对、相对与旋转位置编码
Transformer 模型的位置编码机制是区分序列顺序的关键。本文详细解析了三种主流编码方法:绝对位置编码利用正弦余弦函数为每个位置生成唯一向量,简单但泛化性弱;相对位置编码关注元素间距离,适合长距离依赖;旋转位置编码(RoPE)通过向量旋转嵌入位置信息,支持长序列外推。文章提供了 Python 和 PyTorch 代码实现及对比分析,指出绝对编码适用于短文本,相对编码适合翻译任务,而 RoPE 已成为大语言模型的首选方案。

Transformer 模型的位置编码机制是区分序列顺序的关键。本文详细解析了三种主流编码方法:绝对位置编码利用正弦余弦函数为每个位置生成唯一向量,简单但泛化性弱;相对位置编码关注元素间距离,适合长距离依赖;旋转位置编码(RoPE)通过向量旋转嵌入位置信息,支持长序列外推。文章提供了 Python 和 PyTorch 代码实现及对比分析,指出绝对编码适用于短文本,相对编码适合翻译任务,而 RoPE 已成为大语言模型的首选方案。

Transformer 模型自 2017 年提出以来,凭借其在序列到序列任务中的优异表现,迅速成为自然语言处理(NLP)领域的主流模型。与传统的循环神经网络(RNN)不同,Transformer 模型完全基于自注意力机制,因此在处理长距离依赖关系方面有显著优势。然而,由于 Transformer 模型缺乏内置的序列顺序信息,必须通过位置编码(Positional Encoding)显式引入位置信息,以便模型能够区分序列中的不同位置。
位置编码是 Transformer 模型中一个至关重要的部分,直接影响到模型对序列信息的处理能力。本文将系统地介绍 Transformer 模型中的三种主要位置编码方法:绝对位置编码、相对位置编码和旋转位置编码。通过对这些方法的详细剖析,并结合具体代码和案例,深入探讨它们在实际应用中的表现和适用场景。
绝对位置编码(Absolute Positional Encoding)是最常见的一种位置编码方法,其思想是在每个输入序列的元素上添加一个位置向量,以表示该元素在序列中的具体位置。这个位置向量通常通过固定的函数生成,与输入数据无关。通常使用的是正弦和余弦函数,这样生成的编码具有很强的周期性,能够捕捉序列中的相对位置信息。
具体来说,对于序列中的第 pos 个位置,绝对位置编码向量的第 i 个维度的值定义如下:
$$ PE_{(pos, 2i)} = \sin(pos / 10000^{2i/d_{model}}) $$ $$ PE_{(pos, 2i+1)} = \cos(pos / 10000^{2i/d_{model}}) $$
其中,$pos$ 表示位置,$i$ 表示维度,$d_{model}$ 表示模型的隐藏层维度。
接下来,我们将展示如何在代码中实现绝对位置编码,并以'我爱你,中国。'为例,展示位置编码后的向量表示。
import numpy as np
import matplotlib.pyplot as plt
def get_absolute_positional_encoding(seq_len, d_model):
position = np.arange(seq_len)[:, np.newaxis]
div_term = np.exp(np.arange(0, d_model, 2) * -(np.log(10000.0) / d_model))
pe = np.zeros((seq_len, d_model))
pe[:, 0::2] = np.sin(position * div_term)
pe[:, 1::2] = np.cos(position * div_term)
return pe
# 假设句子长度为 8,d_model 为 32
sentence = "我爱你,中国。"
seq_len = len(sentence)
d_model = 32
absolute_positional_encoding = get_absolute_positional_encoding(seq_len, d_model)
# 展示绝对位置编码的效果
plt.figure(figsize=(12, 8))
plt.imshow(absolute_positional_encoding, cmap='viridis')
plt.colorbar()
plt.title("Absolute Positional Encoding")
plt.xlabel("d_model dimensions")
plt.ylabel("Position in Sentence")
plt.show()
在上面的代码中,我们为长度为 8 的句子生成了一个绝对位置编码矩阵。该矩阵的维度为(8, 32),每一行表示句子中一个位置的编码。通过热图可以看到,不同位置的编码在特定维度上具有不同的模式,这些模式帮助 Transformer 区分序列中不同位置的元素。
具体到'我爱你,中国。'这句话,每个字符都有一个 32 维的编码向量,这个向量的数值是基于该字符的位置计算出来的。这样,Transformer 模型在处理这个句子时,就可以感知到每个字符在句子中的位置。
优势: 简单且具有良好的可解释性。它能够有效地为序列中的每个位置分配独特的编码,从而帮助模型捕捉序列的顺序信息。
局限性: 尤其是在处理变长序列或长距离依赖时,绝对位置编码可能无法充分表达复杂的位置信息。此外,它难以泛化到训练时未见过的序列长度。
相对位置编码最早在 Transformer-XL 和 T5 等模型中引入,以解决绝对位置编码在捕捉长距离依赖关系时的不足。
与绝对位置编码不同,相对位置编码(Relative Positional Encoding)并不直接为每个位置分配一个唯一的编码,而是关注序列中各元素之间的相对位置。相对位置编码的核心思想是通过计算序列中元素之间的距离,来表示它们之间的相对关系。这种方法尤其适合处理需要捕捉长距离依赖关系的任务,因为它能够更加灵活地表示序列中的结构信息。
相对位置编码可以通过多种方式实现,其中最常用的方法之一是将位置差值与注意力权重相结合,即在计算自注意力时,不仅考虑内容,还考虑位置差异。这样,模型能够根据元素之间的距离调整它们之间的交互强度。
在标准的 Self-Attention 中,Query (Q) 和 Key (K) 的点积计算相似度。引入相对位置后,Score 变为:
$$ Score(q_i, k_j) = q_i^T k_j + b_{r(i,j)} $$
其中 $r(i,j)$ 表示位置 $i$ 和 $j$ 之间的相对距离,$b$ 是可学习的偏置项。
下面是一个简单的相对位置编码的实现示例,仍然以'我爱你,中国。'为例。
import torch
import torch.nn.functional as F
class RelativePositionalEncoding(torch.nn.Module):
def __init__(self, d_model, max_len=5000):
super(RelativePositionalEncoding, self).__init__()
self.d_model = d_model
self.max_len = max_len
# 生成相对位置编码
self.relative_positions_matrix = self.generate_relative_positions_matrix(max_len)
self.embeddings_table = self.create_embeddings_table(max_len, d_model)
def generate_relative_positions_matrix(self, length):
range_vec = torch.arange(length)
distance_mat = range_vec[None, :] - range_vec[:, None]
return distance_mat
def create_embeddings_table(self, max_len, d_model):
table = torch.zeros(max_len, max_len, d_model)
for pos in range(-max_len+1, max_len):
table[:, pos] = self.get_relative_positional_encoding(pos, d_model)
return table
def get_relative_positional_encoding(self, pos, d_model):
pos_encoding = torch.zeros(d_model)
for i in range(0, d_model, 2):
pos_encoding[i] = torch.sin(pos / ( ** (( * i)/d_model)))
i + < d_model:
pos_encoding[i + ] = torch.cos(pos / ( ** (( * i)/d_model)))
pos_encoding
():
positions_matrix = .relative_positions_matrix[:length, :length]
F.embedding(positions_matrix, .embeddings_table)
sentence_length =
d_model =
relative_positional_encoding = RelativePositionalEncoding(d_model, max_len=sentence_length)
relative_positional_encodings = relative_positional_encoding(sentence_length)
(relative_positional_encodings.shape)
优势: 对序列长度和相对位置信息的良好适应性,特别适合处理长文本或存在复杂依赖关系的任务。
劣势: 实现相对复杂,且在某些情况下可能增加计算成本。
旋转位置编码(Rotary Positional Encoding, RoPE)是近年来提出的一种新型位置编码方法,主要应用于大语言模型(如 LLaMA)。RoPE 的核心思想是通过对输入向量进行旋转变换,将位置信息嵌入到向量中。具体来说,RoPE 通过旋转每个维度对中的向量,实现对序列中位置信息的编码。
RoPE 具有很强的表达能力,尤其是在处理具有对称性或周期性的任务时,能够更加自然地捕捉序列中的位置信息。它允许模型在推理时外推到比训练更长的序列。
假设查询向量 $q$ 和键向量 $k$ 被分解为二维平面上的向量对 $(q_{2i}, q_{2i+1})$ 和 $(k_{2i}, k_{2i+1})$。对于位置 $m$ 和 $n$,RoPE 定义为:
$$ f_q(m) = [\cos(m\theta)x - \sin(m\theta)y, \sin(m\theta)x + \cos(m\theta)y] $$ $$ f_k(n) = [\cos(n\theta)a - \sin(n\theta)b, \sin(n\theta)a + \cos(n\theta)b] $$
其中 $\theta$ 是频率参数。经过变换后,点积 $f_q(m)^T f_k(n)$ 仅依赖于相对位置 $m-n$。
下面的代码展示了如何在 NLP 任务中实现旋转位置编码。
import torch
import math
def precompute_freqs_cis(dim, end, theta=10000.0):
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
t = torch.arange(end, device=freqs.device) # type: ignore
freqs = torch.outer(t, freqs).float() # type: ignore
freqs_cos = torch.cos(freqs)
freqs_sin = torch.sin(freqs)
return freqs_cos, freqs_sin
def apply_rotary_pos_emb(x, cos, sin):
x1 = x[..., ::2]
x2 = x[..., 1::2]
cos = cos.unsqueeze(-2)
sin = sin.unsqueeze(-2)
out = torch.stack((x1 * cos - x2 * sin, x2 * cos + x1 * sin), dim=-1).flatten(-2)
return out
# 模拟 Q 和 K 矩阵
batch_size = 1
seq_len = 8
d_model = 32
head_dim = d_model // 4 # 假设 4 heads
q = torch.randn(batch_size, seq_len, head_dim)
k = torch.randn(batch_size, seq_len, head_dim)
# 预计算 RoPE
freqs_cos, freqs_sin = precompute_freqs_cis(head_dim, seq_len)
# 应用 RoPE
q_rope = apply_rotary_pos_emb(q, freqs_cos, freqs_sin)
k_rope = apply_rotary_pos_emb(k, freqs_cos, freqs_sin)
print(f"Original Q shape: {q.shape}")
print(f"RoPE Q shape: {q_rope.shape}")
优势: 强大的表达能力,特别是在处理具有对称性或周期性特征的数据时表现优异。支持动态长度外推。
劣势: 实现复杂度较高,且对硬件加速器的兼容性需要仔细验证。
| 特性 | 绝对位置编码 | 相对位置编码 | 旋转位置编码 (RoPE) |
|---|---|---|---|
| 原理 | 固定函数映射位置 | 基于位置差值的偏置 | 向量空间旋转 |
| 序列长度泛化 | 较差 | 较好 | 优秀 |
| 计算复杂度 | 低 | 中 | 中 |
| 适用场景 | 短文本、标准 NLP | 长文本、翻译 | 大语言模型、长上下文 |
| 实现难度 | 简单 | 中等 | 较难 |
位置编码是 Transformer 模型中至关重要的一部分,不同的编码方式适用于不同的任务和数据类型。本文详细介绍了绝对位置编码、相对位置编码和旋转位置编码的原理、实现及应用,通过具体的案例分析展示了它们在实际任务中的表现。
随着 NLP 领域的不断发展,新的位置编码方法可能会不断涌现,进一步提升 Transformer 模型在复杂任务中的表现。了解并掌握这些位置编码方法,将有助于研究人员和工程师更好地应用 Transformer 模型,处理各种序列数据,提升模型的性能和应用效果。

微信公众号「极客日志」,在微信中扫描左侧二维码关注。展示文案:极客日志 zeeklog
使用加密算法(如AES、TripleDES、Rabbit或RC4)加密和解密文本明文。 在线工具,加密/解密文本在线工具,online
生成新的随机RSA私钥和公钥pem证书。 在线工具,RSA密钥对生成器在线工具,online
基于 Mermaid.js 实时预览流程图、时序图等图表,支持源码编辑与即时渲染。 在线工具,Mermaid 预览与可视化编辑在线工具,online
解析常见 curl 参数并生成 fetch、axios、PHP curl 或 Python requests 示例代码。 在线工具,curl 转代码在线工具,online
将字符串编码和解码为其 Base64 格式表示形式即可。 在线工具,Base64 字符串编码/解码在线工具,online
将字符串、文件或图像转换为其 Base64 表示形式。 在线工具,Base64 文件转换器在线工具,online