Vision Transformer 全面代码解析
Vision Transformer 全面代码解析
原创 大厂小僧 2024年08月17日 14:25北京
在这篇文章中,我们将深入探讨Vision Transformer(ViT)的代码实现细节。对于那些对ViT理论基础还不太熟悉的朋友,我推荐你们观看李宏毅教授的相关视频,或者阅读下面推荐的几篇博客。在这里,我将主要专注于代码层面的解析,不再赘述理论部分。

1. 代码解析
先看下VIT的网络结构,如下图:

从上方的架构图中,我们可以清晰地看到Vision Transformer(ViT)由三个核心组件构成:Patch Embedding、Transformer Encoder以及MLP Head。接下来,我们将逐步剖析这些组件的代码实现。首先,我们会理解每个单独模块的功能,随后整合这些模块,深入了解整个ViT架构的工作原理。
1.1 DropPath 模块
def drop_path(x, drop_prob: float = 0., training: bool = False):在介绍Encoder Block中的DropPath之前,让我们先澄清一点关于DropPath的描述。实际上,DropPath并不是“删除”整个分支,而是随机地使特定路径上的特征值变为零,从而在前向传播过程中有效地“丢弃”那些路径。与Dropout不同,Dropout是随机使神经元的输出变为零,而DropPath则是针对网络中的路径进行操作。
DropPath通常只在训练阶段使用,而在验证和测试阶段则不应用。这种做法与Dropout类似,其目的是为了在训练过程中引入噪声,以提高模型的泛化能力。
在训练阶段,DropPath通过随机“丢弃”网络中的某些路径来工作,这意味着在前向传播过程中,这些路径的输出会被设置为零。这样做有助于防止模型过于依赖特定的路径或特征,从而增强了模型的鲁棒性和泛化性能。
然而,在验证和测试阶段,我们希望模型能够利用所有可用的信息来进行预测,因此不应用DropPath。在这些阶段,模型的输出是基于所有路径的贡献,而不是被随机“丢弃”了一些路径的情况。
总结来说,DropPath是一种正则化技术,仅在训练阶段使用,以提高模型的泛化能力。在验证和测试阶段,则不应用DropPath,以确保模型能够充分利用所有路径进行准确的预测。
1.2 Patch Embeding
class PatchEmbed(nn.Module):当我们输入一个大小为(1,3,224,224)的图像矩阵到Patch Embeding模块后,输出的结果大小为[1, 196, 768]。这意味着原始的二维图像矩阵已经被转换成了一系列的一维向量。这样的转换是必要的,因为Transformer模型只能处理一维的序列数据。通过这种方式,我们可以将图像数据建模为序列数据,进而利用Transformer的强大能力进行进一步的处理和分析。
1.3.Multi-Head Attention
class Attention(nn.Module):Vision Transformer (ViT) 的核心确实是 Transformer 架构中的注意力机制。具体来说,ViT 使用了多头自注意力 (Multi-Head Self-Attention, MHSA) 机制,这是 Transformer 模型中最关键的部分之一。
在 ViT 中,图像被分割成一系列的 patches,并且每个 patch 被展平并映射到一个高维空间中。然后,这些 patches 被视为序列输入到 Transformer 中。多头自注意力机制能够捕捉到这些 patches 之间的复杂关系,从而实现对图像的有效建模。
注意力机制允许模型在处理输入序列时,关注到最重要的部分,而多头自注意力则通过多个独立的注意力头来同时关注不同的特征子空间,提高了模型的表达能力。每个注意力头都会计算键 (keys)、查询 (queries) 和值 (values),并根据它们之间的相似性分配权重,从而生成加权和作为输出。
1.4.MLP
除了注意力机制之外,Transformer 还包括了前馈神经网络 (Feed-Forward Networks, FFN) 和层归一化 (Layer Normalization) 等组件,这些组件共同作用,使得 ViT 能够在图像识别和其他计算机视觉任务中表现出色。

class Mlp(nn.Module):1.5. Encoder Block
上面我们简单浏览了MLP的代码实现,由于其相对简单,我们不再做过多解释。接下来,我们将重点分析“Encodef Block”这一模块,它是构建Transformer Encoder的关键单元。ViT Base通过将Block模块重复堆叠12次,我们便能得到完整的Encoder结构。下面是Block模块的详细内容:

class Block(nn.Module):1.6.VisionTransformer
在完成了所有必要模块的创建之后,我们现在要做的就是将它们组合起来,构建我们的VisionTransformer模型。以下是将这些模块整合在一起的代码实现:
class VisionTransformer(nn.Module):2. 构建VIT模型
虽然我们已经完成了VisionTransformer的所有代码分析和搭建过程,但为了让模型更加易于使用和调用,我们还需要对其进行进一步的封装。下面是封装后的代码实现:
def vit_base_patch16_224_in21k(num_classes: int = 21843, has_logits: bool = True):从上面的代码可以看到,总共5个模型,从上到下复杂的依次递增,上面介绍了一vit_base_patch16_224_in21k这个模型的创建配置参数,其他模型的参数大同小异。
3. 完整代码
"""至此,我们已经全面分析了VIT框架的代码实现。非常欢迎各位专家提出宝贵的意见和建议。