论文
Scalable Diffusion Models with Transformers
定义
DiT 是基于 Transformer 架构的扩散模型。用于各种图像(SD3、FLUX 等)和视频(Sora 等)视觉生成任务。
DiT 证明了 Transformer 思想与扩散模型结合的有效性,并且还验证了 Transformer 架构在扩散模型上具备较强的 Scaling 能力,在稳步增大 DiT 模型参数量与增强数据质量时,DiT 的生成性能稳步提升。
其中最大的 DiT-XL/2 模型在 ImageNet 256x256 的类别条件生成上达到了当时的 SOTA【最先进的(State Of The Art)】(FID 为 2.27)性能。同时在 SD3 和 FLUX.1 中也说明了较强的 Scaling 能力。
架构
DiT 架构如下所示:
图 3.扩散 Transformer(DiT)架构。左:我们训练条件潜在 DiT 模型。输入的潜在被分解成补丁和处理的几个 DiT 块。右图:DiT 区块的详细信息。我们用标准 Transformer 块的变体进行了实验,这些块通过自适应层归一化、交叉注意和额外输入的令牌(上下文环境)来进行调节。自适应层规范效果最好。
- 左侧主要架构图:训练条件潜在 DiT 模型 (conditional latent DiT models),潜在输入和条件被分解成 patch 并结合条件信息通过几个 DiT blocks 处理。本质就是噪声图片减掉预测的噪声以实现逐步复原。
- DiT blocks 前:比如当输入是一张 256x256x3 的图片,得到 32x32x4 的 Noised Latent,之后进行 Patch 和位置编码,结合当前的 Timestep t、Label y 作为输入。
- DiT blocks 后:经过 N 个 Dit Block(基于 transformer) 通过 MLP 进行输出,在 DiT 模型的最后一个 Transformer 块(DiT block)之后,需要将生成的图像 token 序列解码为以下两项输出:噪声'Noise 预测'以及对应的协方差矩阵,最后经过 T 个 step 采样,得到 32x32x4 的降噪后的 latent。
- 右侧 DiT Block 实现方式:DiT blocks 的细节,作者试验了标准 transformer 块的变体,这些变体通过自适应层归一化、交叉注意和额外输入 token 来加入条件 (incorporate conditioning via adaptive layer norm, cross-attention and extra input tokens,这个 conditioning 相当于就是带条件的去噪),其中自适应层归一化效果最好。
下文将按照这个架构进行阐述,从左到右。
与传统 (U-Net) 扩散模型区别
架构
- DiT 将扩散模型中经典的 U-Net 架构完全替换成了 Transformer 架构。能够高效地捕获数据中的依赖关系并生成高质量的结果。
噪声调度策略
- DiT 扩散过程的采用简单的 Linear scheduler(timesteps=1000,beta_start=0.0001,beta_end=0.02)。在传统的 U-Net 扩散模型(SD)中,所采用的 noise scheduler 通常是 Scaled Linear scheduler。
TODO:【也有说在传统的 U-Net 扩散模型(SD)中,所采用的 noise scheduler 是带调优参数后的线性调度器(Linear Scheduler)。】
与传统扩散的相同
- DiT 的整体框架并没有采用常规的 Pixel Diffusion(像素扩散)架构,而是使用和 Stable Diffusion 相同的 Latent Diffusion(潜变量扩散)架构,使用了和 SD 一样的 VAE 模型将像素级图像压缩到低维 Latent 特征。这极大地降低了扩散模型的计算复杂度(减少 Transformer 的 token 的数量)。
输入图像的 Patch 化(Patchify)和位置编码
在图像领域使用 Transformer,首先想到的模型就是 ViT(参考:万字长文解读深度学习——ViT、ViLT),和 ViT 一样,DiT 也需要经过 Patch 和位置编码。
Patch 化
DiT 和 ViT 一样,首先采用一个 Patch Embedding 来将输入图像 Patch 化,主要作用是将 VAE 编码后的二维特征转化为一维序列,从而得到一系列的图像 tokens,ViT 具体如下图所示:
DiT 在这个图像 Patch 化的过程中,设计了 patch size 超参数,它直接决定了图像 tokens 的大小和数量,从而影响 DiT 模型的整体计算量。DiT 论文中共设置了三种 patch size,分别是 2, 4, 8。patch size 为 2*2 是最理想的。(结论来自:视频生成 Sora 的全面解析:从 AI 绘画、ViT 到 ViViT、TECO、DiT、VDT、NaViT 等)
Latent Diffusion Transformer 结构中,输入的图像在经过 VAE 编码器处理后,生成一个 Latent 特征,Patchify 的目的是将 Latent 特征转换成一系列 T 个 token(将 Latent 特征进行 Patch 化),每个 token 的维度为 d。Patchify 创建的 token 数量 T 由补丁大小超参数 p 决定。如下图所示,将 p 减半会使 T 增加四倍,因此至少使整个 transformer Gflops 增加四倍。
图 4. DiT 的输入规格。给定 patch size 是 p × p,空间表示(来自 VAE 的加噪潜变量),其形状为 I × I × C,会被'划分成补丁'(patchified)为一个长度为 T = (I/p)^2 的序列,隐藏维度为 d。较小的补丁大小 p 会导致序列长度更长,因此需要更多的计算量(以 Gflops 表示)。
位置编码
在执行 patchify 之后,我们对所有输入 token 应用标准的 ViT 基于频率的位置嵌入(正弦 - 余弦版本)。图像 tokens 后,还要加上 Positional Embeddings 进行位置标记,DiT 中采用经典的非学习 sin&cosine 位置编码技术。
ViT(vision transformer)采用的是 2D Frequency Embeddings(两个 1D Frequency Embeddings 进行 concat 操作),详情请参考:深度学习——3 种常见的 Transformer 位置编码【sin/cos、基于频率的二维位置编码(2D Frequency Embeddings)、RoPE】
DiT Block 模块详细信息
DiT 在完成输入图像的预处理后,就要将 Latent 特征输入到 DiT Block 中进行特征的提取了,与 ViT 不同的是,DiT 作为扩散模型还需要在 Backbone(主干)网络中嵌入额外的条件信息(不同模态的条件信息等),这里的条件信息就包括了 Timesteps 以及类别标签(文本信息)。
DiT 中的 Backbone 网络进行了两个主要工作:
- 常规的特征提取抽象;
- 对图像和特征额外的多模态条件特征进行融合。
额外信息都可以采用一个 Embedding 来进行编码,从而注入 DiT 中。DiT 论文中为了增强特征融合的性能,一共设计了四种【三种】方案来实现两个额外 Embeddings 的嵌入 (说白了,就是怎么加入 conditioning)。实现方式如下图模型架构的后半部分(红框):
【有的文章表示设计了四种,其实是将 AdaLN 和 AdaLN-Zero 分为两种,这里按照论文中图进行解释,分为三种】
Diffusion Transformer 模型架构图中由右到左的顺序分别是:
- 上下文条件(In-context conditioning)
- 交叉注意力块(Cross-Attention)
- 自适应层归一化块(Adaptive Layer Normalization, AdaLN)
下面将按顺序详细介绍。
上下文条件化
实现机制:
- 在上下文条件化中,条件信息 (t)(时间步嵌入)和 (c)(其他条件,如类别或文本嵌入)被表示为两个独立的嵌入向量。
- 这些向量被附加到输入图像 token 序列的开头,形成一个扩展后的输入序列。
- Transformer 模块对 (t) 和 (c) 的嵌入与图像 tokens 一视同仁,这些条件化 tokens 通过多头自注意力机制与图像 token 一起参与信息交换。
这与 ViT 中的 cls tokens 类似,它允许我们无需修改就使用标准的 ViT 模块。在最后一个模块之后,我们从序列中移除条件化 tokens。这种方法对模型的新 Gflops 增加可以忽略不计。
交叉注意力模块
实现机制:
- 条件信息 (t)(时间步嵌入)和 (c)(其他条件,如类别或文本嵌入)被拼接(concat)为一个长度为 2 的序列,与图像 token 序列分开。
- Transformer 模块被修改为在多头自注意力模块后添加一个多头交叉注意力层,专门用于让图像 token 与条件 token 进行交互,从而将条件信息显式注入到图像特征中。
- 图像特征作为 Cross Attention 机制的查询(Query)。
- 条件信息的 Embeddings 作为 Cross Attention 机制的键(Key)和值(Value)。
这种方式是 Stable Diffusion 等文生图大模型常用的方式,交叉注意力对模型的 Gflops 增加最多,大约增加了 15% 的开销。
adaLN-Zero 模块
首先需要了解什么是 Adaptive Layer Normalization(AdaLN),而 AdaLN 之前又要先知道 LN,下面将一步步优化讲解:
Layer Normalization(LN)
首先在理解 AdaLN 之前,我们先简单回顾一下 Layer Normalization。
其他归一化方法参考:深度学习——优化算法、激活函数、归一化、正则化
- 层归一化(Layer Normalization, LN) 是 Transformer 中的一个关键组件,其作用是对输入的每个特征维度归一化,从而稳定训练和加速收敛。
Layer Normalization 的处理步骤主要分成下面的三步:
- 计算输入权重的均值和标准差:计算模型每一层输入权重的均值和标准差。
- 对输入权重进行标准化:使用计算得到的均值和标准差将输入权重标准化,使其均值为 0,标准差为 1。
- 对输入权重进行缩放和偏移:使用可学习的缩放参数和偏移参数,对标准化后的输入权重进行线性变换,使模型能够拟合任意的分布。
LN 的公式:
LN ( x ) = γ ⋅ (x − μ) / σ + β
其中:
- x:输入特征。
- μ, σ:输入的均值和标准差(按特征维度计算)。
- γ, β:可学习的缩放和偏移参数,用于调整归一化后的分布。
在条件生成任务(如扩散模型)中,需要让条件信息(如时间步 (t) 或类别 (c))对生成过程产生影响。为此,传统的 LN 被改进为自适应层归一化(adaLN),可以动态调整归一化的参数以包含条件信息。
Adaptive Layer Normalization(AdaLN)
在 GANs 和具有 UNet 骨干的扩散模型中广泛使用自适应归一化层之后,探索用自适应层归一化(adaLN)替换 transformer 模块中的标准归一化层。adaLN 并不是直接学习维度规模的缩放和偏移参数 γ 和 β,而是从 t 和 c 的嵌入向量之和中回归得到它们。
AdaLN 的核心思想是根据输入的不同条件信息,自适应地调整 Layer Normalization 的缩放参数 γ 和偏移参数 β,增加的 Gflops 非常少,适合大规模任务。
adaLN 相比 LN 的改进:
- 不再直接使用固定的可学习参数 γ, β。
- 相反,AdaLN 仅通过动态生成缩放参数 γ_ada 和偏移参数 β_ada,以条件信息 c 为输入,用于调整 Layer Normalization 的行为:
AdaLN ( x , c ) = γ_ada ⋅ (x − μ) / (σ + ϵ) + β_ada
- γ_ada = f_γ(c):由条件信息 c 通过神经网络(如 MLP)生成的缩放参数。
- β_ada = f_β(c):由条件信息 c 通过神经网络生成的偏移参数。
说明:其中 f(c) 是一个小型神经网络(通常是一个多层感知机,MLP),以条件信息(如时间步 t 和类别 c 的嵌入向量)为输入,输出对应的 γ, β。
AdaLN 的核心步骤
AdaLN 的核心步骤包括以下三步【详细的步骤会在下一节 adaLN-Zero 的核心步骤总结】:
- 提取条件信息
从输入的条件(如 Text Embeddings、标签等)中提取信息,一般来说会专门使用一个神经网络模块(比如全连接层等)来处理输入条件,并生成与输入数据相对应的缩放参数 γ 和偏移参数 β。
- 生成自适应的缩放和偏移参数
利用提取的条件信息,生成自适应的缩放和偏移参数。假设输入条件为 c,经过一个神经网络模块(比如全连接层等)生成缩放参数和偏移参数如下:
γ_ada = f_γ(c), β_ada = f_β(c)
- 使用自适应参数
使用这些自适应参数对输入权重进行 Layer Normalization 处理:
AdaLN ( x , c ) = γ_ada ⋅ (x − μ) / (σ + ϵ) + β_ada
在我们探索的三种模块设计中,adaLN 增加的 Gflops 最少,因此是最计算高效的。它也是唯一一个限制对所有 tokens 应用相同函数的条件化机制。
adaLN-Zero
AdaLN-Zero 在 AdaLN 的基础上新增了残差缩放参数 α,用于动态控制残差路径的影响。通过将 α 初始化为零,模型的初始状态被设置为恒等函数,从而确保输出在训练初期的稳定性。这种设计显著提升了模型的训练稳定性和收敛速度。
之前的 ResNets 工作发现,将每个残差块初始化为恒等函数是有益的。例如,在监督学习环境中,将每个块中最后的批量归一化缩放因子 γ 零初始化可以加速大规模训练。
扩散 U-Net 模型使用了类似的初始化策略,在任何残差连接之前零初始化每个块中的最终卷积层。
我们探索了对 adaLN DiT 模块的修改,它做了同样的事情。除了回归 γ 和 β,我们还回归了在 DiT 模块内的任何残差连接之前作用的 dimension-wise 的缩放参数 α。初始化 MLP 以输出所有 α 为零向量;这将完整的 DiT 模块初始化为恒等函数。与标准的 adaLN 模块一样,adaLNZero 对模型的 Gflops 增加可以忽略不计。
adaLN-Zero 的核心步骤
下面将根据下图来阐述:
- AdaLN 有 4 个参数:γ_1, β_1, γ_2, β_2,分别用于自注意力和 MLP 模块的归一化操作,没有残差缩放参数。
- AdaLN-Zero 增加了 2 个参数:α_1, α_2,用于控制残差路径的输出,显著提升了训练稳定性和适应性,因此总共 6 个参数【如上图】。
- 如果任务需要更强的稳定性(如深层 Transformer 模型或大规模扩散模型训练),AdaLN-Zero 是更优的选择。
adaLN-Zero 的核心步骤包括以下三步,和 adaLN 的步骤相似,只不过需要在过程中加入维度缩放参数 α。
需要做的额外工作如下:在第一步中提取回归缩放参数 α 在第二步中生成自适应的缩放参数 α 在第三步中使用 α 残差路径进行控制
提取条件信息、缩放参数 α:从输入的条件(如 Text Embeddings、标签等)中提取信息,一般来说会专门使用一个神经网络模块(比如全连接层等)来处理输入条件,并生成与输入数据相对应的缩放和偏移参数。
在 DiT 的官方实现中,使用了一个全连接层+SiLU 激活函数来实现这样一个输入条件的特征提取网络:
nn.SiLU(),
nn.Linear(hidden_size, 6* hidden_size, bias=True))
同时,DiT 在每个残差模块之后还使用了一个回归缩放参数 α 来对权重进行缩放调整,这个 α 参数也是由上述条件特征提取网络提取的。上面的代码示例中和上图(DiT Block with adaLN-Zero)我们可以看到,adaLN_modulation 计算了 6 个变量 shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp,这 6 个变量分别对应了多头自注意力机制 (MSA) 的 AdaLN 的归一化参数与缩放参数(下图中的 β_1, γ_1, α_1)以及 MLP 模块的 AdaLN 的归一化参数与缩放参数(下图中的 β_2, γ_2, α_2)。
【在 DiT Block 中,MSA(多头自注意力模块)和 MLP(多层感知机模块)都需要分别进行一次 adaLN-Zero 归一化处理。每个模块的 Layer Normalization(LN)都会被替换为 adaLN-Zero,并且两者的归一化参数 (γ, β) 和残差路径缩放参数 (α) 是独立的(需要分别提取),具体如下:】
Transformer 模块由 MSA 和 MLP 两部分组成,而它们在功能上的分工导致必须对每部分单独设计对应的条件化机制。这是因为:MSA 的核心在于捕捉全局依赖关系,因此其动态参数 (β_1, γ_1, α_1) 主要控制全局特征的动态调整。MLP 的核心在于非线性特征提取,增强局部特征表达,因此其动态参数 (β_2, γ_2, α_2) 主要用于控制局部特征的动态变换。
- 生成自适应的缩放和偏移参数、缩放参数 α:
利用提取的条件信息,生成自适应的缩放和偏移参数。假设输入条件为 c,经过一个神经网络模块(比如全连接层等)生成缩放参数和偏移参数如下:
γ_ada = f_γ(c), β_ada = f_β(c)
为残差路径增加一个新的维度缩放参数 α,由条件信息动态生成:
α_ada = f_α(c)
初始化为零:在训练开始时,α_ada = 0,使得模块输出仅为主路径输出,实现恒等初始化。
- 使用自适应参数、缩放参数 α:
- 在 AdaLN 的基础上,加入 α_ada 对残差路径进行缩放控制:
AdaLN-Zero ( x , c ) = α_ada ⋅ (γ_ada ⋅ (x − μ) / (σ + ϵ) + β_ada)
- 残差路径的输出被动态调节:通过 α_ada 的逐步增加,残差路径的影响逐渐加强。
- 当 α_ada = 0 时,整个模块行为等效于恒等函数。
初始化 MLP 以输出所有 α 为零向量;这将完整的 DiT 模块初始化为恒等函数。adaLNZero 对模型的 Gflops 增加可以忽略不计,与标准的 adaLN 模块一样。
说明
- 公式 AdaLN-Zero ( x , c ) 描述的是残差路径的动态调整过程,输出为 Residual Path Output。
- 完整的模块输出是路径输出与残差路径输出的加权和(AdaLN 完整的模块输出换成对应公式即可):
Output = Main Path Output + AdaLN-Zero ( x , c )
- 这种联系确保了主路径与残差路径的协同作用,结合条件化调整和归一化机制,使模型更加稳定高效地处理生成任务。
DiT 中具体的初始化
设置如下所示:
- 对 DiT Block 中的 AdaLN 和 Linear 层均采用参数 0 初始化。
对于其它网络层参数,使用正态分布初始化和 xavier 初始化。
图 5.比较不同的条件反射策略。adaLN-Zero 在训练的各个阶段都优于交叉注意和情境条件反射。
DiT 论文中对四种方案进行了对比试验,发现采用 AdaLN-Zero 效果是最好的,所以 DiT 默认都采用这种方式来嵌入条件 Embeddings。与此同时,AdaLN-Zero 也成为了基于 DiT 架构的 AI 绘画大模型的必备策略。
U-ViT(U-Net Vision Transformer)
参考:此文 U-ViT 部分:视频生成 Sora 的全面解析:从 AI 绘画、ViT 到 ViViT、TECO、DiT、VDT、NaViT 等
DiT 和 U-ViT 的对比
| 特性 | DiT | U-ViT |
|---|
| 模型设计灵感 | 基于 ViT 的纯 Transformer 架构 | 结合 U-Net 和 ViT 的混合架构 |
| 网络结构 | 标准 Transformer 堆叠 | Encoder-Transformer-Decoder 框架 |
| 局部特征建模 | 依赖 Patch Embedding 和 MLP,局部建模较弱 | 使用 U-Net 的卷积模块,局部特征建模强 |
| 全局特征建模 | 完全由 Transformer 捕捉全局上下文信息 | 通过嵌入 ViT 增强全局建模能力 |
| 跳跃连接(Skip) | 无跳跃连接 | 具有跳跃连接,保留细粒度信息 |
| 输入表示 | Patch Embedding 序列化输入 | 原始图像直接输入 |
| 适用任务 | 高分辨率潜在空间生成任务 | 低分辨率生成任务 |
| 计算复杂度 | 随序列长度增加计算复杂度显著提升 | U-Net 局部操作高效,整体复杂度较低 |