DiT(Diffusion Transformer)详解:架构与核心模块分析
DiT 是基于 Transformer 架构的扩散模型,替代了传统 U-Net。它通过 Patchify 和位置编码处理输入,利用 AdaLN-Zero 模块高效注入条件信息(时间步、类别)。相比 U-ViT,DiT 采用纯 Transformer 堆叠,适合高分辨率潜在空间生成。论文验证了其 Scaling 能力,在 ImageNet 等任务上达到 SOTA。

DiT 是基于 Transformer 架构的扩散模型,替代了传统 U-Net。它通过 Patchify 和位置编码处理输入,利用 AdaLN-Zero 模块高效注入条件信息(时间步、类别)。相比 U-ViT,DiT 采用纯 Transformer 堆叠,适合高分辨率潜在空间生成。论文验证了其 Scaling 能力,在 ImageNet 等任务上达到 SOTA。


Scalable Diffusion Models with Transformers
DiT 是基于 Transformer 架构的扩散模型。用于各种图像(SD3、FLUX 等)和视频(Sora 等)视觉生成任务。
DiT 证明了 Transformer 思想与扩散模型结合的有效性,并且验证了 Transformer 架构在扩散模型上具备较强的 Scaling 能力。在稳步增大 DiT 模型参数量与增强数据质量时,DiT 的生成性能稳步提升。
其中最大的 DiT-XL/2 模型在 ImageNet 256x256 的类别条件生成上达到了当时的 SOTA(State Of The Art)性能(FID 为 2.27)。同时在 SD3 和 FLUX.1 中也说明了较强的 Scaling 能力。
DiT 架构如下所示:

图 3.扩散 Transformer(DiT)架构。左:我们训练条件潜在 DiT 模型。输入的潜在被分解成补丁和处理的几个 DiT 块。右图:DiT 区块的详细信息。我们用标准 Transformer 块的变体进行了实验,这些块通过自适应层归一化、交叉注意和额外输入的令牌(上下文环境)来进行调节。自适应层规范效果最好。
下文将按照这个架构进行阐述,从左到右。
TODO:【也有说在传统的 U-Net 扩散模型(SD)中,所采用的 noise scheduler 是带调优参数后的线性调度器(Linear Scheduler)。】
在图像领域使用 Transformer,首先想到的模型就是 ViT(参考:万字长文解读深度学习——ViT、ViLT),和 ViT 一样,DiT 也需要经过 Patch 和位置编码,如下图红框。

DiT 和 ViT 一样,首先采用一个 Patch Embedding 来将输入图像 Patch 化,主要作用是将 VAE 编码后的二维特征转化为一维序列,从而得到一系列的图像 tokens,ViT 具体如下图所示:

DiT 在这个图像 Patch 化的过程中,设计了 patch size 超参数,它直接决定了图像 tokens 的大小和数量,从而影响 DiT 模型的整体计算量。DiT 论文中共设置了三种 patch size,分别是 2, 4, 8。patch size 为 2*2 是最理想的。(结论来自:视频生成 Sora 的全面解析:从 AI 绘画、ViT 到 ViViT、TECO、DiT、VDT、NaViT 等)
Latent Diffusion Transformer 结构中,输入的图像在经过 VAE 编码器处理后,生成一个 Latent 特征,Patchify 的目的是将 Latent 特征转换成一系列 T 个 token(将 Latent 特征进行 Patch 化),每个 token 的维度为 d。Patchify 创建的 token 数量 T 由补丁大小超参数 p 决定。如下图所示,将 p 减半会使 T 增加四倍,因此至少使整个 transformer Gflops 增加四倍。具体流程如下图所示:

图 4. DiT 的输入规格。给定 patch size 是 p × p,空间表示(来自 VAE 的加噪潜变量),其形状为 I × I × C,会被'划分成补丁'(patchified)为一个长度为 T = (I/p)^2 的序列,隐藏维度为 d。较小的补丁大小 p 会导致序列长度更长,因此需要更多的计算量(以 Gflops 表示)。
在执行 patchify 之后,我们对所有输入 token 应用标准的 ViT 基于频率的位置嵌入(正弦 - 余弦版本)。图像 tokens 后,还要加上 Positional Embeddings 进行位置标记,DiT 中采用经典的非学习 sin&cosine 位置编码技术。
ViT(vision transformer)采用的是 2D Frequency Embeddings(两个 1D Frequency Embeddings 进行 concat 操作),详情请参考:深度学习——3 种常见的 Transformer 位置编码【sin/cos、基于频率的二维位置编码(2D Frequency Embeddings)、RoPE】
DiT 在完成输入图像的预处理后,就要将 Latent 特征输入到 DiT Block 中进行特征的提取了,与 ViT 不同的是,DiT 作为扩散模型还需要在 Backbone(主干)网络中嵌入额外的条件信息(不同模态的条件信息等),这里的条件信息就包括了 Timesteps 以及类别标签(文本信息)。
DiT 中的 Backbone 网络进行了两个主要工作:
额外信息都可以采用一个 Embedding 来进行编码,从而注入 DiT 中。DiT 论文中为了增强特征融合的性能,一共设计了四种【三种】方案来实现两个额外 Embeddings 的嵌入 (说白了,就是怎么加入 conditioning)。实现方式如下图模型架构的后半部分(红框):
【有的文章表示设计了四种,其实是将 AdaLN 和 AdaLN-Zero 分为两种,这里按照论文中图进行解释,分为三种】

Diffusion Transformer 模型架构图中由右到左的顺序分别是:
下面将按顺序详细介绍。
实现机制:
这与 ViT 中的 cls tokens 类似,它允许我们无需修改就使用标准的 ViT 模块。在最后一个模块之后,我们从序列中移除条件化 tokens。这种方法对模型的新 Gflops 增加可以忽略不计。
实现机制:
这种方式是 Stable Diffusion 等文生图大模型常用的方式,交叉注意力对模型的 Gflops 增加最多,大约增加了 15% 的开销。
首先需要了解什么是 Adaptive Layer Normalization(AdaLN),而 AdaLN 之前又要先知道 LN,下面将一步步优化讲解:
首先在理解 AdaLN 之前,我们先简单回顾一下 Layer Normalization。
其他归一化方法参考:深度学习——优化算法、激活函数、归一化、正则化
Layer Normalization 的处理步骤主要分成下面的三步:
LN 的公式:
LN(x) = γ · (x - μ) / σ + β
其中:
在条件生成任务(如扩散模型)中,需要让条件信息(如时间步 (t) 或类别 (c))对生成过程产生影响。为此,传统的 LN 被改进为自适应层归一化(adaLN),可以动态调整归一化的参数以包含条件信息。
在 GANs 和具有 UNet 骨干的扩散模型中广泛使用自适应归一化层之后,探索用自适应层归一化(adaLN)替换 transformer 模块中的标准归一化层。adaLN 并不是直接学习维度规模的缩放和偏移参数 γ 和 β,而是从 t 和 c 的嵌入向量之和中回归得到它们。
AdaLN 的核心思想是根据输入的不同条件信息,自适应地调整 Layer Normalization 的缩放参数 γ 和偏移参数 β,增加的 Gflops 非常少,适合大规模任务。
adaLN 相比 LN 的改进:
AdaLN(x, c) = γ_ada · (x - μ) / (σ + ϵ) + β_ada
AdaLN 的核心步骤包括以下三步【详细的步骤会在下一节 adaLN-Zero 的核心步骤总结】:
γ_ada = f_γ(c), β_ada = f_β(c)AdaLN(x, c) = γ_ada · (x - μ) / (σ + ϵ) + β_ada在我们探索的三种模块设计中,adaLN 增加的 Gflops 最少,因此是最计算高效的。它也是唯一一个限制对所有 tokens 应用相同函数的条件化机制。
AdaLN-Zero 在 AdaLN 的基础上新增了残差缩放参数 α,用于动态控制残差路径的影响。通过将 α 初始化为零,模型的初始状态被设置为恒等函数,从而确保输出在训练初期的稳定性。这种设计显著提升了模型的训练稳定性和收敛速度。
之前的 ResNets 工作发现,将每个残差块初始化为恒等函数是有益的。例如,在监督学习环境中,将每个块中最后的批量归一化缩放因子 γ 零初始化可以加速大规模训练。 扩散 U-Net 模型使用了类似的初始化策略,在任何残差连接之前零初始化每个块中的最终卷积层。 我们探索了对 adaLN DiT 模块的修改,它做了同样的事情。除了回归 γ 和 β,我们还回归了在 DiT 模块内的任何残差连接之前作用的 dimension-wise 的缩放参数 α。初始化 MLP 以输出所有 α 为零向量;这将完整的 DiT 模块初始化为恒等函数。与标准的 adaLN 模块一样,adaLNZero 对模型的 Gflops 增加可以忽略不计。
下面将根据下图来阐述:

