跳到主要内容
极客日志极客日志面向AI+效率的开发者社区
首页博客GitHub 精选镜像工具UI配色美学隐私政策关于联系
搜索内容 / 工具 / 仓库 / 镜像...⌘K搜索
注册
博客列表
PythonAI算法

从零开始手写 Vision Transformer 实现图像分类任务

综述由AI生成基于 PyTorch 从零搭建 Vision Transformer (ViT) 模型的完整流程。内容涵盖图像 Patch 序列化与线性映射、分类 Token 的引入、正弦余弦位置编码的实现、以及 Transformer Encoder 中多头自注意力机制、层归一化和残差连接的构建。文章提供了关键代码片段并解释了各模块的数学原理与维度变化,最终通过分类头完成图像分类任务,适合希望深入理解 ViT 架构原理的开发者参考。

全栈工匠发布于 2025/2/7更新于 2026/6/1219 浏览
从零开始手写 Vision Transformer 实现图像分类任务

从零开始手写 Vision Transformer 实现图像分类任务

本文旨在通过 PyTorch 框架,从零开始逐步搭建 Vision Transformer (ViT) 模型,并应用于图像分类任务。Transformer 最初是为自然语言处理(NLP)设计的序列化数据处理模型,而 ViT 的核心思想是将图像视为序列化的 Patch(图块),从而利用 Transformer 强大的自注意力机制进行特征提取。

1. 图像序列化与 Patch Embedding

Transformer 无法直接处理二维图像数据,因此第一步是将图像转换为序列。我们将输入图像切割成大小相等的 Patch 子图片,每个子图片被视为一个 Token。

假设输入图像尺寸为 $H \times W$,Patch 大小为 $P \times P$。若图像能被整除,则得到 $N = (H/P) \times (W/P)$ 个 Patch。每个 Patch 被展平为一维向量,并通过线性映射(Linear Projection)投影到隐藏维度 $D$。

代码实现:Patch 分割与线性映射

import torch
import torch.nn as nn

