DiT 是基于 Transformer 架构的扩散模型。用于各种图像(SD3、FLUX 等)和视频(Sora 等)视觉生成任务。
DiT 证明了 Transformer 思想与扩散模型结合的有效性,并且还验证了 Transformer 架构在扩散模型上具备较强的 Scaling 能力,在稳步增大 DiT 模型参数量与增强数据质量时,DiT 的生成性能稳步提升。
其中最大的 DiT-XL/2 模型在 ImageNet 256x256 的类别条件生成上达到了当时的 SOTA【最先进的(State Of The Art)】(FID 为 2.27)性能。同时在 SD3 和 FLUX.1 中也说明了较强的 Scaling 能力。
架构
DiT 架构如下所示:
图 3.扩散 Transformer(DiT)架构。左:我们训练条件潜在 DiT 模型。输入的潜在被分解成补丁和处理的几个 DiT 块。右图:DiT 区块的详细信息。我们用标准 Transformer 块的变体进行了实验,这些块通过自适应层归一化、交叉注意和额外输入的令牌(上下文环境)来进行调节。自适应层规范效果最好。
左侧主要架构图:训练条件潜在 DiT 模型 (conditional latent DiT models),潜在输入和条件被分解成 patch 并结合条件信息通过几个 DiT blocks 处理。本质就是噪声图片减掉预测的噪声以实现逐步复原。
DiT blocks 前:比如当输入是一张 256x256x3 的图片,得到 32x32x4 的 Noised Latent,之后进行 Patch 和位置编码,结合当前的 Timestep t、Label y 作为输入。
DiT blocks 后:经过 N 个 Dit Block(基于 transformer) 通过 MLP 进行输出,在 DiT 模型的最后一个 Transformer 块(DiT block)之后,需要将生成的图像 token 序列解码为以下两项输出:噪声'Noise 预测'以及对应的协方差矩阵,最后经过 T 个 step 采样,得到 32x32x4 的降噪后的 latent。
右侧 DiT Block 实现方式:DiT blocks 的细节,作者试验了标准 transformer 块的变体,这些变体通过自适应层归一化、交叉注意和额外输入 token 来加入条件 (incorporate conditioning via adaptive layer norm, cross-attention and extra input tokens,这个 conditioning 相当于就是带条件的去噪),其中自适应层归一化效果最好。
下文将按照这个架构进行阐述,从左到右。
与传统 (U-Net) 扩散模型区别
架构
DiT 将扩散模型中经典的 U-Net 架构完全替换成了 Transformer 架构。能够高效地捕获数据中的依赖关系并生成高质量的结果。
噪声调度策略
DiT 扩散过程的采用简单的 Linear scheduler(timesteps=1000,beta_start=0.0001,beta_end=0.02)。在传统的 U-Net 扩散模型(SD)中,所采用的 noise scheduler 通常是 Scaled Linear scheduler。
在图像领域使用 Transformer,首先想到的模型就是 ViT(参考:万字长文解读深度学习——ViT、ViLT),和 ViT 一样,DiT 也需要经过 Patch 和位置编码。
Patch 化
DiT 和 ViT 一样,首先采用一个 Patch Embedding 来将输入图像 Patch 化,主要作用是将 VAE 编码后的二维特征转化为一维序列,从而得到一系列的图像 tokens,ViT 具体如下图所示:
DiT 在这个图像 Patch 化的过程中,设计了 patch size 超参数,它直接决定了图像 tokens 的大小和数量,从而影响 DiT 模型的整体计算量。DiT 论文中共设置了三种 patch size,分别是 2, 4, 8。patch size 为 2*2 是最理想的。(结论来自:视频生成 Sora 的全面解析:从 AI 绘画、ViT 到 ViViT、TECO、DiT、VDT、NaViT 等)
Latent Diffusion Transformer 结构中,输入的图像在经过 VAE 编码器处理后,生成一个 Latent 特征,Patchify 的目的是将 Latent 特征转换成一系列 T 个 token(将 Latent 特征进行 Patch 化),每个 token 的维度为 d。Patchify 创建的 token 数量 T 由补丁大小超参数 p 决定。如下图所示,将 p 减半会使 T 增加四倍,因此至少使整个 transformer Gflops 增加四倍。
图 4. DiT 的输入规格。给定 patch size 是 p × p,空间表示(来自 VAE 的加噪潜变量),其形状为 I × I × C,会被'划分成补丁'(patchified)为一个长度为 T = (I/p)^2 的序列,隐藏维度为 d。较小的补丁大小 p 会导致序列长度更长,因此需要更多的计算量(以 Gflops 表示)。
ViT(vision transformer)采用的是 2D Frequency Embeddings(两个 1D Frequency Embeddings 进行 concat 操作),详情请参考:深度学习——3 种常见的 Transformer 位置编码【sin/cos、基于频率的二维位置编码(2D Frequency Embeddings)、RoPE】
DiT Block 模块详细信息
DiT 在完成输入图像的预处理后,就要将 Latent 特征输入到 DiT Block 中进行特征的提取了,与 ViT 不同的是,DiT 作为扩散模型还需要在 Backbone(主干)网络中嵌入额外的条件信息(不同模态的条件信息等),这里的条件信息就包括了 Timesteps 以及类别标签(文本信息)。
DiT 中的 Backbone 网络进行了两个主要工作:
常规的特征提取抽象;
对图像和特征额外的多模态条件特征进行融合。
额外信息都可以采用一个 Embedding 来进行编码,从而注入 DiT 中。DiT 论文中为了增强特征融合的性能,一共设计了四种【三种】方案来实现两个额外 Embeddings 的嵌入 (说白了,就是怎么加入 conditioning)。实现方式如下图模型架构的后半部分(红框):
提取条件信息、缩放参数 α:从输入的条件(如 Text Embeddings、标签等)中提取信息,一般来说会专门使用一个神经网络模块(比如全连接层等)来处理输入条件,并生成与输入数据相对应的缩放和偏移参数。
在 DiT 的官方实现中,使用了一个全连接层+SiLU 激活函数来实现这样一个输入条件的特征提取网络: