跳到主要内容
极客日志极客日志面向AI+效率的开发者社区
首页博客GitHub 精选镜像工具UI配色美学隐私政策关于联系
搜索内容 / 工具 / 仓库 / 镜像...⌘K搜索
注册
博客列表
PythonAI算法

Transformer 核心机制与架构详解:注意力、自注意力及模型结构

综述由AI生成详细解析了 Transformer 模型的核心机制与架构设计。首先介绍了注意力机制的基本原理,包括 Query、Key、Value 的概念及 Scaled Dot Product Attention 计算方法。接着阐述了自注意力机制如何捕捉序列内部依赖,以及多头注意力如何通过并行计算提升模型表达能力。文章还深入剖析了 Transformer 的 Encoder-Decoder 结构,涵盖位置编码、残差连接、层归一化及解码器的掩码注意力机制。最后提供了基于 MindSpore 框架的完整实践代码,包括模型构建、训练、评估及推理流程,展示了 Transformer 在机器翻译任务中的应用。

t ag发布于 2025/2/6更新于 2026/6/323 浏览
Transformer 核心机制与架构详解:注意力、自注意力及模型结构

Transformer 核心机制与架构详解

1. 学习总结

1.1 注意力机制(Attention Mechanism)

注意力机制是深度学习中一种模拟人类视觉或听觉系统工作方式的技术。其灵感来源于人类的感知过程,即根据输入的信息,有选择性地关注或聚焦于不同部分,以便更有效地处理信息。在深度学习中,注意力机制被广泛应用于序列数据、图像处理等任务。

注意力机制的基本思想是在处理输入序列时,不同位置的信息被赋予不同的权重,以便网络更集中地关注对当前任务有用的部分。这样可以提高模型对长序列或大型数据的处理能力,同时降低处理的复杂性。

在自然语言处理中,注意力机制常常被用于机器翻译、文本摘要等任务。在图像处理中,它可以用于图像分类、图像生成等任务。在注意力机制的基础上,出现了不同的变种,如自注意力机制(Self-Attention)等,用于更好地捕捉序列内部的依赖关系。

具体来说,注意力机制允许模型在处理输入时,对不同位置的信息分配不同的权重,以便网络更有针对性地处理输入序列。这种权重的分配是动态的,可以根据当前输入的情况调整。这种能力使得模型能够更灵活地处理各种输入,并且在处理长序列时不容易出现信息丢失的问题。

在自然语言任务中,通过注意力分数来表达某个词在句子中的重要性,分数越高,说明该词对完成该任务的重要性越大。

计算注意力分数时,我们主要参考三个因素:Query、Key 和 Value。

  • Query:任务内容,代表当前需要查询的信息。
  • Key:索引/标签,帮助定位到答案。
  • Value:答案,实际包含的信息内容。

注意力机制示意图

在文本翻译中,我们希望翻译后的句子的意思和原始句子相似,所以进行注意力分数计算时,Query 一般和目标序列(即翻译后的句子)有关,Key 则与源序列(即翻译前的原始句子)有关。

常用的计算注意力分数的方式有两种:Additive Attention 和 Scaled Dot Product Attention。这里主要介绍第二种方法:Scaled Dot Product Attention。

在几何角度,点积(Dot Product)表示一个向量在另一个向量方向上的投影。换句话说,从几何角度上解读,点积代表了某个向量中的多少是和另一个向量相似的。

点积几何解释

图片来源:Understanding the Dot Product from BetterExplained

将这个概念运用到当前的情境中,我们想要求 Query 和 Key 之间有多少是相似的,则需要计算 Query 和 Key 的点积。

下面是注意力机制的公式:

注意力计算公式

1.2 自注意力机制(Self-Attention)

自注意力机制中,我们关注句子本身,查看每个单词对于周边单词的重要性。这样可以很好地理清句子中的逻辑关系,如代词指代。

举个例子,在"The animal didn't cross the street because it was too tired"这句话中,"it"指代句中的"The animal",所以自注意力会赋予"The"、"animal"更高的注意力分值。

自注意力分数的计算还是遵循着上述的公式,只不过这里的 Query、Key 和 Value 都变成了句子本身点乘各自权重。

给定序列 $X$,序列长度为 n,维度为 d_model。在计算自注意力时 $Q = XW^Q, K = XW^K, V = XW^V$。

其中,序列中位置 i 的词与位置 j 的词之间的自注意力分数为:

自注意力分数公式

1.3 多头注意力(Multi-Head Attention)

多头注意力是注意力机制的扩展,它可以使模型通过不同的方式关注输入序列的不同部分,从而提升模型的训练效果。

不同于之前一次计算整体输入的注意力分数,多头注意力是多次计算,每次计算输入序列中某一部分的注意力分数,最后再将结果进行整合。

多头注意力结构

图片来源:Ashish Vaswani et al., Attention is all you need, 2017.

多头注意力通过对输入的 embedding 乘以不同的权重参数 $W_i^Q, W_i^K, W_i^V$,将其映射到多个小维度空间中,我们称之为'头'(head),每个头部会并行计算自己的自注意力分数。

多头计算过程

拼接前输出

$W_i^O$ 为可学习的权重参数。一般为了平衡计算成本,我们会取 $d_k = d_{model} / h$。

在获得多组自注意力分数后,我们将结果拼接到一起,得到多头注意力的最终输出。

拼接操作

$W^O$ 为可学习的权重参数,用于将拼接后的多头注意力输出映射回原来的维度。

最终输出映射

简单来说,在多头注意力中,每个头部可以'解读'输入内容的不同方面,比如:捕捉全局依赖关系、关注特定语境下的词元、识别词和词之间的语法关系等。

1.4 Transformer 结构

Transformer 是 encoder-decoder 的结构,这里的'encoder'和'decoder'是由无数个同样结构的 encoder 层和 decoder 层堆叠组成。

比如在进行机器翻译时,encoder 解读源语句(被翻译的句子)的信息,并传输给 decoder。decoder 接收源语句信息后,结合当前输入(目前翻译的情况),预测下一个单词,直到生成完整的句子。

Transformer 整体架构

1.4.1 位置编码(Positional Encoding)

Transformer 模型不包含 RNN,所以无法在模型中记录时序信息,这样会导致模型无法识别由顺序改变而产生的句子含义的改变,如'我爱我的小猫'和'我的小猫爱我'。

为了弥补这个缺陷,我们选择在输入数据中额外添加表示位置信息的位置编码。

位置编码 PE 的形状与经过 word embedding 后的输出 X 相同,对于索引为 [pos, 2i] 的元素,以及索引为 [pos, 2i+1] 的元素,位置编码的计算如下:

位置编码公式

1.4.2 编码器(Encoder)

Transformer 的 Encoder 负责处理输入的源序列,并将输入信息整合为一系列的上下文向量(context vector)输出。

每个 encoder 层中存在两个子层:多头自注意力(multi-head self-attention)和基于位置的前馈神经网络(position-wise feed-forward network)。

子层之间使用了残差连接(residual connection),并使用了层规范化(layer normalization)。二者统称为'Add & Norm'。

Encoder 层结构

1.4.3 Add & Norm

Add&Norm 层本质上是残差连接后紧接了一个 LayerNorm 层。

Add & Norm 结构

  • Add:残差连接,帮助缓解网络退化问题,注意需要满足 x 与 SubLayer(x) 的形状一致。
  • Norm:Layer Norm,层归一化,帮助模型更快地进行收敛。
1.4.4 解码器(Decoder)

Decoder 层结构

解码器将编码器输出的上下文序列转换为目标序列的预测结果。

Decoder 输出

该输出将在模型训练中与真实目标输出 Y 进行比较,计算损失。

不同于编码器,每个 Decoder 层中包含两层多头注意力机制,并在最后多出一个线性层,输出对目标序列的预测结果。

  1. 第一层:计算目标序列的注意力分数的掩码多头自注意力(Masked Multi-Head Attention)。这一步是为了防止 Decoder 看到未来的信息,保证自回归生成的特性。
  2. 第二层:用于计算上下文序列与目标序列对应关系,其中 Decoder 掩码多头注意力的输出作为 Query,Encoder 的输出(上下文序列)作为 Key 和 Value。

2. 课程实践

通过 Transformer 实现文本机器翻译

全流程

  1. 数据预处理:将图像、文本等数据处理为可以计算的 Tensor。
  2. 模型构建:使用框架 API,搭建模型。
  3. 模型训练:定义模型训练逻辑,遍历训练集进行训练。
  4. 模型评估:使用训练好的模型,在测试集评估效果。
  5. 模型推理:将训练好的模型部署,输入新数据获得预测结果。

这里实验代码比较多,重点从模型构建开始讲。

模型构建

定义超参数,实例化模型。

src_vocab_size = len(de_vocab)  
trg_vocab_size = len(en_vocab)  
src_pad_idx = de_vocab.pad_idx  
trg_pad_idx = en_vocab.pad_idx  
  
d_model = 512  
d_ff = 2048  
n_layers = 6  
n_heads = 8  
  
encoder = Encoder(src_vocab_size, d_model, n_heads, d_ff, n_layers, dropout_p=0.1)  
decoder = Decoder(trg_vocab_size, d_model, n_heads, d_ff, n_layers, dropout_p=0.1)  
model = Transformer(encoder, decoder)  

模型训练 & 模型评估

模型训练逻辑

昇思 MindSpore 在模型训练部分使用了函数式编程(FP)。 构造函数 → 函数变换 → 函数调用。