adaLN-Zero 的核心步骤包括以下三步,和 adaLN 的步骤相似,只不过需要在过程中加入维度缩放参数 α。
需要做的额外工作如下:在第一步中提取回归缩放参数 α 在第二步中生成自适应的缩放参数 α 在第三步中使用 α 残差路径进行控制
提取条件信息、缩放参数 α:从输入的条件(如 Text Embeddings、标签等)中提取信息,一般来说会专门使用一个神经网络模块(比如全连接层等)来处理输入条件,并生成与输入数据相对应的缩放和偏移参数。 在 DiT 的官方实现中,使用了一个全连接层+SiLU 激活函数来实现这样一个输入条件的特征提取网络:
# 输入条件的特征提取网络
self.adaLN_modulation = nn.Sequential(
nn.SiLU(),
nn.Linear(hidden_size, 6* hidden_size, bias=True)
)
# c 代表输入的条件信息
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=1)
同时,DiT 在每个残差模块之后还使用了一个回归缩放参数 α 来对权重进行缩放调整,这个 α 参数也是由上述条件特征提取网络提取的。上面的代码示例中和上图(DiT Block with adaLN-Zero)我们可以看到,adaLN_modulation 计算了 6 个变量 shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp,这 6 个变量分别对应了多头自注意力机制 (MSA) 的 AdaLN 的归一化参数与缩放参数(下图中的 β_1, γ_1, α_1)以及 MLP 模块的 AdaLN 的归一化参数与缩放参数(下图中的 β_2, γ_2, α_2)。
【在 DiT Block 中,MSA(多头自注意力模块)和 MLP(多层感知机模块)都需要分别进行一次 adaLN-Zero 归一化处理。每个模块的 Layer Normalization(LN)都会被替换为 adaLN-Zero,并且两者的归一化参数 (γ, β) 和残差路径缩放参数 (α) 是独立的(需要分别提取),具体如下:】
Transformer 模块由 MSA 和 MLP 两部分组成,而它们在功能上的分工导致必须对每部分单独设计对应的条件化机制。这是因为:MSA 的核心在于捕捉全局依赖关系,因此其动态参数 (β_1, γ_1, α_1) 主要控制全局特征的动态调整。MLP 的核心在于非线性特征提取,增强局部特征表达,因此其动态参数 (β_2, γ_2, α_2) 主要用于控制局部特征的动态变换。
γ_ada = f_γ(c), β_ada = f_β(c)
为残差路径增加一个新的维度缩放参数 α,由条件信息动态生成:
α_ada = f_α(c)
初始化为零:在训练开始时,α_ada = 0,使得模块输出仅为主路径输出,实现恒等初始化。AdaLN-Zero(x, c) = α_ada · (γ_ada · (x - μ) / (σ + ϵ) + β_ada)
初始化 MLP 以输出所有 α 为零向量;这将完整的 DiT 模块初始化为恒等函数。adaLNZero 对模型的 Gflops 增加可以忽略不计,与标准的 adaLN 模块一样。
AdaLN-Zero(x, c) 描述的是残差路径的动态调整过程,输出为 Residual Path Output。Output = Main Path Output + AdaLN-Zero(x, c)设置如下所示:
对于其它网络层参数,使用正态分布初始化和 xavier 初始化。

图 5.比较不同的条件反射策略。adaLN-Zero 在训练的各个阶段都优于交叉注意和情境条件反射。
DiT 论文中对四种方案进行了对比试验,发现采用 AdaLN-Zero 效果是最好的,所以 DiT 默认都采用这种方式来嵌入条件 Embeddings。与此同时,AdaLN-Zero 也成为了基于 DiT 架构的 AI 绘画大模型的必备策略。
参考:此文 U-ViT 部分:视频生成 Sora 的全面解析:从 AI 绘画、ViT 到 ViViT、TECO、DiT、VDT、NaViT 等
| 特性 | DiT | U-ViT |
|---|---|---|
| 模型设计灵感 | 基于 ViT 的纯 Transformer 架构 | 结合 U-Net 和 ViT 的混合架构 |
| 网络结构 | 标准 Transformer 堆叠 | Encoder-Transformer-Decoder 框架 |
| 局部特征建模 | 依赖 Patch Embedding 和 MLP,局部建模较弱 | 使用 U-Net 的卷积模块,局部特征建模强 |
| 全局特征建模 | 完全由 Transformer 捕捉全局上下文信息 | 通过嵌入 ViT 增强全局建模能力 |
| 跳跃连接(Skip) | 无跳跃连接 | 具有跳跃连接,保留细粒度信息 |
| 输入表示 | Patch Embedding 序列化输入 | 原始图像直接输入 |
| 适用任务 | 高分辨率潜在空间生成任务 | 低分辨率生成任务 |
| 计算复杂度 | 随序列长度增加计算复杂度显著提升 | U-Net 局部操作高效,整体复杂度较低 |

微信公众号「极客日志」,在微信中扫描左侧二维码关注。展示文案:极客日志 zeeklog
使用加密算法(如AES、TripleDES、Rabbit或RC4)加密和解密文本明文。 在线工具,加密/解密文本在线工具,online
生成新的随机RSA私钥和公钥pem证书。 在线工具,RSA密钥对生成器在线工具,online
基于 Mermaid.js 实时预览流程图、时序图等图表,支持源码编辑与即时渲染。 在线工具,Mermaid 预览与可视化编辑在线工具,online
解析常见 curl 参数并生成 fetch、axios、PHP curl 或 Python requests 示例代码。 在线工具,curl 转代码在线工具,online
将字符串编码和解码为其 Base64 格式表示形式即可。 在线工具,Base64 字符串编码/解码在线工具,online
将字符串、文件或图像转换为其 Base64 表示形式。 在线工具,Base64 文件转换器在线工具,online