跳到主要内容PythonAI算法
从零开始手写 Vision Transformer 实现图像分类任务
综述由AI生成基于 PyTorch 从零搭建 Vision Transformer (ViT) 模型的完整流程。内容涵盖图像 Patch 序列化与线性映射、分类 Token 的引入、正弦余弦位置编码的实现、以及 Transformer Encoder 中多头自注意力机制、层归一化和残差连接的构建。文章提供了关键代码片段并解释了各模块的数学原理与维度变化,最终通过分类头完成图像分类任务,适合希望深入理解 ViT 架构原理的开发者参考。
全栈工匠19 浏览 从零开始手写 Vision Transformer 实现图像分类任务
本文旨在通过 PyTorch 框架,从零开始逐步搭建 Vision Transformer (ViT) 模型,并应用于图像分类任务。Transformer 最初是为自然语言处理(NLP)设计的序列化数据处理模型,而 ViT 的核心思想是将图像视为序列化的 Patch(图块),从而利用 Transformer 强大的自注意力机制进行特征提取。
1. 图像序列化与 Patch Embedding
Transformer 无法直接处理二维图像数据,因此第一步是将图像转换为序列。我们将输入图像切割成大小相等的 Patch 子图片,每个子图片被视为一个 Token。
假设输入图像尺寸为 $H \times W$,Patch 大小为 $P \times P$。若图像能被整除,则得到 $N = (H/P) \times (W/P)$ 个 Patch。每个 Patch 被展平为一维向量,并通过线性映射(Linear Projection)投影到隐藏维度 $D$。
代码实现:Patch 分割与线性映射
import torch
import torch.nn as nn
class MyViT(nn.Module):
def __init__(self, input_shape, n_patches=7, hidden_d=8, n_heads=2, out_d=10):
super(MyViT, self).__init__()
self.input_shape = input_shape
self.n_patches = n_patches
self.n_heads = n_heads
self.hidden_d = hidden_d
self.out_d = out_d
assert input_shape[1] % n_patches == 0, "Input height not divisible by n_patches"
assert input_shape[2] % n_patches == 0, "Input width not divisible by n_patches"
self.patch_size = (input_shape[] // n_patches, input_shape[] // n_patches)
.input_d = (input_shape[] * .patch_size[] * .patch_size[])
.linear_mapper = nn.Linear(.input_d, .hidden_d)
1
2
self
int
0
self
0
self
1
self
self
self
在此步骤中,我们不仅完成了空间维度的转换,还通过全连接层实现了特征的初步抽象。如果有多个颜色通道(如 RGB),它们也会被展平并包含在输入向量中。
2. 添加分类标记 (Class Token)
为了完成分类任务,我们需要一个特殊的标记来聚合整个图像的信息。类似于 NLP 中的 [CLS] 标记,我们在序列的开头添加一个可学习的参数向量 class_token。
经过线性映射后,我们的张量形状为 $(N, 49, 8)$。添加 Class Token 后,形状变为 $(N, 50, 8)$,其中第一个位置是分类标记,后续 49 个位置是图像 Patch 的特征。
self.class_token = nn.Parameter(torch.rand(1, self.hidden_d))
在 forward 过程中,我们将这个 Class Token 堆叠到每个样本的序列头部。
3. 位置编码 (Positional Encoding)
Transformer 本身不具备处理序列顺序的能力(因为它没有循环或卷积结构)。因此,必须注入位置信息。ViT 通常使用正弦和余弦波函数生成的固定位置编码,或者可学习的位置嵌入。
这里我们采用经典的正弦/余弦位置编码公式:
$$PE_{(pos, 2i)} = \sin(pos / 10000^{2i/d})$$
$$PE_{(pos, 2i+1)} = \cos(pos / 10000^{2i/d})$$
这确保了每个位置都有独特的编码,且不同位置之间具有相对距离的可预测性。
def get_positional_embeddings(sequence_length, d):
result = torch.ones(sequence_length, d)
for i in range(sequence_length):
for j in range(d):
if j % 2 == 0:
result[i][j] = torch.sin(torch.tensor(i) / (10000 ** (j / d)))
else:
result[i][j] = torch.cos(torch.tensor(i) / (10000 ** ((j - 1) / d)))
return result
在模型前向传播中,我们将位置编码加到 Token 上:
tokens += get_positional_embeddings(self.n_patches ** 2 + 1, self.hidden_d).repeat(n, 1, 1)
注意:Class Token 通常不添加位置编码,但在简化实现中,为了方便矩阵运算,有时也会统一加上。严谨的做法是只对 Patch Token 加位置编码,但此处为了保持维度一致性和教学简洁性,对整体序列进行了叠加。
4. Transformer Encoder 核心组件
Transformer Encoder 由多层相同的 Block 组成,每个 Block 包含两个主要部分:多头自注意力机制 (Multi-Head Self Attention, MSA) 和前馈神经网络 (Feed Forward Network, FFN),两者之间都采用了残差连接 (Residual Connection) 和层归一化 (Layer Normalization)。
4.1 层归一化 (LayerNorm)
LayerNorm 用于稳定训练过程。它通过对每个样本的特定维度进行标准化(均值为 0,方差为 1)来实现。在 ViT 中,我们通常在注意力机制之前应用 LayerNorm。
self.ln1 = nn.LayerNorm((self.n_patches ** 2 + 1, self.hidden_d))
4.2 多头自注意力 (MSA)
自注意力机制允许序列中的每个 Token 与其他所有 Token 进行交互。多头机制意味着我们将特征维度切分为多个'头',每个头独立学习不同的子空间特征,最后拼接起来。
- Query, Key, Value: 通过线性变换生成 Q, K, V。
- Attention Scores: 计算 $QK^T / \sqrt{d_k}$。
- Softmax: 归一化为概率分布。
- Weighted Sum: 用权重乘以 V。
class MyMSA(nn.Module):
def __init__(self, d, n_heads=2):
super(MyMSA, self).__init__()
self.d = d
self.n_heads = n_heads
assert d % n_heads == 0, f"Can't divide dimension {d} into {n_heads} heads"
self.d_head = int(d / n_heads)
self.q_mappings = nn.ModuleList([nn.Linear(self.d_head, self.d_head) for _ in range(self.n_heads)])
self.k_mappings = nn.ModuleList([nn.Linear(self.d_head, self.d_head) for _ in range(self.n_heads)])
self.v_mappings = nn.ModuleList([nn.Linear(self.d_head, self.d_head) for _ in range(self.n_heads)])
self.softmax = nn.Softmax(dim=-1)
def forward(self, sequences):
result = []
for sequence in sequences:
seq_result = []
for head in range(self.n_heads):
q_mapping = self.q_mappings[head]
k_mapping = self.k_mappings[head]
v_mapping = self.v_mappings[head]
seq = sequence[:, head * self.d_head: (head + 1) * self.d_head]
q, k, v = q_mapping(seq), k_mapping(seq), v_mapping(seq)
attention = self.softmax(q @ k.T / (self.d_head ** 0.5))
seq_result.append(attention @ v)
result.append(torch.hstack(seq_result))
return torch.cat([torch.unsqueeze(r, dim=0) for r in result])
注:上述代码为了清晰展示了逐样本处理逻辑,实际工程中常使用 batched operations 以提升效率。
4.3 残差连接与前馈网络 (MLP)
在 MSA 之后,再次应用 LayerNorm 和 MLP。MLP 通常包含两个线性层和一个激活函数(如 GELU 或 ReLU),用于增加模型的非线性表达能力。
self.ln2 = nn.LayerNorm((self.n_patches ** 2 + 1, self.hidden_d))
self.enc_mlp = nn.Sequential(
nn.Linear(self.hidden_d, self.hidden_d),
nn.ReLU()
)
完整的 Encoder Block 前向传播如下:
out = tokens + self.msa(self.ln1(tokens))
out = out + self.enc_mlp(self.ln2(out))
5. 分类头 (Classification Head)
经过 Encoder 处理后,我们只需要关注 Class Token 的输出。将其取出,通过一个线性层映射到类别数量,并使用 Softmax 得到概率分布。
self.mlp = nn.Sequential(
nn.Linear(self.hidden_d, out_d),
nn.Softmax(dim=-1)
)
def forward(self, images):
n, c, w, h = images.shape
patches = images.reshape(n, self.n_patches ** 2, self.input_d)
tokens = self.linear_mapper(patches)
tokens = torch.stack([torch.vstack((self.class_token, tokens[i])) for i in range(len(tokens))])
tokens += get_positional_embeddings(self.n_patches ** 2 + 1, self.hidden_d).repeat(n, 1, 1)
out = tokens + self.msa(self.ln1(tokens))
out = out + self.enc_mlp(self.ln2(out))
out = out[:, 0]
return self.mlp(out)
6. 测试与验证
我们可以构建一个简单的测试脚本来验证维度变化是否符合预期。
input_shape = (1, 28, 28)
input_image_shape = (1, 1, 28, 28)
model = MyViT(input_shape=input_shape)
images = torch.rand(input_image_shape)
output = model(images)
print(f"Output shape: {output.shape}")
运行结果应显示输出张量形状为 (N, 10),代表每张图像属于 10 个类别的概率分布。
总结
Vision Transformer 通过将图像分解为 Patch 序列,成功地将 Transformer 架构引入计算机视觉领域。其核心优势在于全局感受野,能够捕捉长距离依赖关系。尽管 CNN 在局部特征提取上依然高效,但 ViT 在大规模数据集上的表现证明了其强大的泛化能力。本教程详细拆解了从 Patch Embedding、位置编码、多头注意力到分类头的全过程,为深入理解现代视觉模型奠定了坚实基础。
相关免费在线工具
- 加密/解密文本
使用加密算法(如AES、TripleDES、Rabbit或RC4)加密和解密文本明文。 在线工具,加密/解密文本在线工具,online
- RSA密钥对生成器
生成新的随机RSA私钥和公钥pem证书。 在线工具,RSA密钥对生成器在线工具,online
- Mermaid 预览与可视化编辑
基于 Mermaid.js 实时预览流程图、时序图等图表,支持源码编辑与即时渲染。 在线工具,Mermaid 预览与可视化编辑在线工具,online
- 随机西班牙地址生成器
随机生成西班牙地址(支持马德里、加泰罗尼亚、安达卢西亚、瓦伦西亚筛选),支持数量快捷选择、显示全部与下载。 在线工具,随机西班牙地址生成器在线工具,online
- Gemini 图片去水印
基于开源反向 Alpha 混合算法去除 Gemini/Nano Banana 图片水印,支持批量处理与下载。 在线工具,Gemini 图片去水印在线工具,online
- curl 转代码
解析常见 curl 参数并生成 fetch、axios、PHP curl 或 Python requests 示例代码。 在线工具,curl 转代码在线工具,online