def train(iterator, epoch=0):  
    model.set_train(True)  
    num_batches = len(iterator)  
    total_loss = 0  
    total_steps = 0  
  
    with tqdm(total=num_batches) as t:  
        t.set_description(f'Epoch: {epoch}')  
        for src, src_len, trg in iterator():  
            loss = train_step(src, trg)  
            total_loss += loss.asnumpy()  
            total_steps += 1  
            curr_loss = total_loss / total_steps  
            t.set_postfix({'loss': f'{curr_loss:.2f}'})  
            t.update(1)  
  
    return total_loss / total_steps  

定义模型评估逻辑。

def evaluate(iterator):  
    model.set_train(False)  
    num_batches = len(iterator)  
    total_loss = 0  
    total_steps = 0  
  
    with tqdm(total=num_batches) as t:  
        for src, _, trg in iterator():  
            loss = forward(src, trg)  
            total_loss += loss.asnumpy()  
            total_steps += 1  
            curr_loss = total_loss / total_steps  
            t.set_postfix({'loss': f'{curr_loss:.2f}'})  
            t.update(1)  
  
    return total_loss / total_steps

模型训练

数据集遍历迭代,一次完整的数据集遍历成为一个 epoch。我们逐个 epoch 打印训练的损失值和评估精度,并通过 save_checkpoint 保存评估精度最高的 ckpt 文件(transformer.ckpt)到 home_path/.mindspore_examples/transformer.ckpt。

from mindspore import save_checkpoint  
import os

num_epochs = 10  
best_valid_loss = float('inf')  
ckpt_file_name = os.path.join(cache_dir, 'transformer.ckpt')  
  
for i in range(num_epochs):  
    train_loss = train(train_iterator, i)  
    valid_loss = evaluate(valid_iterator)  
  
    if valid_loss < best_valid_loss:  
        best_valid_loss = valid_loss  
        save_checkpoint(model, ckpt_file_name)  

模型推理

def inference(sentence, max_len=32):  
    """模型推理:输入一个德语句子,输出翻译后的英文句子  
    enc_inputs: [batch_size(1), src_len]  
    """  
    new_model.set_train(False)  
  
    if isinstance(sentence, str):  
        tokens = [tok.lower() for tok in re.findall(r'\w+|[^\w\s]', sentence.rstrip())]  
    else:  
        tokens = [token.lower() for token in sentence]  
  
    if len(tokens) > max_len - 2:  
        src_len = max_len  
        tokens = ['<bos>'] + tokens[:max_len - 2] + ['<eos>']  
    else:  
        src_len = len(tokens) + 2  
        tokens = ['<bos>'] + tokens + ['<eos>'] + ['<pad>'] * (max_len - src_len)  
  
    indexes = de_vocab.encode(tokens)  
    enc_inputs = Tensor(indexes, mstype.float32).expand_dims(0)  
  
    enc_outputs, _ = new_model.encoder(enc_inputs, src_pad_idx)  
  
    dec_inputs = Tensor([[en_vocab.bos_idx]], mstype.float32)  
  
    max_len = enc_inputs.shape[1]  
    for _ in range(max_len):  
        dec_outputs, _, _ = new_model.decoder(dec_inputs, enc_inputs, enc_outputs, src_pad_idx, trg_pad_idx)  
        dec_logits = dec_outputs.view((-1, dec_outputs.shape[-1]))  
  
        dec_logits = dec_logits[-1, :]  
        pred = dec_logits.argmax(axis=0).expand_dims(0).expand_dims(0)  
        pred = pred.astype(mstype.float32)  
  
        dec_inputs = ops.concat((dec_inputs, pred), axis=1)  
  
        if int(pred.asnumpy()[0]) == en_vocab.eos_idx:  
            break  
  
    trg_indexes = [int(i) for i in dec_inputs.view(-1).asnumpy()]  
    eos_idx = trg_indexes.index(en_vocab.eos_idx) if en_vocab.eos_idx in trg_indexes else -1  
    trg_tokens = en_vocab.decode(trg_indexes[1:eos_idx])  
  
    return trg_tokens  

以测试数据集中的第一组语句为例,进行测试。

example_idx = 0  
  
src = test_dataset[example_idx][0]  
trg = test_dataset[example_idx][1]  
pred_trg = inference(src)  
  
print(f'src = {src}')  
print(f'trg = {trg}')  
print(f"predicted trg = {pred_trg}")

3. 总结与展望

Transformer 引入的注意力机制是一次巨大的创新。相较于传统的循环神经网络(RNN)或长短时记忆网络(LSTM),Transformer 采用了自注意力机制,使模型能够在处理长序列时保持更好的信息传递。通过对输入序列中不同位置的信息分配不同的权重,Transformer 在语言建模、机器翻译等任务上表现出色。

