一、Swin Transformer 的基础原理
1. Transformer 在视觉任务中的挑战
Transformer 最初是为自然语言处理设计的,通过自注意力机制捕捉序列中的长程依赖关系。然而,将其直接应用于视觉任务时,会遇到以下问题:
- 计算复杂度过高:自注意力机制的计算复杂度为 $O(N^2)$,其中 $N$ 是序列长度。对于高分辨率图像,像素数量巨大,导致计算量不可接受。
- 局部特征提取不足:图像具有天然的局部相关性,而 Transformer 的全局自注意力机制无法像 CNN 那样高效捕捉局部模式。
为了解决这些问题,Swin Transformer 引入了窗口化自注意力和移位窗口机制,在保持 Transformer 优势的同时,显著降低了计算复杂度并增强了局部建模能力。
2. Swin Transformer 的核心思想
Swin Transformer 的核心在于将图像划分为不重叠的局部窗口,并在窗口内计算自注意力,从而将计算复杂度从全局的二次方降低为线性的。同时,通过移位窗口(Shifted Window)机制,在不同层之间建立跨窗口的连接,确保信息在全局范围内的流动。
主要特点包括:
- 层级结构:采用类似 CNN 的层级设计,逐步降低分辨率并增加通道数。
- 线性复杂度:基于窗口划分,计算复杂度随图像尺寸线性增长。
- 通用性:可作为骨干网络用于分类、检测、分割等多种任务。
二、Swin Transformer 的架构
1. Patch Embedding
将输入图像划分为固定大小的非重叠补丁(Patch),并将每个补丁展平为一个向量,通过线性投影映射到嵌入空间。这一步类似于 CNN 中的卷积操作,但使用的是全连接层。
2. 位置编码
由于 Transformer 本身不具备位置感知能力,Swin Transformer 使用相对位置偏置(Relative Position Bias)来编码窗口内的相对位置信息,帮助模型理解空间结构。
3. Swin Transformer Block
每个 Block 包含两个主要组件:窗口多头自注意力(W-MSA)和移位窗口多头自注意力(SW-MSA),中间穿插多层归一化(LayerNorm)和前馈神经网络(FFN)。
4. 窗口化自注意力(W-MSA)
在每个局部窗口内独立计算自注意力。假设窗口大小为 $M \times M$,则计算复杂度为 $O(M^2 \cdot C)$,与图像总尺寸无关,仅取决于窗口大小。
5. 移位窗口自注意力(SW-MSA)
为了打破窗口间的隔离,将窗口进行移位操作,使得相邻窗口的元素能够交互。这允许信息在深层网络中传播,实现全局感受野。
6. Patch Merging
在下采样阶段,将相邻的 $2 \times 2$ 个补丁合并为一个,减少特征图的空间分辨率,同时增加通道维度。这有助于构建层级特征表示。
7. 整体架构
整体架构由多个阶段组成,每个阶段包含若干个 Swin Transformer Block。随着层数加深,特征图分辨率降低,通道数增加,最终通过全局平均池化和全连接层输出分类结果。
三、代码实现详解
1. ShiftWindowAttentionBlock
1.1 初始化
定义窗口大小、注意力头数、隐藏维度等超参数。初始化相对位置偏置表。


