跳到主要内容Transformer 算法详解:架构、注意力机制与核心组件 | 极客日志编程语言AI算法
Transformer 算法详解:架构、注意力机制与核心组件
Transformer 是一种基于注意力机制的深度学习模型,由 Vaswani 等人于 2017 年提出。其核心架构包含编码器和解码器,利用多头自注意力机制处理序列数据,解决了 RNN 和 LSTM 在长距离依赖上的问题。文章详细阐述了输入编码、位置编码、Self-Attention 计算流程(Query、Key、Value)、残差连接、层标准化及前馈网络等关键组件,并补充了解码器的掩码注意力机制,为理解 ChatGPT 等大语言模型奠定基础。
CoderByte3 浏览 Transformer 模型是深度学习中一种基于注意力机制的模型,广泛应用于自然语言处理(NLP)任务,如机器翻译、文本生成和问答系统。
它由 Vaswani 等人在 2017 年的论文《Attention Is All You Need》中提出,突破了传统序列模型(如 RNN 和 LSTM)的限制,特别是在长距离依赖问题上表现出色。它是 ChatGPT 和所有其他大语言模型(LLM)的支柱。
模型架构
Transformer 模型由编码器(Encoder)和解码器(Decoder)组成。编码器和解码器各由 N 层相同的子层堆叠而成。
编码器(Encoder)
- 多头自注意力机制(Multi-Head Self-Attention)
- 前馈神经网络(Feed-Forward Neural Network)
解码器(Decoder)
- 掩码多头自注意力机制(Masked Multi-Head Self-Attention)
- 编码器 - 解码器注意力机制(Encoder-Decoder Attention)
- 前馈神经网络(Feed-Forward Neural Network)
Transformer 核心组件
下面,让我们来看看 Transformer 如何将输入文本序列转换为向量表示,以及如何逐层处理这些向量表示得到最终的输出。
1. 输入编码
和常见的 NLP 任务一样,我们首先会使用词嵌入算法(Word Embedding),将输入文本序列的每个词转换为一个词向量。实际应用中的向量一般是 256 或者 512 维。但为了简化起见,这里使用 4 维的词向量来进行讲解。
假设我们的输入文本序列包含了 3 个词,那么每个词可以通过词嵌入算法得到一个 4 维向量,于是整个输入被转化成为一个向量序列。
2. 位置编码
由于 Transformer 模型依赖于自注意力机制,而自注意力机制本质上是无序的,即它不区分输入序列中各个词的位置顺序,因此需要显式地引入位置信息来帮助模型理解序列的顺序关系。
位置编码的具体实现方式有多种,Transformer 模型中采用了一种基于正弦和余弦函数的方式。
对于输入序列中的每个位置 pos 和每个维度 i,位置编码向量的计算公式如下:
$$ PE(pos, 2i) = \sin(pos / 10000^{2i/d_{model}}) $$
$$ PE(pos, 2i+1) = \cos(pos / 10000^{2i/d_{model}}) $$
- pos 表示序列中词语的位置。
- i 表示位置编码向量的维度索引。
- $d_{model}$ 是词嵌入向量的维度。
在模型中,位置编码向量会与输入嵌入向量相加,将位置信息显式地引入到输入数据中。
3. 编码器 Encoder
输入文本序列经过输入处理之后得到了一个向量序列,这个向量序列将被送入第 1 层编码器,第 1 层编码器输出的同样是一个向量序列,再接着送入下一层编码器:第 1 层编码器的输入是融合位置向量的词向量,更上层编码器的输入则是上一层编码器的输出。
下图展示了向量序列在单层 encoder 中的流动,融合位置信息的词向量进入 self-attention 层,self-attention 输出每个位置的向量再输入 FFN 神经网络得到每个位置的新向量。
Self-Attention
注意力机制是神经网络中一个非常吸引人的概念,尤其是在 NLP 等任务中。它就像给模型打了一盏聚光灯,让它专注于输入序列的某些部分,而忽略其他部分,就像我们人类在理解句子时会注意特定的单词或短语一样。
现在,让我们深入研究一种特殊的注意力机制,称为 Self-Attention(自注意力)。想象一下,你正在阅读一个句子,你的大脑会自动突出显示重要的单词或短语以理解其含义。这本质上就是自注意力在神经网络中的作用。它使序列中的每个单词能够'注意'其他单词(包括它自己),以更好地理解上下文。
给定输入序列 X,通过三个不同的线性变换得到 Query(查询向量)、Key(键向量)和 Value(值向量)。
$$ Q = XW^Q, \quad K = XW^K, \quad V = XW^V $$
其中,$W^Q, W^K, W^V$ 是可训练的权重矩阵。
通过点积计算 Query 和 Key 之间的相似度,再通过缩放(scaling)以稳定梯度。
$$ Attention(Q, K, V) = softmax(\frac{QK^T}{\sqrt{d_k}})V $$
下面,我们通过一个具体的案例进行说明。假设,我们需要对词组 Thinking Machines 进行翻译,其中 Thinking 对应的 embedding 向量用 x1 表示,Machines 对应的 embedding 向量用 x2 表示。
当我们处理 Thinking 这个词时,我们需要计算句子中所有词与它的 Attention Score。首先将当前词作为搜索的 query,去和句子中所有词(包含该词本身)的 key 去匹配,看看相关度有多高。我们用 q1 代表 Thinking 对应的 query 向量,k1 及 k2 分别代表 Thinking 以及 Machines 对应的 key 向量,则计算 Thinking 的 attention score 的时候我们需要计算 q1 与 k1、k2 的点乘,同理,我们计算 Machines 的 attention score 的时候需要计算 q2 与 k1、k2 的点乘。
如下图中所示我们分别得到了 q1 与 k1、k2 的点乘积,然后我们进行尺度缩放与 softmax 归一化。
显然,当前单词与自身的注意力得分一般最大,其他单词根据当前单词重要程度有相应的分数。随后我们用这些注意力得分与 Value 向量相乘,得到加权的向量。
上图中 z1 表示对第一个位置词向量(Thinking)计算 Self Attention 的全过程。最终得到的当前位置(这里的例子是第一个位置)词向量会继续输入到前馈神经网络。
在实际的代码实现中,Self Attention 的计算过程是使用矩阵快速计算的,一次就得到所有位置的输出向量。
其中 $W^O$ 是我们模型训练过程中学习到的合适的参数。
而多头注意力(Multi-Head Attention)就是我们可以有不同的 Q,K,V 表示,之后再将其结果合并起来,如下图所示。
残差连接和层标准化
经过 Multi-head Attention 后会进入 Add & Norm 层,这一层是指残差连接和层标准化。
前一层的输出 Sublayer(x) 会与原输入 x 相加 (残差连接),以减缓梯度消失的问题,然后再做层标准化。
其中,Sublayer(x) 表示 Multi-head Attention 的输出。
下面,我们来看一下什么是层标准化,和批量标准化有什么区别。
批量标准化是对一个批次中的所有样本进行标准化处理,它是对一个批次中的所有样本的每一个特征进行归一化。而层标准化是对每个样本的所有特征进行标准化处理,独立于同一批次中的其他样本。
层标准化的优点是不受批量大小的影响,可以在小批量甚至单个样本上工作。更适合序列数据。
前馈网络(FFN)
接着进入到 FFN 层,由下列公式可以看到输入 x 先做线性运算后,然后送入 ReLU,之后再做一次线性运算。
$$ FFN(x) = max(0, xW_1 + b_1)W_2 + b_2 $$
在 FFN 后面,也会接一个 Add & Norm 层,这里就不再赘述。
到目前位置,我们已经把 Transformer 中的 Encoder 部分聊完了。接下来补充 Decoder 部分的细节。
4. 解码器 Decoder
解码器的结构与编码器类似,但增加了额外的注意力层以利用编码器的输出。
- 掩码多头自注意力(Masked Multi-Head Attention):在训练时,为了防止模型偷看未来的目标词,会对后续位置进行掩码(Mask),确保预测第 t 个词时只能看到前 t-1 个词。
- 编码器 - 解码器注意力(Encoder-Decoder Attention):这是解码器独有的层,Query 来自解码器的上一层输出,Key 和 Value 来自编码器的最终输出。这使得解码器可以关注输入序列中与当前输出最相关的部分。
- 前馈网络(FFN):与编码器相同,用于非线性变换。
总结
Transformer 通过并行化的注意力机制取代了传统的循环结构,极大地提升了训练效率和长序列处理能力。其核心在于位置编码解决了顺序问题,多头注意力捕捉了全局依赖,而残差连接与层标准化保证了深层网络的稳定性。理解这些组件是掌握现代大语言模型的基础。
相关免费在线工具
- 加密/解密文本
使用加密算法(如AES、TripleDES、Rabbit或RC4)加密和解密文本明文。 在线工具,加密/解密文本在线工具,online
- RSA密钥对生成器
生成新的随机RSA私钥和公钥pem证书。 在线工具,RSA密钥对生成器在线工具,online
- Mermaid 预览与可视化编辑
基于 Mermaid.js 实时预览流程图、时序图等图表,支持源码编辑与即时渲染。 在线工具,Mermaid 预览与可视化编辑在线工具,online
- Base64 字符串编码/解码
将字符串编码和解码为其 Base64 格式表示形式即可。 在线工具,Base64 字符串编码/解码在线工具,online
- Base64 文件转换器
将字符串、文件或图像转换为其 Base64 表示形式。 在线工具,Base64 文件转换器在线工具,online
- Markdown转HTML
将 Markdown(GFM)转为 HTML 片段,浏览器内 marked 解析;与 HTML转Markdown 互为补充。 在线工具,Markdown转HTML在线工具,online