从零开始手写 Vision Transformer 实现图像分类任务
本文旨在通过 PyTorch 框架,从零开始逐步搭建 Vision Transformer (ViT) 模型,并应用于图像分类任务。Transformer 最初是为自然语言处理(NLP)设计的序列化数据处理模型,而 ViT 的核心思想是将图像视为序列化的 Patch(图块),从而利用 Transformer 强大的自注意力机制进行特征提取。
1. 图像序列化与 Patch Embedding
Transformer 无法直接处理二维图像数据,因此第一步是将图像转换为序列。我们将输入图像切割成大小相等的 Patch 子图片,每个子图片被视为一个 Token。
假设输入图像尺寸为 $H \times W$,Patch 大小为 $P \times P$。若图像能被整除,则得到 $N = (H/P) \times (W/P)$ 个 Patch。每个 Patch 被展平为一维向量,并通过线性映射(Linear Projection)投影到隐藏维度 $D$。
代码实现:Patch 分割与线性映射
import torch
import torch.nn as nn
class MyViT(nn.Module):
def __init__(self, input_shape, n_patches=7, hidden_d=8, n_heads=2, out_d=10):
super(MyViT, self).__init__()
self.input_shape = input_shape
self.n_patches = n_patches
self.n_heads = n_heads
self.hidden_d = hidden_d
self.out_d = out_d
assert input_shape[1] % n_patches == 0, "Input height not divisible by n_patches"
assert input_shape[2] % n_patches == 0, "Input width not divisible by n_patches"
self.patch_size = (input_shape[1] // n_patches, input_shape[2] // n_patches)
self.input_d = int(input_shape[0] * self.patch_size[0] * self.patch_size[1])
self.linear_mapper = nn.Linear(self.input_d, self.hidden_d)
在此步骤中,我们不仅完成了空间维度的转换,还通过全连接层实现了特征的初步抽象。如果有多个颜色通道(如 RGB),它们也会被展平并包含在输入向量中。
2. 添加分类标记 (Class Token)
为了完成分类任务,我们需要一个特殊的标记来聚合整个图像的信息。类似于 NLP 中的 [CLS] 标记,我们在序列的开头添加一个可学习的参数向量 class_token。
经过线性映射后,我们的张量形状为 $(N, 49, 8)$。添加 Class Token 后,形状变为 $(N, 50, 8)$,其中第一个位置是分类标记,后续 49 个位置是图像 Patch 的特征。
self.class_token = nn.Parameter(torch.rand(1, self.hidden_d))
在 forward 过程中,我们将这个 Class Token 堆叠到每个样本的序列头部。
3. 位置编码 (Positional Encoding)
Transformer 本身不具备处理序列顺序的能力(因为它没有循环或卷积结构)。因此,必须注入位置信息。ViT 通常使用正弦和余弦波函数生成的固定位置编码,或者可学习的位置嵌入。
这里我们采用经典的正弦/余弦位置编码公式:
$$PE_{(pos, 2i)} = \sin(pos / 10000^{2i/d})$$
$$PE_{(pos, 2i+1)} = \cos(pos / 10000^{2i/d})$$
这确保了每个位置都有独特的编码,且不同位置之间具有相对距离的可预测性。
def get_positional_embeddings(sequence_length, d):
result = torch.ones(sequence_length, d)
for i in range(sequence_length):
for j in range(d):
if j % 2 == 0:
result[i][j] = torch.sin(torch.tensor(i) / (10000 ** (j / d)))
else:
result[i][j] = torch.cos(torch.tensor(i) / (10000 ** ((j - 1) / d)))
return result
在模型前向传播中,我们将位置编码加到 Token 上:
tokens += get_positional_embeddings(self.n_patches ** 2 + 1, self.hidden_d).repeat(n, 1, 1)
注意:Class Token 通常不添加位置编码,但在简化实现中,为了方便矩阵运算,有时也会统一加上。严谨的做法是只对 Patch Token 加位置编码,但此处为了保持维度一致性和教学简洁性,对整体序列进行了叠加。
4. Transformer Encoder 核心组件
Transformer Encoder 由多层相同的 Block 组成,每个 Block 包含两个主要部分:多头自注意力机制 (Multi-Head Self Attention, MSA) 和前馈神经网络 (Feed Forward Network, FFN),两者之间都采用了残差连接 (Residual Connection) 和层归一化 (Layer Normalization)。
4.1 层归一化 (LayerNorm)
LayerNorm 用于稳定训练过程。它通过对每个样本的特定维度进行标准化(均值为 0,方差为 1)来实现。在 ViT 中,我们通常在注意力机制之前应用 LayerNorm。
self.ln1 = nn.LayerNorm((self.n_patches ** 2 + 1, self.hidden_d))
4.2 多头自注意力 (MSA)
自注意力机制允许序列中的每个 Token 与其他所有 Token 进行交互。多头机制意味着我们将特征维度切分为多个'头',每个头独立学习不同的子空间特征,最后拼接起来。
对于单个头,计算流程如下:
- Query, Key, Value: 通过线性变换生成 Q, K, V。
- Attention Scores: 计算 $QK^T / \sqrt{d_k}$。
- Softmax: 归一化为概率分布。
- Weighted Sum: 用权重乘以 V。
class MyMSA(nn.Module):
def __init__(self, d, n_heads=2):
super(MyMSA, self).__init__()
self.d = d
self.n_heads = n_heads
assert d % n_heads == 0, f"Can't divide dimension {d} into {n_heads} heads"
self.d_head = int(d / n_heads)
self.q_mappings = nn.ModuleList([nn.Linear(self.d_head, self.d_head) for _ in range(self.n_heads)])
self.k_mappings = nn.ModuleList([nn.Linear(self.d_head, self.d_head) for _ in range(self.n_heads)])
self.v_mappings = nn.ModuleList([nn.Linear(self.d_head, self.d_head) for _ in range(self.n_heads)])
self.softmax = nn.Softmax(dim=-1)
def forward(self, sequences):
result = []
for sequence in sequences:
seq_result = []
head (.n_heads):
q_mapping = .q_mappings[head]
k_mapping = .k_mappings[head]
v_mapping = .v_mappings[head]
seq = sequence[:, head * .d_head: (head + ) * .d_head]
q, k, v = q_mapping(seq), k_mapping(seq), v_mapping(seq)
attention = .softmax(q @ k.T / (.d_head ** ))
seq_result.append(attention @ v)
result.append(torch.hstack(seq_result))
torch.cat([torch.unsqueeze(r, dim=) r result])
注:上述代码为了清晰展示了逐样本处理逻辑,实际工程中常使用 batched operations 以提升效率。
4.3 残差连接与前馈网络 (MLP)
在 MSA 之后,再次应用 LayerNorm 和 MLP。MLP 通常包含两个线性层和一个激活函数(如 GELU 或 ReLU),用于增加模型的非线性表达能力。
self.ln2 = nn.LayerNorm((self.n_patches ** 2 + 1, self.hidden_d))
self.enc_mlp = nn.Sequential(
nn.Linear(self.hidden_d, self.hidden_d),
nn.ReLU()
)
完整的 Encoder Block 前向传播如下:
out = tokens + self.msa(self.ln1(tokens))
out = out + self.enc_mlp(self.ln2(out))
5. 分类头 (Classification Head)
经过 Encoder 处理后,我们只需要关注 Class Token 的输出。将其取出,通过一个线性层映射到类别数量,并使用 Softmax 得到概率分布。
self.mlp = nn.Sequential(
nn.Linear(self.hidden_d, out_d),
nn.Softmax(dim=-1)
)
完整的前向传播逻辑整合:
def forward(self, images):
n, c, w, h = images.shape
patches = images.reshape(n, self.n_patches ** 2, self.input_d)
tokens = self.linear_mapper(patches)
tokens = torch.stack([torch.vstack((self.class_token, tokens[i])) for i in range(len(tokens))])
tokens += get_positional_embeddings(self.n_patches ** 2 + 1, self.hidden_d).repeat(n, 1, 1)
out = tokens + self.msa(self.ln1(tokens))
out = out + self.enc_mlp(self.ln2(out))
out = out[:, 0]
return self.mlp(out)
6. 测试与验证
我们可以构建一个简单的测试脚本来验证维度变化是否符合预期。
input_shape = (1, 28, 28)
input_image_shape = (1, 1, 28, 28)
model = MyViT(input_shape=input_shape)
images = torch.rand(input_image_shape)
output = model(images)
print(f"Output shape: {output.shape}")
运行结果应显示输出张量形状为 (N, 10),代表每张图像属于 10 个类别的概率分布。
总结
Vision Transformer 通过将图像分解为 Patch 序列,成功地将 Transformer 架构引入计算机视觉领域。其核心优势在于全局感受野,能够捕捉长距离依赖关系。尽管 CNN 在局部特征提取上依然高效,但 ViT 在大规模数据集上的表现证明了其强大的泛化能力。本教程详细拆解了从 Patch Embedding、位置编码、多头注意力到分类头的全过程,为深入理解现代视觉模型奠定了坚实基础。