Vision Transformer 全面代码解析

Vision Transformer 全面代码解析

Vision Transformer 全面代码解析

原创 大厂小僧  2024年08月17日 14:25北京

在这篇文章中,我们将深入探讨Vision Transformer(ViT)的代码实现细节。对于那些对ViT理论基础还不太熟悉的朋友,我推荐你们观看李宏毅教授的相关视频,或者阅读下面推荐的几篇博客。在这里,我将主要专注于代码层面的解析,不再赘述理论部分。

www.zeeklog.com - Vision Transformer 全面代码解析

1. 代码解析

先看下VIT的网络结构,如下图:

www.zeeklog.com - Vision Transformer 全面代码解析

从上方的架构图中,我们可以清晰地看到Vision Transformer(ViT)由三个核心组件构成:Patch Embedding、Transformer Encoder以及MLP Head。接下来,我们将逐步剖析这些组件的代码实现。首先,我们会理解每个单独模块的功能,随后整合这些模块,深入了解整个ViT架构的工作原理。

1.1 DropPath 模块

def drop_path(x, drop_prob: float = 0., training: bool = False):

在介绍Encoder Block中的DropPath之前,让我们先澄清一点关于DropPath的描述。实际上,DropPath并不是“删除”整个分支,而是随机地使特定路径上的特征值变为零从而在前向传播过程中有效地“丢弃”那些路径。与Dropout不同,Dropout是随机使神经元的输出变为零,而DropPath则是针对网络中的路径进行操作。

DropPath通常只在训练阶段使用,而在验证和测试阶段则不应用。这种做法与Dropout类似,其目的是为了在训练过程中引入噪声,以提高模型的泛化能力。

在训练阶段,DropPath通过随机“丢弃”网络中的某些路径来工作,这意味着在前向传播过程中,这些路径的输出会被设置为零。这样做有助于防止模型过于依赖特定的路径或特征,从而增强了模型的鲁棒性和泛化性能。

然而,在验证和测试阶段,我们希望模型能够利用所有可用的信息来进行预测,因此不应用DropPath。在这些阶段,模型的输出是基于所有路径的贡献,而不是被随机“丢弃”了一些路径的情况。

总结来说,DropPath是一种正则化技术,仅在训练阶段使用,以提高模型的泛化能力。在验证和测试阶段,则不应用DropPath,以确保模型能够充分利用所有路径进行准确的预测。

1.2 Patch Embeding

class PatchEmbed(nn.Module):

当我们输入一个大小为(1,3,224,224)的图像矩阵到Patch Embeding模块后,输出的结果大小为[1, 196, 768]。这意味着原始的二维图像矩阵已经被转换成了一系列的一维向量。这样的转换是必要的,因为Transformer模型只能处理一维的序列数据。通过这种方式,我们可以将图像数据建模为序列数据,进而利用Transformer的强大能力进行进一步的处理和分析。

1.3.Multi-Head Attention

class Attention(nn.Module):

Vision Transformer (ViT) 的核心确实是 Transformer 架构中的注意力机制。具体来说,ViT 使用了多头自注意力 (Multi-Head Self-Attention, MHSA) 机制,这是 Transformer 模型中最关键的部分之一。

在 ViT 中,图像被分割成一系列的 patches,并且每个 patch 被展平并映射到一个高维空间中。然后,这些 patches 被视为序列输入到 Transformer 中。多头自注意力机制能够捕捉到这些 patches 之间的复杂关系,从而实现对图像的有效建模。

注意力机制允许模型在处理输入序列时,关注到最重要的部分,而多头自注意力则通过多个独立的注意力头来同时关注不同的特征子空间,提高了模型的表达能力。每个注意力头都会计算键 (keys)、查询 (queries) 和值 (values),并根据它们之间的相似性分配权重,从而生成加权和作为输出。

1.4.MLP

除了注意力机制之外,Transformer 还包括了前馈神经网络 (Feed-Forward Networks, FFN) 和层归一化 (Layer Normalization) 等组件,这些组件共同作用,使得 ViT 能够在图像识别和其他计算机视觉任务中表现出色。

www.zeeklog.com - Vision Transformer 全面代码解析
class Mlp(nn.Module):

1.5. Encoder Block

上面我们简单浏览了MLP的代码实现,由于其相对简单,我们不再做过多解释。接下来,我们将重点分析“Encodef Block”这一模块,它是构建Transformer Encoder的关键单元。ViT Base通过将Block模块重复堆叠12次,我们便能得到完整的Encoder结构。下面是Block模块的详细内容:

www.zeeklog.com - Vision Transformer 全面代码解析
class Block(nn.Module):

1.6.VisionTransformer

在完成了所有必要模块的创建之后,我们现在要做的就是将它们组合起来,构建我们的VisionTransformer模型。以下是将这些模块整合在一起的代码实现:

class VisionTransformer(nn.Module):

2. 构建VIT模型

虽然我们已经完成了VisionTransformer的所有代码分析和搭建过程,但为了让模型更加易于使用和调用,我们还需要对其进行进一步的封装。下面是封装后的代码实现:

def vit_base_patch16_224_in21k(num_classes: int = 21843, has_logits: bool = True):