其次,多头注意力机制使得模型能够同时关注不同的子空间,进一步提高了模型的学习能力。这种并行性的设计使得 Transformer 在处理大规模数据时表现出色,且能够更好地捕捉序列中的局部和全局依赖关系。

另外,位置编码的引入解决了 Transformer 无法处理序列顺序信息的问题。通过将位置信息嵌入到输入数据中,Transformer 能够更好地理解序列中元素的相对位置,这对于语言等需要考虑单词顺序的任务至关重要。

此外,课程中还深入探讨了 Transformer 的训练过程,包括 Add&Norm、残差连接等技术。这些技术的运用使得 Transformer 更易于训练,且对于不同类型的数据和任务都具有广泛的适用性。

总的来说,Transformer 是深度学习领域的一次重要革新。这一模型的创新性设计和在实际任务中的卓越表现让我们深感神经网络领域的不断进步。Transformer 不仅是一种模型,更为探索复杂任务提供了全新的思路和工具,特别是在 NLP 和计算机视觉领域展现出了强大的迁移能力和扩展性。

目录

  1. Transformer 核心机制与架构详解
  2. 1. 学习总结
  3. 1.1 注意力机制(Attention Mechanism)
  4. 1.2 自注意力机制(Self-Attention)
  5. 1.3 多头注意力(Multi-Head Attention)
  6. 1.4 Transformer 结构
  7. 1.4.1 位置编码(Positional Encoding)
  8. 1.4.2 编码器(Encoder)
  9. 1.4.3 Add & Norm
  10. 1.4.4 解码器(Decoder)
  11. 2. 课程实践
  12. 全流程
  13. 模型构建
  14. 模型训练 & 模型评估
  15. 模型训练
  16. 模型推理
  17. 3. 总结与展望
  • 💰 8折买阿里云服务器限时8折了解详情
  • Magick API 一键接入全球大模型注册送1000万token查看
  • 🤖 一键搭建Deepseek满血版了解详情
  • 一键打造专属AI 智能体了解详情
极客日志微信公众号二维码

微信扫一扫,关注极客日志

微信公众号「极客日志V2」,在微信中扫描左侧二维码关注。展示文案:极客日志V2 zeeklog

更多推荐文章

查看全部
  • Python 数据分析神器:ydata_profiling 模块详解
  • GraphRAG 与 RAG 的比较分析
  • C++ string 常用函数详解(三)
  • OpenClaw 本地 AI 助手部署与飞书对接指南
  • Dify + Skill 本地部署大模型智能体:企业级 AI Agent 构建指南
  • GraalVM for JDK 快速上手指南
  • 低空经济驱动下的无人机光伏巡检技术应用
  • AIGC 自动化编程实战:Python、Java、JavaScript 与 VBA
  • OpenClaw 飞书机器人配置教程:聊天下达 AI 指令
  • SLAM 在无人机导航中的落地实践:从算法到部署
  • 2024 主流 AI 绘图工具深度解析:Midjourney 与 Stable Diffusion 对比
  • JVS-APS:算法驱动与低代码融合的智能排产方案
  • 使用 Z-Image-Turbo 进行本地 AI 绘画:16GB 显存支持与中英提示词
  • JavaScript 基础语法与 jQuery 快速入门
  • AI 开发必备:4 个 Skills 组合掌控全流程与灵活控制
  • OpenClaw 配置指南:接入第三方 API 使用大模型
  • Boost C++ 库实战:构建高性能即时通讯服务器
  • 数组与哈希表(Map/Object)核心区别及实战
  • 算法练习:多重背包、贪心差分、DFS 及路径 DP 题解
  • LangChain 中 RAG 模型的应用步骤与技巧

相关免费在线工具

  • 加密/解密文本

    使用加密算法(如AES、TripleDES、Rabbit或RC4)加密和解密文本明文。 在线工具,加密/解密文本在线工具,online

  • RSA密钥对生成器

    生成新的随机RSA私钥和公钥pem证书。 在线工具,RSA密钥对生成器在线工具,online

  • Mermaid 预览与可视化编辑

    基于 Mermaid.js 实时预览流程图、时序图等图表,支持源码编辑与即时渲染。 在线工具,Mermaid 预览与可视化编辑在线工具,online

  • 随机西班牙地址生成器

    随机生成西班牙地址(支持马德里、加泰罗尼亚、安达卢西亚、瓦伦西亚筛选),支持数量快捷选择、显示全部与下载。 在线工具,随机西班牙地址生成器在线工具,online

  • Gemini 图片去水印

    基于开源反向 Alpha 混合算法去除 Gemini/Nano Banana 图片水印,支持批量处理与下载。 在线工具,Gemini 图片去水印在线工具,online

  • curl 转代码

    解析常见 curl 参数并生成 fetch、axios、PHP curl 或 Python requests 示例代码。 在线工具,curl 转代码在线工具,online