class MyViT(nn.Module):
    def __init__(self, input_shape, n_patches=7, hidden_d=8, n_heads=2, out_d=10):    
        super(MyViT, self).__init__()
        
        # Input and patches sizes
        self.input_shape = input_shape  # 例如 (1, 28, 28)
        self.n_patches = n_patches      # 每边切分数量,如 7
        self.n_heads = n_heads
        self.hidden_d = hidden_d
        self.out_d = out_d
        
        # 验证输入尺寸是否能被整除
        assert input_shape[1] % n_patches == 0, "Input height not divisible by n_patches"
        assert input_shape[2] % n_patches == 0, "Input width not divisible by n_patches"
        
        # 计算 Patch 的实际尺寸
        self.patch_size = (input_shape[] // n_patches, input_shape[] // n_patches)
        
        
        .input_d = (input_shape[] * .patch_size[] * .patch_size[])
        
        
        .linear_mapper = nn.Linear(.input_d, .hidden_d)
1
2
# 计算每个 Patch 展平后的维度 (通道数 * patch_h * patch_w)
self
int
0
self
0
self
1
# 1) Linear mapper: 将 Patch 展平向量映射到隐藏维度 D
self
self
self

在此步骤中,我们不仅完成了空间维度的转换,还通过全连接层实现了特征的初步抽象。如果有多个颜色通道(如 RGB),它们也会被展平并包含在输入向量中。

2. 添加分类标记 (Class Token)

为了完成分类任务,我们需要一个特殊的标记来聚合整个图像的信息。类似于 NLP 中的 [CLS] 标记,我们在序列的开头添加一个可学习的参数向量 class_token。

经过线性映射后,我们的张量形状为 $(N, 49, 8)$。添加 Class Token 后,形状变为 $(N, 50, 8)$,其中第一个位置是分类标记,后续 49 个位置是图像 Patch 的特征。

        # 2) Classification token: 可学习的参数
        self.class_token = nn.Parameter(torch.rand(1, self.hidden_d))

在 forward 过程中,我们将这个 Class Token 堆叠到每个样本的序列头部。

3. 位置编码 (Positional Encoding)

Transformer 本身不具备处理序列顺序的能力(因为它没有循环或卷积结构)。因此,必须注入位置信息。ViT 通常使用正弦和余弦波函数生成的固定位置编码,或者可学习的位置嵌入。

这里我们采用经典的正弦/余弦位置编码公式: $$PE_{(pos, 2i)} = \sin(pos / 10000^{2i/d})$$ $$PE_{(pos, 2i+1)} = \cos(pos / 10000^{2i/d})$$

这确保了每个位置都有独特的编码,且不同位置之间具有相对距离的可预测性。

def get_positional_embeddings(sequence_length, d):
    result = torch.ones(sequence_length, d)
    for i in range(sequence_length):
        for j in range(d):
            if j % 2 == 0:
                result[i][j] = torch.sin(torch.tensor(i) / (10000 ** (j / d)))
            else:
                result[i][j] = torch.cos(torch.tensor(i) / (10000 ** ((j - 1) / d)))
    return result

在模型前向传播中,我们将位置编码加到 Token 上:

        # Adding positional embedding
        tokens += get_positional_embeddings(self.n_patches ** 2 + 1, self.hidden_d).repeat(n, 1, 1)

注意:Class Token 通常不添加位置编码,但在简化实现中,为了方便矩阵运算,有时也会统一加上。严谨的做法是只对 Patch Token 加位置编码,但此处为了保持维度一致性和教学简洁性,对整体序列进行了叠加。

4. Transformer Encoder 核心组件

Transformer Encoder 由多层相同的 Block 组成,每个 Block 包含两个主要部分:多头自注意力机制 (Multi-Head Self Attention, MSA) 和前馈神经网络 (Feed Forward Network, FFN),两者之间都采用了残差连接 (Residual Connection) 和层归一化 (Layer Normalization)。

4.1 层归一化 (LayerNorm)

LayerNorm 用于稳定训练过程。它通过对每个样本的特定维度进行标准化(均值为 0,方差为 1)来实现。在 ViT 中,我们通常在注意力机制之前应用 LayerNorm。

        # 4a) Layer normalization 1
        self.ln1 = nn.LayerNorm((self.n_patches ** 2 + 1, self.hidden_d))

4.2 多头自注意力 (MSA)

自注意力机制允许序列中的每个 Token 与其他所有 Token 进行交互。多头机制意味着我们将特征维度切分为多个'头',每个头独立学习不同的子空间特征,最后拼接起来。

对于单个头,计算流程如下:

  1. Query, Key, Value: 通过线性变换生成 Q, K, V。
  2. Attention Scores: 计算 $QK^T / \sqrt{d_k}$。
  3. Softmax: 归一化为概率分布。
  4. Weighted Sum: 用权重乘以 V。
class MyMSA(nn.Module):
    def __init__(self, d, n_heads=2):    
        super(MyMSA, self).__init__()        
        self.d = d
        self.n_heads = n_heads
        assert d % n_heads == 0, f"Can't divide dimension {d} into {n_heads} heads"
        
        self.d_head = int(d / n_heads)
        
        # 每个头独立的 Q, K, V 线性层
        self.q_mappings = nn.ModuleList([nn.Linear(self.d_head, self.d_head) for _ in range(self.n_heads)])        
        self.k_mappings = nn.ModuleList([nn.Linear(self.d_head, self.d_head) for _ in range(self.n_heads)])        
        self.v_mappings = nn.ModuleList([nn.Linear(self.d_head, self.d_head) for _ in range(self.n_heads)])        
        self.softmax = nn.Softmax(dim=-1)
        
    def forward(self, sequences):    
        # sequences shape: (N, seq_length, token_dim)
        result = []        
        for sequence in sequences:         # 遍历 Batch 中的每个样本
            seq_result = []            
            for head in range(self.n_heads):            
                q_mapping = self.q_mappings[head]                
                k_mapping = self.k_mappings[head]                
                v_mapping = self.v_mappings[head]
                
                # 提取当前头的特征部分
                seq = sequence[:, head * self.d_head: (head + 1) * self.d_head]                
                q, k, v = q_mapping(seq), k_mapping(seq), v_mapping(seq)
                
                # 计算注意力分数
                attention = self.softmax(q @ k.T / (self.d_head ** 0.5))                
                seq_result.append(attention @ v)            
           result.append(torch.hstack(seq_result))        
       return torch.cat([torch.unsqueeze(r, dim=0) for r in result])

注:上述代码为了清晰展示了逐样本处理逻辑,实际工程中常使用 batched operations 以提升效率。

4.3 残差连接与前馈网络 (MLP)

在 MSA 之后,再次应用 LayerNorm 和 MLP。MLP 通常包含两个线性层和一个激活函数(如 GELU 或 ReLU),用于增加模型的非线性表达能力。

        # 5a) Layer normalization 2
        self.ln2 = nn.LayerNorm((self.n_patches ** 2 + 1, self.hidden_d))
        
        # 5b) Encoder MLP
        self.enc_mlp = nn.Sequential(            
            nn.Linear(self.hidden_d, self.hidden_d),            
            nn.ReLU()        
        )

完整的 Encoder Block 前向传播如下:

        # TRANSFORMER ENCODER BEGINS
        # Running Layer Normalization, MSA and residual connection
        out = tokens + self.msa(self.ln1(tokens))
        
        # Running Layer Normalization, MLP and residual connection
        out = out + self.enc_mlp(self.ln2(out))
        # TRANSFORMER ENCODER ENDS

5. 分类头 (Classification Head)

经过 Encoder 处理后,我们只需要关注 Class Token 的输出。将其取出,通过一个线性层映射到类别数量,并使用 Softmax 得到概率分布。

        # 6) Classification MLP
        self.mlp = nn.Sequential(       
            nn.Linear(self.hidden_d, out_d),            
            nn.Softmax(dim=-1)        
        )

完整的前向传播逻辑整合:

    def forward(self, images):    
        # Dividing images into patches
        n, c, w, h = images.shape
        patches = images.reshape(n, self.n_patches ** 2, self.input_d)
        
        # Running linear layer for tokenization
        tokens = self.linear_mapper(patches)
        
        # Adding classification token to the tokens
        tokens = torch.stack([torch.vstack((self.class_token, tokens[i])) for i in range(len(tokens))])
        
        # Adding positional embedding
        tokens += get_positional_embeddings(self.n_patches ** 2 + 1, self.hidden_d).repeat(n, 1, 1)
        
        # Transformer Encoder Processing
        out = tokens + self.msa(self.ln1(tokens))
        out = out + self.enc_mlp(self.ln2(out))
        
        # Getting the classification token only
        out = out[:, 0]
        
        return self.mlp(out)

6. 测试与验证

我们可以构建一个简单的测试脚本来验证维度变化是否符合预期。

# Define input shape (batch size, channels, width, height)
input_shape = (1, 28, 28)
input_image_shape = (1, 1, 28, 28)

# Create model
model = MyViT(input_shape=input_shape)

# Create random input tensor with shape (1, 1, 28, 28)
images = torch.rand(input_image_shape)

# Forward pass
output = model(images)
print(f"Output shape: {output.shape}")  # Expected: (1, 10)

运行结果应显示输出张量形状为 (N, 10),代表每张图像属于 10 个类别的概率分布。

总结

Vision Transformer 通过将图像分解为 Patch 序列,成功地将 Transformer 架构引入计算机视觉领域。其核心优势在于全局感受野,能够捕捉长距离依赖关系。尽管 CNN 在局部特征提取上依然高效,但 ViT 在大规模数据集上的表现证明了其强大的泛化能力。本教程详细拆解了从 Patch Embedding、位置编码、多头注意力到分类头的全过程,为深入理解现代视觉模型奠定了坚实基础。

目录

  1. 从零开始手写 Vision Transformer 实现图像分类任务
  2. 1. 图像序列化与 Patch Embedding
  3. 代码实现:Patch 分割与线性映射
  4. 2. 添加分类标记 (Class Token)
  5. 3. 位置编码 (Positional Encoding)
  6. 4. Transformer Encoder 核心组件
  7. 4.1 层归一化 (LayerNorm)
  8. 4.2 多头自注意力 (MSA)
  9. 4.3 残差连接与前馈网络 (MLP)
  10. 5. 分类头 (Classification Head)
  11. 6. 测试与验证
  12. Define input shape (batch size, channels, width, height)
  13. Create model
  14. Create random input tensor with shape (1, 1, 28, 28)
  15. Forward pass
  16. 总结
  • 免费图片AI生成工具免费生成了解详情
  • Magick API 一键接入全球大模型注册送1000万token查看
  • 免费图片视频在线生成30秒,将你的创意变成现实开始设计
  • X/Twitter免费视频下载器免登陆无限额度免费视频解析下载了解详情
  • 100+免费在线小游戏爽一把
极客日志微信公众号二维码

微信扫一扫,关注极客日志

微信公众号「极客日志V2」,在微信中扫描左侧二维码关注。展示文案:极客日志V2 zeeklog

更多推荐文章

查看全部
  • Web 可访问性最佳实践:构建人人可用的前端界面
  • DeerFlow 2.0 开源超级 Agent 框架技术解析
  • 常见 AIGC 论文降重工具评测与对比
  • 2025 年 AIGC 六大核心趋势与落地应用分析
  • 两个月学习大语言模型(LLM)的详细学习计划与实战指南
  • Spring Cloud Gateway 微服务网关核心解析
  • Django REST Framework 重构智能合同审查系统实战
  • 使用 Windows Machine Learning 加载 ONNX 模型并推理
  • 使用Ollama和Open WebUI部署与管理本地开源大模型
  • Trae 配置 MinGW 编译 C++ 程序指南
  • SkyWalking 集成 Spring Cloud Alibaba 全链路追踪实战
  • Skills 详解:AI Agent 的模块化能力扩展系统
  • Rust 异步 Web 框架 Axum:核心原理与实战进阶
  • RunningHub:AIGC 创作平台深度解析
  • WebRTC 源码解析:应用层 API 功能实现
  • Java 数据结构:栈与队列核心指南
  • Python 生成四位随机数的多种实现方案
  • 前缀和算法实战:和为 K 的子数组与和可被 K 整除的子数组
  • SRC 漏洞挖掘思路与手法详解:从信息搜集到批量测试
  • nanobot 通过 webhook 对接钉钉和飞书实现跨平台消息同步

相关免费在线工具

  • 加密/解密文本

    使用加密算法(如AES、TripleDES、Rabbit或RC4)加密和解密文本明文。 在线工具,加密/解密文本在线工具,online

  • RSA密钥对生成器

    生成新的随机RSA私钥和公钥pem证书。 在线工具,RSA密钥对生成器在线工具,online

  • Mermaid 预览与可视化编辑

    基于 Mermaid.js 实时预览流程图、时序图等图表,支持源码编辑与即时渲染。 在线工具,Mermaid 预览与可视化编辑在线工具,online

  • 随机西班牙地址生成器

    随机生成西班牙地址(支持马德里、加泰罗尼亚、安达卢西亚、瓦伦西亚筛选),支持数量快捷选择、显示全部与下载。 在线工具,随机西班牙地址生成器在线工具,online

  • Gemini 图片去水印

    基于开源反向 Alpha 混合算法去除 Gemini/Nano Banana 图片水印,支持批量处理与下载。 在线工具,Gemini 图片去水印在线工具,online

  • curl 转代码

    解析常见 curl 参数并生成 fetch、axios、PHP curl 或 Python requests 示例代码。 在线工具,curl 转代码在线工具,online