从上面的代码可以看到,总共5个模型,从上到下复杂的依次递增,上面介绍了一vit_base_patch16_224_in21k这个模型的创建配置参数,其他模型的参数大同小异。

3. 完整代码

"""

至此,我们已经全面分析了VIT框架的代码实现。非常欢迎各位专家提出宝贵的意见和建议。

Read more

【OpenClaw从入门到精通】第10篇:OpenClaw生产环境部署全攻略:性能优化+安全加固+监控运维(2026实测版)

【OpenClaw从入门到精通】第10篇:OpenClaw生产环境部署全攻略:性能优化+安全加固+监控运维(2026实测版)

摘要:本文聚焦OpenClaw从测试环境走向生产环境的核心痛点,围绕“性能优化、安全加固、监控运维”三大维度展开实操讲解。先明确生产环境硬件/系统选型标准,再通过硬件层资源管控、模型调度策略、缓存优化等手段提升响应速度(实测响应效率提升50%+);接着从网络、权限、数据三层构建安全防护体系,集成火山引擎安全方案拦截高危操作;最后落地TenacitOS可视化监控与Prometheus告警体系,配套完整故障排查清单和虚拟实战案例。全文所有配置、代码均经实测验证,兼顾新手入门实操性和进阶读者的生产级部署需求,帮助开发者真正实现OpenClaw从“能用”到“放心用”的跨越。 优质专栏欢迎订阅! 【DeepSeek深度应用】【Python高阶开发:AI自动化与数据工程实战】【YOLOv11工业级实战】 【机器视觉:C# + HALCON】【大模型微调实战:平民级微调技术全解】 【人工智能之深度学习】【AI 赋能:Python 人工智能应用实战】【数字孪生与仿真技术实战指南】 【AI工程化落地与YOLOv8/v9实战】【C#工业上位机高级应用:高并发通信+性能优化】 【Java生产级避坑指南:

By Ne0inhk
ARM Linux 驱动开发篇--- Linux 并发与竞争实验(互斥体实现 LED 设备互斥访问)--- Ubuntu20.04互斥体实验

ARM Linux 驱动开发篇--- Linux 并发与竞争实验(互斥体实现 LED 设备互斥访问)--- Ubuntu20.04互斥体实验

🎬 渡水无言:个人主页渡水无言 ❄专栏传送门: 《linux专栏》《嵌入式linux驱动开发》《linux系统移植专栏》 ❄专栏传送门: 《freertos专栏》《STM32 HAL库专栏》 ⭐️流水不争先,争的是滔滔不绝  📚博主简介:第二十届中国研究生电子设计竞赛全国二等奖 |国家奖学金 | 省级三好学生 | 省级优秀毕业生获得者 | ZEEKLOG新星杯TOP18 | 半导纵横专栏博主 | 211在读研究生 在这里主要分享自己学习的linux嵌入式领域知识;有分享错误或者不足的地方欢迎大佬指导,也欢迎各位大佬互相三连 目录 前言  一、实验基础说明 1.1、互斥体简介 1.2 本次实验设计思路 二、硬件原理分析(看过之前博客的可以忽略) 三、实验程序编写 3.1 互斥体 LED 驱动代码(mutex.c) 3.2.1、设备结构体定义(28-39

By Ne0inhk
Flutter for OpenHarmony:swagger_dart_code_generator 接口代码自动化生成的救星(OpenAPI/Swagger) 深度解析与鸿蒙适配指南

Flutter for OpenHarmony:swagger_dart_code_generator 接口代码自动化生成的救星(OpenAPI/Swagger) 深度解析与鸿蒙适配指南

欢迎加入开源鸿蒙跨平台社区:https://openharmonycrossplatform.ZEEKLOG.net 前言 后端工程师扔给你一个 Swagger (OpenAPI) 文档地址,你会怎么做? 1. 对着文档,手写 Dart Model 类(容易写错字段类型)。 2. 手写 Retrofit/Dio 的 API 接口定义(容易拼错 URL)。 3. 当后端修改了字段名,你对着报错修半天。 这是重复劳动的地狱。 swagger_dart_code_generator 可以将 Swagger (JSON/YAML) 文件直接转换为高质量的 Dart 代码,包括: * Model 类:支持 json_serializable,带 fromJson/

By Ne0inhk
Linux 开发别再卡壳!makefile/git/gdb 全流程实操 + 作业解析,新手看完直接用----《Hello Linux!》(5)

Linux 开发别再卡壳!makefile/git/gdb 全流程实操 + 作业解析,新手看完直接用----《Hello Linux!》(5)

文章目录 * 前言 * make/makefile * 文件的三个时间 * Linux第一个小程序-进度条 * 回车和换行 * 缓冲区 * 程序的代码展示 * git指令 * 关于gitee * Linux调试器-gdb使用 * 作业部分 前言 做 Linux 开发时,你是不是也遇到过这些 “卡脖子” 时刻?写 makefile 时,明明语法没错却报错,最后发现是依赖方法行没加 Tab;想提交代码到 gitee,记不清 git add/commit/push 的 “三板斧”,还得反复搜教程;用 gdb 调试程序,输了命令没反应,才想起编译时没加-g生成 debug 版本;甚至连写个进度条,都搞不懂\r和\n的区别,导致进度条乱跳…… 其实这些问题,

By Ne0inhk