Transformer 模型源自论文《Attention Is All You Need》,旨在解决传统 RNN 在文本生成任务中的局限性。RNN 存在两个主要缺点:一是计算顺序进行,无法并行化;二是长距离依赖问题,信息容易衰减。Transformer 通过多头注意力机制、位置编码、层归一化和残差连接等组件,实现了高效的并行计算和长序列建模。
Transformer 整体架构
Transformer 由 Encoder(编码器)和 Decoder(解码器)两大部分组成。Encoder 负责处理输入序列,Decoder 负责生成输出序列。两者均由多层堆叠的模块构成,主要包含多头自注意力机制和前馈神经网络。
编码 (Encoder) 部分
1. 多头注意力 (Multi-Head Attention)
多头注意力是自注意力机制的扩展。它允许模型在不同表示子空间上关注不同位置的信息。具体流程如下:
- 线性变换:将输入 Q、K、V 分别通过不同的线性投影矩阵,划分为多个头(Heads)。通常使用 8 个头。
- 计算注意力分数:对每个头独立计算缩放点积注意力。
- 拼接与输出:将所有头的输出拼接,再通过一个线性层得到最终结果。
缩放点积注意力的计算公式为: $$\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V$$ 其中 $d_k$ 是键向量的维度,除以 $\sqrt{d_k}$ 是为了防止点积过大导致 softmax 梯度消失。
2. 位置编码 (Positional Encoding)
由于 Transformer 不包含循环或卷积结构,无法感知序列中词的位置信息,因此需要引入位置编码。原始论文采用正弦和余弦函数生成位置编码,公式如下:
$$PE_{(pos, 2i)} = \sin\left(\frac{pos}{10000^{2i/d_{model}}}\right)$$ $$PE_{(pos, 2i+1)} = \cos\left(\frac{pos}{10000^{2i/d_{model}}}\right)$$
位置编码与词嵌入向量相加后作为下一层的输入,确保模型能区分不同位置的词。
3. 层归一化与残差连接
为了加速训练并稳定梯度,Transformer 在每个子层(如注意力层、前馈网络层)周围都添加了残差连接和层归一化。
- 残差连接:将子层的输入直接加到输出上,即 $LayerNorm(x + Sublayer(x))$。
- 层归一化:对单个样本的所有特征进行归一化,使其均值为 0,方差为 1。
解码 (Decoder) 部分
Decoder 的结构与 Encoder 类似,但增加了 Masked Multi-Head Attention 层,以防止解码时看到未来的信息。
1. 掩码多头注意力 (Masked Multi-Head Attention)
在训练过程中,Decoder 的输入是目标序列的右移版本。为了防止模型在预测第 t 个词时看到第 t+1 及之后的词,需要在计算注意力分数时,将未来位置的值设为负无穷大(-INF),经过 softmax 后概率接近 0。
2. 交叉注意力 (Cross-Attention)
Decoder 的第二层注意力机制接收 Encoder 的输出作为 Key 和 Value,而 Query 来自 Decoder 的前一层。这使得 Decoder 能够根据当前生成的词,关注 Encoder 中对应的输入信息。
3. 前馈神经网络与输出
Decoder 最后通过一个全连接层(Linear)和一个 Softmax 层,将输出映射到词汇表大小,得到每个词的概率分布,选择概率最大的词作为输出。
代码实现示例
以下是一个简化的 PyTorch 风格的位置编码实现,用于展示如何生成位置向量:
import torch
import math
class PositionalEncoding(torch.nn.Module):
():
().__init__()
pe = torch.zeros(max_len, d_model)
position = torch.arange(, max_len).unsqueeze()
div_term = torch.exp(torch.arange(, d_model, ) * -(math.log() / d_model))
pe[:, ::] = torch.sin(position * div_term)
pe[:, ::] = torch.cos(position * div_term)
pe = pe.unsqueeze()
.register_buffer(, pe)
():
x + .pe[:, :x.size()]


