跳到主要内容 扩散模型详解:原理与代码实现 | 极客日志
Python AI 算法
扩散模型详解:原理与代码实现 扩散模型基于热力学非平衡过程,通过前向加噪和反向去噪生成数据。DDPM 框架利用 U-Net 预测噪声,训练目标为最小化噪声预测误差。本文介绍了扩散模型的理论基础、数学推导及 PyTorch 代码实现,涵盖数据集加载、模型构建、训练循环及采样流程,展示了在动漫人脸生成任务中的应用效果与优化方向。
1739658202 发布于 2025/2/6 更新于 2026/4/20 1 浏览1. 简介
扩散模型(Diffusion Model)的起源可以追溯到概率图模型和统计物理学领域。它最初的灵感来自于对热扩散和布朗运动等物理现象的研究,这些过程描述了系统如何从一个高能量、不均匀的状态逐步过渡到一个低能量、平衡的状态,后来被引入机器学习和生成模型领域。
扩散模型的早期形式可以追溯到多种研究工作,Sohl-Dickstein 等人(2015 年)在论文《Deep Unsupervised Learning using Nonequilibrium Thermodynamics》中首次将扩散过程引入深度生成模型的框架,奠定了扩散模型的理论基础。
近年来,扩散模型得到了快速发展。尤其是 Ho 等人(2020 年)在论文《Denoising Diffusion Probabilistic Models (DDPM)》中提出了一种高效的扩散模型框架,这一工作开辟了扩散模型的新方向,使其在图像生成、语音合成和其他生成任务中表现优异。DDPM 将扩散过程分为两个阶段,一是前向过程(Forward Process),向数据中逐步添加噪声,直到数据接近高斯分布。二是反向过程(Reverse Process),学习逐步去噪,最终从随机噪声中生成目标数据。
[图:随机噪声经过去噪模块生成猫图像的示意图]
就像是米开朗基罗说的:'塑像就在石头里,我只是把不需要的部分去掉'。
2. 原理
2.1 前向过程
前向过程也就是我们常说的扩散过程,它模拟真实数据逐渐被噪声污染的过程。其做法通常是从高斯分布里面采样一组噪声添加到正常的图片当中,产生有点噪声的图像,然后从高斯分布中再采样一次,再得到更加噪声的图片,以此类推,最后整张图片就看不出来原来是什么东西,也就是整张图片变成了一个接近高斯分布的噪声。做完这个扩散过程以后,就有去噪模块的训练数据了。
前向过程是一个固定的、不可学习的马尔可夫链。从初始数据分布 $x_0$ 开始,逐步向数据中添加高斯噪声,使数据分布逐渐接近标准高斯分布 $\mathcal{N}(0, I)$。公式表示为:
$$ q(x_t | x_{t-1}) = \mathcal{N}(x_t; \sqrt{1 - \beta_t} x_{t-1}, \beta_t I) $$
其中 $x_t$ 是当前时间步的数据,$x_{t-1}$ 是前一时间步的数据,$\beta_t$ 是一个预定义的时间步长参数,表示每一步添加噪声的强度。$I$ 是单位矩阵,表示每个数据维度上的噪声是独立且均匀的。$\mathcal{N}$ 是高斯分布,表示 $x_t$ 的条件分布。
添加噪声的过程是一个马尔可夫链,在每一步中,数据根据高斯分布从 $x_{t-1}$ 生成 $x_t$,噪声的均值为 $0$,方差为 $\beta_t$。前向过程将原始数据逐渐加噪,最终在 $T$ 步后,使其分布接近标准高斯分布 $\mathcal{N}(0, I)$。
2.2 反向过程
反向过程的目的学习如何从完全随机的噪声逐步还原出目标数据。反向过程是需要通过神经网络来学习的,该网络的输入是一张有噪声的图,输出是一张滤掉一点噪声的图像,去噪越做越多,最终就能看到一张清晰的图片。
[图:反向去噪过程示意图]
通常,这个去噪的模型里面实际上是一个噪声预测器(noise predictor),它会预测图片里面的噪声。这个噪声预测器的输入是去噪的图片和噪声现在的严重程度(也就是我们现在进行到去噪的第几个步骤的代号)。它预测在这张图片里面噪声应该长什么样子,再在去噪的图片中减去它预测的噪声,就产生去噪以后的结果,即输出一张噪声少了一点的图。
[图:噪声预测器结构示意]
要训练这样的噪声预测器,要用到之前我们在扩散过程中产生的训练数据。即,扩散过程中产生的一张加完噪声的图片跟现在是第几次加噪声,是网络的输入,而加入的这个噪声就是网络应该要预测的输出。
但有些时候我们不仅想要产生图片,还想产生与我们文字描述一样的图片,对于这样的情况,我们只需要在训练数据中增加对图片的描述,同时在去噪的每一个步骤中让噪声预测器多一个额外的输入,也就是描述的这段文字。
[图:条件扩散模型示意]
反向过程假设数据的逆演化仍然是一个马尔可夫过程,模型需要学习如何从 $x_t$ 预测 $x_{t-1}$,逐步还原出无噪声的原始数据:
$$ p_\theta(x_{t-1} | x_t) = \mathcal{N}(x_{t-1}; \mu_\theta(x_t, t), \Sigma_\theta(x_t, t)) $$
其中 $x_t$ 是当前时间步的带噪声数据,$x_{t-1}$ 是目标时间步的数据(去噪后)。$\mu_\theta$ 是由神经网络预测的均值。$\Sigma_\theta$ 是由神经网络预测的方差(通常固定为常数以简化计算)。通过训练神经网络,使其能够预测每一步中的噪声成分。
扩散模型的训练目标是通过最大化似然估计来优化反向过程。训练过程等价于一个降噪任务,模型学习在给定加噪数据 $x_t$ 的情况下,如何估计当前时间步的噪声,扩散模型的训练损失函数如下:
$$ L = \mathbb{E}{t, x_0, \epsilon} [ || \epsilon - \epsilon \theta(x_t, t) ||^2 ] $$
其中 $\epsilon$ 是前向过程中实际添加的噪声,$\epsilon_\theta$ 是模型预测的噪声。$L_t$ 是当前时间步的损失,表示实际噪声与预测噪声之间的均方误差。通过最小化这个损失,模型学习如何在每一步准确预测噪声,从而能够反向还原数据。
[图:损失函数计算示意]
3. 代码
下面我们以生成动漫人脸图像为目标来训练 Diffusion Model。
动漫人脸数据集下载链接:
https://www.kaggle.com/datasets/b07202024/diffusion/download?datasetVersionNumber=1
本代码遵循典型的 DDPM(Denoising Diffusion Probabilistic Model)框架,整体分为 U-Net 模型(用于去噪)、GaussianDiffusion 类(提供前向扩散和反向采样逻辑)以及数据集和训练器等部分。U-Net 负责在不同尺度下对图像特征进行编码与解码,以预测在每个时间步中加入的噪声;GaussianDiffusion 封装了核心公式与超参数,包括 beta 调度、采样/训练流程及损失函数;Trainer 则管理训练过程,如数据加载、梯度累积、EMA(指数移动平均)等。这种架构将'加噪'和'去噪'分离并封装在模型和调度器中,使得训练和推理流程更加清晰易懂,也能方便地进行扩展或替换不同的网络结构与超参数策略。
以下是完整代码(引自《李宏毅深度学习》):
import math
import copy
from pathlib import Path
from random import random
from functools import partial
from collections import namedtuple
from multiprocessing import cpu_count
import torch
from torch import nn, einsum
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.optim import Adam
import torchvision
from torchvision import transforms as T, utils
from einops import rearrange, reduce, repeat
from einops.layers.torch import Rearrange
from PIL import Image
tqdm.auto tqdm
ema_pytorch EMA
accelerate Accelerator
matplotlib.pyplot plt
os
torch.backends.cudnn.benchmark =
torch.manual_seed( )
torch.cuda.is_available():
torch.cuda.manual_seed( )
( ):
scale = / timesteps
beta_start = scale *
beta_end = scale *
torch.linspace(beta_start, beta_end, timesteps, dtype = torch.float64)
( ):
b, *_ = t.shape
out = a.gather(- , t)
out.reshape(b, *(( ,) * ( (x_shape) - )))
( ):
( ):
.folder = folder
.image_size = image_size
.paths = [p p Path( ).glob( )]
.transform = T.Compose([
T.Resize(image_size),
T.ToTensor()
])
( ):
( .paths)
( ):
path = .paths[index]
img = Image. (path)
.transform(img)
( ):
x
( ):
exists(val):
val
d() (d) d
( ):
t
( ):
:
data dl:
data
( ):
(math.sqrt(num) ** ) == num
( ):
groups = num // divisor
remainder = num % divisor
arr = [divisor] * groups
remainder > :
arr.append(remainder)
arr
( ):
img * -
( ):
(t + ) *
(nn.Module):
( ):
().__init__()
.fn = fn
( ):
.fn(x, *args, **kwargs) + x
( ):
nn.Sequential(
nn.Upsample(scale_factor = , mode = ),
nn.Conv2d(dim, default(dim_out, dim), , padding = )
)
( ):
nn.Sequential(
Rearrange( , p1 = , p2 = ),
nn.Conv2d(dim * , default(dim_out, dim), )
)
(nn.Conv2d):
( ):
eps = x.dtype == torch.float32
weight = .weight
mean = reduce(weight, , )
var = reduce(weight, , partial(torch.var, unbiased = ))
normalized_weight = (weight - mean) * (var + eps).rsqrt()
F.conv2d(x, normalized_weight, .bias, .stride, .padding, .dilation, .groups)
(nn.Module):
( ):
().__init__()
.g = nn.Parameter(torch.ones( , dim, , ))
( ):
eps = x.dtype == torch.float32
var = torch.var(x, dim = , unbiased = , keepdim = )
mean = torch.mean(x, dim = , keepdim = )
(x - mean) * (var + eps).rsqrt() * .g
(nn.Module):
( ):
().__init__()
.fn = fn
.norm = LayerNorm(dim)
( ):
x = .norm(x)
.fn(x)
(nn.Module):
( ):
().__init__()
.dim = dim
( ):
device = x.device
half_dim = .dim //
emb = math.log( ) / (half_dim - )
emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
emb = x[:, ] * emb[ , :]
emb = torch.cat((emb.sin(), emb.cos()), dim=- )
emb
(nn.Module):
following @crowsonkb b -> b d -> d b c -> b c b (h c) x y -> b h c (x y) b h d n, b h e n -> b h d e b h d e, b h d n -> b h e n b h c (x y) -> b (h c) x y b (h c) x y -> b h c (x y) b h d i, b h d j -> b h i j b h i j, b h d j -> b h i d b h (x y) d -> b (h d) x y linea linea linea unknown beta schedule {beta_schedule} linea betas alphas_cumprod alphas_cumprod_prev sqrt_alphas_cumprod sqrt_one_minus_alphas_cumprod log_one_minus_alphas_cumprod sqrt_recip_alphas_cumprod sqrt_recipm1_alphas_cumprod posterior_variance posterior_log_variance_clipped posterior_mean_coef1 posterior_mean_coef2 loss_weight sampling loop time step none b ... -> b (...) mean height width of image must be {img_size} ./faces/faces linea
accelerate 1.0.1
einops 0.8.0
ema-pytorch 0.7.7
matplotlib 3.5.1
multiprocess 0.70.15
numpy 1.24.4
python 3.8.19
pytorch 2.4.0
pytorch-cuda 12.1
tqdm 4.66.5
下图为模型在完成训练之后生成的动漫人脸图像:
[图:训练完成的动漫人脸生成示例]
从该结果可以看出,模型成功地学习到了二次元人脸的整体特征与色彩分布,生成的人像在发型、五官、配色等方面都有一定的多样性,说明扩散模型在此任务中具备一定的泛化能力。不过图像中仍存在一定程度的模糊、面部细节缺失或扭曲等现象,表明训练规模与网络容量可能还需要进一步优化,以获得更精细、更稳定的生成质量。
总结一下,扩散模型(Diffusion Model)通过在前向过程逐步向图像添加噪声、在反向过程逐步去噪的方式实现图像生成,具有相对稳定的训练过程和良好的生成多样性。它在高分辨率图像生成、条件生成(文本、语音、语义分割等)方面表现不错,且与自回归、GAN 等其他生成方法形成互补。未来发展方向包括更高效的采样策略、更灵活的条件控制、多尺度或多模态的融合,以及在更广泛的数据类型(视频、3 D 等)上的应用和研究。
相关免费在线工具 加密/解密文本 使用加密算法(如AES、TripleDES、Rabbit或RC4)加密和解密文本明文。 在线工具,加密/解密文本在线工具,online
RSA密钥对生成器 生成新的随机RSA私钥和公钥pem证书。 在线工具,RSA密钥对生成器在线工具,online
Mermaid 预览与可视化编辑 基于 Mermaid.js 实时预览流程图、时序图等图表,支持源码编辑与即时渲染。 在线工具,Mermaid 预览与可视化编辑在线工具,online
curl 转代码 解析常见 curl 参数并生成 fetch、axios、PHP curl 或 Python requests 示例代码。 在线工具,curl 转代码在线工具,online
Base64 字符串编码/解码 将字符串编码和解码为其 Base64 格式表示形式即可。 在线工具,Base64 字符串编码/解码在线工具,online
Base64 文件转换器 将字符串、文件或图像转换为其 Base64 表示形式。 在线工具,Base64 文件转换器在线工具,online
from
import
from
import
from
import
import
as
import
True
4096
if
4096
def
linear_beta_schedule
timesteps
"""
linear schedule, proposed in original ddpm paper
线性 Beta 时间调度函数,用于扩散模型(DDPM)中定义 Beta 参数在每个时间步的取值。
"""
1000
0.0001
0.02
return
def
extract
a, t, x_shape
"""
从向量 a 中取出与时间步 t 对应的值,并 reshape 成 x_shape 的形状。
"""
1
return
1
len
1
class
Dataset
Dataset
"""
自定义数据集,用于加载指定文件夹下的 .jpg 图像文件。
"""
def
__init__
self,
folder,
image_size
self
self
self
for
in
f'{folder} '
f'**/*.jpg'
self
def
__len__
self
return
len
self
def
__getitem__
self, index
self
open
return
self
def
exists
x
return
is
not
None
def
default
val, d
if
return
return
if
callable
else
def
identity
t, *args, **kwargs
return
def
cycle
dl
while
True
for
in
yield
def
has_int_squareroot
num
return
2
def
num_to_groups
num, divisor
if
0
return
def
normalize_to_neg_one_to_one
img
return
2
1
def
unnormalize_to_zero_to_one
t
return
1
0.5
class
Residual
def
__init__
self, fn
super
self
def
forward
self, x, *args, **kwargs
return
self
def
Upsample
dim, dim_out = None
return
2
'nearest'
3
1
def
Downsample
dim, dim_out = None
return
'b c (h p1) (w p2) -> b (c p1 p2) h w'
2
2
4
1
class
WeightStandardizedConv2d
"""
https://arxiv.org/abs/1903.10520
weight standardization purportedly works synergistically with group normalization
"""
def
forward
self, x
1e-5
if
else
1e-3
self
'o ... -> o 1 1 1'
'mean'
'o ... -> o 1 1 1'
False
return
self
self
self
self
self
class
LayerNorm
def
__init__
self, dim
super
self
1
1
1
def
forward
self, x
1e-5
if
else
1e-3
1
False
True
1
True
return
self
class
PreNorm
def
__init__
self, dim, fn
super
self
self
def
forward
self, x
self
return
self
class
SinusoidalPosEmb
def
__init__
self, dim
super
self
def
forward
self, x
self
2
10000
1
None
None
1
return
class
RandomOrLearnedSinusoidalPosEmb
""
's lead with random (learned optional) sinusoidal pos emb ""
"" https://github.com/crowsonkb/v-diffusion-jax/blob/master/diffusion/models/danbooru_128.py#L8 ""
def __init__(self, dim, is_random = False):
super().__init__()
assert (dim % 2) == 0
half_dim = dim // 2
self.weights = nn.Parameter(torch.randn(half_dim), requires_grad = not is_random)
# 如果 is_random 为 False,则这些频率是可学习的;否则是随机固定
def forward(self, x):
x = rearrange(x, '
1
')
# 将输入 x reshape 成 (batch, 1)
freqs = x * rearrange(self.weights, '
1
') * 2 * math.pi
# 计算随机或可学习的频率 freqs
fouriered = torch.cat((freqs.sin(), freqs.cos()), dim = -1)
# 拼接正弦和余弦
fouriered = torch.cat((x, fouriered), dim = -1)
# 再将原始 x 与正余弦部分合并
return fouriered
# 返回包含输入和正余弦编码的结果
# building block modules
class Block(nn.Module):
def __init__(self, dim, dim_out, groups = 8):
super().__init__()
self.proj = WeightStandardizedConv2d(dim, dim_out, 3, padding = 1)
self.norm = nn.GroupNorm(groups, dim_out)
self.act = nn.SiLU()
# 一个卷积 -> GroupNorm -> SiLU 激活的基本模块
# 卷积使用 WeightStandardizedConv2d,便于搭配 GroupNorm
def forward(self, x, scale_shift = None):
x = self.proj(x)
x = self.norm(x)
if exists(scale_shift):
scale, shift = scale_shift
x = x * (scale + 1) + shift
# 如果从时间嵌入得到 scale_shift,则对特征图进行缩放和偏移
x = self.act(x)
return x
# 输出经过标准化和激活的张量
class ResnetBlock(nn.Module):
def __init__(self, dim, dim_out, *, time_emb_dim = None, groups = 8):
super().__init__()
# 如果传入了 time_emb_dim,则对时间嵌入进行线性映射得到 scale 和 shift
self.mlp = nn.Sequential(
nn.SiLU(),
nn.Linear(time_emb_dim, dim_out * 2)
) if exists(time_emb_dim) else None
self.block1 = Block(dim, dim_out, groups = groups)
self.block2 = Block(dim_out, dim_out, groups = groups)
self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity()
# 如果 dim != dim_out,就用 1x1 卷积在残差分支中对通道数进行调整
def forward(self, x, time_emb = None):
scale_shift = None
if exists(self.mlp) and exists(time_emb):
time_emb = self.mlp(time_emb)
time_emb = rearrange(time_emb, '
1
1
')
scale_shift = time_emb.chunk(2, dim = 1)
# 将 time_emb 拆分为 (scale, shift)
h = self.block1(x, scale_shift = scale_shift)
h = self.block2(h)
return h + self.res_conv(x)
# 最终输出为正常流 (h) + 残差分支
class LinearAttention(nn.Module):
def __init__(self, dim, heads = 4, dim_head = 32):
super().__init__()
self.scale = dim_head ** -0.5
self.heads = heads
hidden_dim = dim_head * heads
self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False)
# 一次性生成 q, k, v
self.to_out = nn.Sequential(
nn.Conv2d(hidden_dim, dim, 1),
LayerNorm(dim)
)
# 输出层(卷积 + LayerNorm)
def forward(self, x):
b, c, h, w = x.shape
qkv = self.to_qkv(x).chunk(3, dim = 1)
# 将通道维分成 q, k, v
q, k, v = map(lambda t: rearrange(t, '
', h = self.heads), qkv)
q = q.softmax(dim = -2)
k = k.softmax(dim = -1)
# 分别对 q 的通道维 (-2) 和 k 的序列维 (-1) 做 softmax
q = q * self.scale
v = v / (h * w)
# 缩放 q,以及对 v 做归一化
context = torch.einsum('
', k, v)
# 先将 k 和 v 做乘积,得到上下文 context
out = torch.einsum('
', context, q)
# 再和 q 做乘积以得到输出
out = rearrange(out, '
', h = self.heads, x = h, y = w)
# reshape 回原始形状
return self.to_out(out)
# 卷积 + LayerNorm 得到最终结果
class Attention(nn.Module):
def __init__(self, dim, heads = 4, dim_head = 32):
super().__init__()
self.scale = dim_head ** -0.5
self.heads = heads
hidden_dim = dim_head * heads
self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False)
self.to_out = nn.Conv2d(hidden_dim, dim, 1)
# 自注意力机制:先获取 q, k, v,再做注意力加权求和,最后映射回 dim
def forward(self, x):
b, c, h, w = x.shape
qkv = self.to_qkv(x).chunk(3, dim = 1)
q, k, v = map(lambda t: rearrange(t, '
', h = self.heads), qkv)
q = q * self.scale
# 缩放 q
sim = torch.einsum('
', q, k)
# 相似度矩阵 sim (b, heads, i, j)
attn = sim.softmax(dim = -1)
# 沿着最后一维做 softmax,得到注意力分布
out = torch.einsum('
', attn, v)
# 加权求和得到输出
out = rearrange(out, '
', x = h, y = w)
# reshape 回原始分辨率
return self.to_out(out)
# 最后再用 1x1 卷积映射回 dim 维度
# model
class Unet(nn.Module):
def __init__(
self,
dim,
init_dim = None,
out_dim = None,
dim_mults=(1, 2, 4, 8),
channels = 3,
resnet_block_groups = 8,
learned_sinusoidal_cond = False,
random_fourier_features = False,
learned_sinusoidal_dim = 16
):
super().__init__()
# determine dimensions
self.channels = channels
init_dim = default(init_dim, dim)
self.init_conv = nn.Conv2d(channels, init_dim, 7, padding = 3)
# 输入通道 -> init_dim, 使用 7x7 卷积做初始特征提取
dims = [init_dim, *map(lambda m: dim * m, dim_mults)]
in_out = list(zip(dims[:-1], dims[1:]))
# 比如 dim=64, dim_mults=(1,2,4,8), 则 dims=[64, 64*1, 64*2, 64*4, 64*8]
# in_out 就是 [(64,64),(64,128),(128,256),(256,512)]
block_klass = partial(ResnetBlock, groups = resnet_block_groups)
# 使用部分函数 partial,将 ResnetBlock 的 groups 参数固定
# time embeddings
time_dim = dim * 4
# 时间嵌入的维度,一般设置为 4 倍 base dim
self.random_or_learned_sinusoidal_cond = learned_sinusoidal_cond or random_fourier_features
if self.random_or_learned_sinusoidal_cond:
sinu_pos_emb = RandomOrLearnedSinusoidalPosEmb(learned_sinusoidal_dim, random_fourier_features)
fourier_dim = learned_sinusoidal_dim + 1
else:
sinu_pos_emb = SinusoidalPosEmb(dim)
fourier_dim = dim
# 根据需要选择使用随机/可学习的正弦嵌入,或使用经典的正弦嵌入
self.time_mlp = nn.Sequential(
sinu_pos_emb,
nn.Linear(fourier_dim, time_dim),
nn.GELU(),
nn.Linear(time_dim, time_dim)
)
# 时间嵌入先经过正弦嵌入,然后用两个全连接层(中间激活为 GELU),维度转为 time_dim
# layers
self.downs = nn.ModuleList([])
self.ups = nn.ModuleList([])
num_resolutions = len(in_out)
for ind, (dim_in, dim_out) in enumerate(in_out):
is_last = ind >= (num_resolutions - 1)
# 判断是否是最后一个分辨率
self.downs.append(nn.ModuleList([
block_klass(dim_in, dim_in, time_emb_dim = time_dim),
block_klass(dim_in, dim_in, time_emb_dim = time_dim),
Residual(PreNorm(dim_in, LinearAttention(dim_in))),
Downsample(dim_in, dim_out) if not is_last else nn.Conv2d(dim_in, dim_out, 3, padding = 1)
]))
# down 阶段:
# 1) ResnetBlock(dim_in -> dim_in)
# 2) 再一个 ResnetBlock(dim_in -> dim_in)
# 3) Residual(PreNorm(LinearAttention))
# 4) 如果不是最后层,用 Downsample;否则用 3x3 卷积保持分辨率
mid_dim = dims[-1]
self.mid_block1 = block_klass(mid_dim, mid_dim, time_emb_dim = time_dim)
self.mid_attn = Residual(PreNorm(mid_dim, Attention(mid_dim)))
self.mid_block2 = block_klass(mid_dim, mid_dim, time_emb_dim = time_dim)
# 中间层(U-Net 最底部):ResnetBlock -> 自注意力 -> ResnetBlock
for ind, (dim_in, dim_out) in enumerate(reversed(in_out)):
is_last = ind == (len(in_out) - 1)
# 倒序遍历 in_out,用于 up 阶段
self.ups.append(nn.ModuleList([
block_klass(dim_out + dim_in, dim_out, time_emb_dim = time_dim),
block_klass(dim_out + dim_in, dim_out, time_emb_dim = time_dim),
Residual(PreNorm(dim_out, LinearAttention(dim_out))),
Upsample(dim_out, dim_in) if not is_last else nn.Conv2d(dim_out, dim_in, 3, padding = 1)
]))
# up 阶段的逻辑与 down 类似,只是要先拼接 skip connection
self.out_dim = default(out_dim, channels)
# 最终输出通道数,默认与输入通道一致
self.final_res_block = block_klass(dim * 2, dim, time_emb_dim = time_dim)
# 最后一步和初始输入拼接后,再过一个 ResnetBlock
self.final_conv = nn.Conv2d(dim, self.out_dim, 1)
# 通过 1x1 卷积将维度映射到 out_dim
def forward(self, x, time):
x = self.init_conv(x)
# 初始卷积提取特征
r = x.clone()
# 保存初始特征用于最后拼接
t = self.time_mlp(time)
# 将时间步 time 通过 time_mlp 得到时间嵌入 t
h = []
# 用于保存每层的输出,以便在解码器阶段做 skip connection
# downsample
for block1, block2, attn, downsample in self.downs:
x = block1(x, t)
h.append(x)
x = block2(x, t)
x = attn(x)
h.append(x)
x = downsample(x)
# 依次执行 block1 -> block2 -> attn -> downsample
# 并存储中间输出 h
# mid
x = self.mid_block1(x, t)
x = self.mid_attn(x)
x = self.mid_block2(x, t)
# U-Net 中间层的处理
# upsample
for block1, block2, attn, upsample in self.ups:
# pop 出下采样时存储的输出,进行 skip connection
x = torch.cat((x, h.pop()), dim = 1)
x = block1(x, t)
x = torch.cat((x, h.pop()), dim = 1)
x = block2(x, t)
x = attn(x)
x = upsample(x)
# final
x = torch.cat((x, r), dim = 1)
# 跟最初的输入特征 r 拼接
x = self.final_res_block(x, t)
return self.final_conv(x)
# 最终输出一个跟输入维度相匹配的特征图
model = Unet(64)
# 实例化一个 U-Net 模型,基本通道数 dim = 64
class GaussianDiffusion(nn.Module):
def __init__(
self,
model, # 传入的 U-Net 等模型,用于预测噪声
*,
image_size, # 图像大小(宽和高)
timesteps = 1000, # 扩散过程的总时间步数
beta_schedule = '
r',# beta 的调度方式;此处仅支持 '
r'
auto_normalize = True # 是否自动将图像 [0,1] 归一化到 [-1,1]
):
super().__init__()
# 继承自 nn.Module
assert not (type(self) == GaussianDiffusion and model.channels != model.out_dim)
# 如果是 GaussianDiffusion 类本身,则要求 model 的输入通道和输出通道一致,否则会出错
assert not model.random_or_learned_sinusoidal_cond
# 在本实现里,不允许网络使用随机或可学习的正弦位置编码
self.model = model
# 保存传入的模型(通常是一个 U-Net)
self.channels = self.model.channels
# 模型的通道数量(图像的通道,默认为 3)
self.image_size = image_size
# 保存图像大小
if beta_schedule == '
r':
beta_schedule_fn = linear_beta_schedule
else:
raise ValueError(f'
')
# 根据传入的 beta_schedule 字符串选择 beta 调度函数
# 目前只支持 '
r',否则抛出异常
# calculate beta and other precalculated parameters
betas = beta_schedule_fn(timesteps)
# 计算在每个时间步上的 beta 值(线性递增)
alphas = 1. - betas
# α_t = 1 - β_t
alphas_cumprod = torch.cumprod(alphas, dim=0)
# 累乘得到 α_1 * α_2 * ... * α_t
alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value = 1.)
# 向前偏移一个时间步,便于在计算 q(x_{t-1}|x_t, x_0) 时使用
# 第一个时间步补 1,使 α_cumprod_prev 的长度与 alphas_cumprod 一致
timesteps, = betas.shape
# 获取时间步数(1000)
self.num_timesteps = int(timesteps)
# 将其保存为整型
# sampling related parameters
self.sampling_timesteps = timesteps
# 采样时使用的步数,默认和训练步数相同
# helper function to register buffer from float64 to float32
register_buffer = lambda name, val: self.register_buffer(name, val.to(torch.float32))
# 定义一个小函数,用于将各种张量注册为 buffer,并转换为 float32 类型
register_buffer('
', betas)
register_buffer('
', alphas_cumprod)
register_buffer('
', alphas_cumprod_prev)
# 将以上计算好的 beta、alpha 累乘、以及前一个时间步的 alpha 累乘注册为 buffer
# 这些值是训练和推理都会用到,但不会被训练的参数
# calculations for diffusion q(x_t | x_{t-1}) and others
register_buffer('
', torch.sqrt(alphas_cumprod))
# sqrt(累乘α_t)
register_buffer('
', torch.sqrt(1. - alphas_cumprod))
# sqrt(1 - 累乘α_t)
register_buffer('
', torch.log(1. - alphas_cumprod))
# 记录 log(1 - 累乘α_t)
register_buffer('
', torch.sqrt(1. / alphas_cumprod))
# sqrt(1 / 累乘α_t)
register_buffer('
', torch.sqrt(1. / alphas_cumprod - 1))
# sqrt(1 / 累乘α_t - 1)
# calculations for posterior q(x_{t-1} | x_t, x_0)
posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)
# q(x_{t-1} | x_t, x_0) 的后验方差
# 根据公式: posterior_variance_t = β_t * (1 - α_{t-1}累乘) / (1 - α_t累乘)
# above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
register_buffer('
', posterior_variance)
# 注册后验方差
# below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
register_buffer('
', torch.log(posterior_variance.clamp(min =1e-20)))
# 取对数时夹紧最小值防止数值溢出
register_buffer('
', betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))
# 后验均值系数 1
register_buffer('
', (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod))
# 后验均值系数 2
# derive loss weight
# snr - signal noise ratio
snr = alphas_cumprod / (1 - alphas_cumprod)
# SNR = α_t累乘 / (1 - α_t累乘)
# https://arxiv.org/abs/2303.09556
maybe_clipped_snr = snr.clone()
# 这里可以对 snr 做一些裁剪操作,如果需要的话
register_buffer('
', maybe_clipped_snr / snr)
# 用于加权损失的系数
# auto-normalization of data [0, 1] -> [-1, 1] - can turn off by setting it to be False
self.normalize = normalize_to_neg_one_to_one if auto_normalize else identity
self.unnormalize = unnormalize_to_zero_to_one if auto_normalize else identity
# 根据 auto_normalize 决定是否对数据进行 [-1,1] <-> [0,1] 的转换
def predict_start_from_noise(self, x_t, t, noise):
"""
通过 x_t 和噪声,反推 x_0 的预测值
x_0 = 1 / sqrt(alpha_cumprod) * x_t - sqrt(1 / alpha_cumprod - 1) * noise
"""
return (
extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t -
extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise
)
def predict_noise_from_start(self, x_t, t, x0):
"""
通过 x_t 和对 x_0 的预测值,反推噪声的预测值
noise = (1 / sqrt(alpha_cumprod) * x_t - x_0) / sqrt(1 / alpha_cumprod - 1)
"""
return (
(extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - x0) / \
extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
)
def q_posterior(self, x_start, x_t, t):
"""
计算后验分布 q(x_{t-1} | x_t, x_0) 的均值和方差
"""
posterior_mean = (
extract(self.posterior_mean_coef1, t, x_t.shape) * x_start +
extract(self.posterior_mean_coef2, t, x_t.shape) * x_t
)
# 后验分布的均值
posterior_variance = extract(self.posterior_variance, t, x_t.shape)
# 后验分布的方差
posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape)
# 后验分布方差的对数(已做 clip)
return posterior_mean, posterior_variance, posterior_log_variance_clipped
def model_predictions(self, x, t, clip_x_start = False, rederive_pred_noise = False):
"""
给定当前噪声图 x 和时间步 t,通过模型预测噪声 pred_noise,并得到对 x_0 的估计 x_start
"""
model_output = self.model(x, t)
# 模型输出,通常是预测噪声
maybe_clip = partial(torch.clamp, min = -1., max = 1.) if clip_x_start else identity
# 如果需要对预测出的 x_0 做裁剪,则 partial(torch.clamp);否则恒等函数
pred_noise = model_output
# 这里把模型输出视为噪声预测
x_start = self.predict_start_from_noise(x, t, pred_noise)
x_start = maybe_clip(x_start)
# 对 x_0 进行 [-1,1] 裁剪(可选)
if clip_x_start and rederive_pred_noise:
# 如果 x_0 被裁剪,为了更准确,需要重新计算一次噪声
pred_noise = self.predict_noise_from_start(x, t, x_start)
return pred_noise, x_start
def p_mean_variance(self, x, t, clip_denoised = True):
"""
计算从扩散过程中 p(x_{t-1} | x_t) 的均值和方差,用于反向采样
"""
noise, x_start = self.model_predictions(x, t)
# 模型预测噪声和 x_0
if clip_denoised:
x_start.clamp_(-1., 1.)
# 默认会把 x_0 的范围裁剪到 [-1,1]
model_mean, posterior_variance, posterior_log_variance = self.q_posterior(
x_start = x_start,
x_t = x,
t = t
)
# 计算后验分布的均值和方差
# 这里的后验分布相当于 q(x_{t-1}|x_t, x_0)
return model_mean, posterior_variance, posterior_log_variance, x_start
@torch.no_grad()
def p_sample(self, x, t: int):
"""
在反向扩散的某一个时间步 t,从 p(x_{t-1} | x_t) 采样
"""
b, *_, device = *x.shape, x.device
batched_times = torch.full((b,), t, device = x.device, dtype = torch.long)
# 构造与批大小相同的时间张量
model_mean, _, model_log_variance, x_start = self.p_mean_variance(
x = x,
t = batched_times,
clip_denoised = True
)
# 根据 x_t 计算后验均值和方差
noise = torch.randn_like(x) if t > 0 else 0.
# 如果 t > 0 则在采样时加噪声;如果 t=0,则不再加噪声
pred_img = model_mean + (0.5 * model_log_variance).exp() * noise
# 采样公式: x_{t-1} = 均值 + 标准差 * 噪声
return pred_img, x_start
@torch.no_grad()
def p_sample_loop(self, shape, return_all_timesteps = False):
"""
从纯噪声开始,逐步反向采样还原图像
"""
batch, device = shape[0], self.betas.device
# batch 大小,使用存储在 buffer 中的 betas 的设备
img = torch.randn(shape, device = device)
# 初始从标准正态分布采样
imgs = [img]
# 用于保存采样过程中每个时间步的结果
x_start = None
###########################################
## TODO: plot the sampling process ##
###########################################
for t in tqdm(reversed(range(0, self.num_timesteps)), desc = '
', total = self.num_timesteps):
# 从 T-1 到 0 逐步反向采样
img, x_start = self.p_sample(img, t)
imgs.append(img)
ret = img if not return_all_timesteps else torch.stack(imgs, dim = 1)
# 如果 return_all_timesteps=True, 返回整个采样序列;否则只返回最终生成的图像
ret = self.unnormalize(ret)
# 将图像从 [-1,1] 转回 [0,1]
return ret
@torch.no_grad()
def sample(self, batch_size = 16, return_all_timesteps = False):
"""
对外提供的采样接口
"""
image_size, channels = self.image_size, self.channels
sample_fn = self.p_sample_loop
# 默认使用 p_sample_loop 进行逐步采样
return sample_fn((batch_size, channels, image_size, image_size), return_all_timesteps = return_all_timesteps)
def q_sample(self, x_start, t, noise=None):
"""
前向扩散:从 x_0 得到 x_t 的采样
x_t = sqrt(α_cumprod) * x_0 + sqrt(1-α_cumprod) * noise
"""
noise = default(noise, lambda: torch.randn_like(x_start))
# 如果不指定噪声,则生成一个和 x_start 形状相同的高斯噪声
return (
extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
)
@property
def loss_fn(self):
return F.mse_loss
# 训练时使用的损失函数,默认是 MSE
def p_losses(self, x_start, t, noise = None):
"""
在给定 x_0 以及随机的时间步 t 时,计算训练时的损失
"""
b, c, h, w = x_start.shape
noise = default(noise, lambda: torch.randn_like(x_start))
# noise sample
x = self.q_sample(x_start = x_start, t = t, noise = noise)
# 前向扩散,将 x_0 添加噪声到 x_t
# predict and take gradient step
model_out = self.model(x, t)
# 模型对 x_t 进行估计噪声
loss = self.loss_fn(model_out, noise, reduction = '
')
# 计算 MSE 损失 (逐元素)
loss = reduce(loss, '
', '
')
# 在除 batch 之外的所有维度取平均 (即每个样本的损失)
loss = loss * extract(self.loss_weight, t, loss.shape)
# 乘以权重 (与 SNR 相关)
return loss.mean()
# 返回对整个 batch 的平均损失
def forward(self, img, *args, **kwargs):
"""
模块的前向调用接口,一般在训练时调用
"""
b, c, h, w, device, img_size, = *img.shape, img.device, self.image_size
# 解包图像形状、设备以及定义的图像大小
assert h == img_size and w == img_size, f'
and
'
t = torch.randint(0, self.num_timesteps, (b,), device=device).long()
# 随机采样一个时间步 t 用于训练
img = self.normalize(img)
# 如果开启了 auto_normalize,则把 [0,1] 的图片映射到 [-1,1]
return self.p_losses(img, t, *args, **kwargs)
# 调用 p_losses 计算训练损失
path = '
'
# 数据所在的文件路径,这里假设所有训练图像都在 ./faces/faces 目录中
IMG_SIZE = 64
# 设置图像尺寸为 64x64
batch_size = 16
# 设置训练时的批大小为 16 张图像
train_num_steps = 10000
# 训练的总步数,指优化器更新(iteration)次数
lr = 1e-3
# 学习率 (learning rate),这里设置为 0.001
grad_steps = 1
# 梯度累积步数;若设置大于 1 则表示每累积一定次数的反向传播再进行一次优化更新
ema_decay = 0.995
# 指数移动平均 (EMA) 的衰减率,常用于在训练过程中平滑模型权重
channels = 16
# U-Net 的基础通道数,即第一个卷积层的通道数
dim_mults = (1, 2, 4)
# 用来指定 U-Net 不同下采样 / 上采样阶段的通道扩展倍数,
# 最终网络结构中的通道数将按 (channels, 2*channels, 4*channels, ...) 的形式逐步增加
timesteps = 100
# 扩散过程中加噪声的时间步数 T;比如在 DDPM 中可以是 1000,这里设置为 100
beta_schedule = '
r'
# beta 的调度方式(表示在扩散过程中 beta 的变化),此处设置为线性
model = Unet(
dim = channels,
dim_mults = dim_mults
)
# 实例化一个 U-Net 模型对象,输入的基本通道数为 16,
# 会根据 dim_mults 逐步在网络层中增加通道数
diffusion = GaussianDiffusion(
model,
image_size = IMG_SIZE,
timesteps = timesteps,
beta_schedule = beta_schedule
)
# 将 U-Net 模型封装到 GaussianDiffusion 类中,
# 并设置扩散过程中的一些参数(如图像大小、时间步数等)。
# 该类会负责前向扩散(加噪)和反向扩散(去噪)的具体实现。
class Trainer:
def __init__(
self,
diffusion,
path,
train_batch_size=16,
train_lr=1e-3,
train_num_steps=10000,
gradient_accumulate_every=1,
ema_decay=0.995,
save_and_sample_every=1000,
):
self.accelerator = Accelerator()
self.diffusion = diffusion
self.batch_size = train_batch_size
self.gradient_accumulate_every = gradient_accumulate_every
self.train_num_steps = train_num_steps
self.ema = EMA(diffusion, decay=ema_decay)
self.save_and_sample_every = save_and_sample_every
self.path = path
self.optimizer = Adam(diffusion.parameters(), lr=train_lr)
self.dataset = Dataset(path, image_size=diffusion.image_size)
self.dataloader = DataLoader(self.dataset, batch_size=train_batch_size, shuffle=True)
self.dataloader = self.accelerator.prepare(self.dataloader)
self.diffusion = self.accelerator.prepare(self.diffusion)
self.optimizer = self.accelerator.prepare(self.optimizer)
def train(self):
progress_bar = tqdm(initial=0, total=self.train_num_steps)
progress_bar.set_description("Training")
for i in range(self.train_num_steps):
for img, _ in self.dataloader:
loss = self.diffusion(img)
self.accelerator.backward(loss / self.gradient_accumulate_every)
if i % self.gradient_accumulate_every == 0:
self.optimizer.step()
self.optimizer.zero_grad()
if i % self.save_and_sample_every == 0:
self.ema.ema_model.eval()
samples = self.ema.ema_model.sample(batch_size=self.batch_size)
print(f"Step {i}: Saving samples...")
progress_bar.update(1)
self.accelerator.end_training()
trainer = Trainer(
diffusion,
path,
train_batch_size=batch_size,
train_lr=lr,
train_num_steps=train_num_steps,
gradient_accumulate_every=grad_steps,
ema_decay=ema_decay,
save_and_sample_every=1000
)
# 实例化一个 Trainer 类来管理训练流程:
# - 使用 diffusion 模型进行前向与反向传播
# - 每个 batch 的大小为 16
# - 使用学习率 1e-3
# - 总训练步数为 10000
# - 每个 step 都更新梯度(grad_steps=1)
# - EMA 衰减因子为 0.995
# - 每 1000 步保存一次模型并进行一次采样
trainer.train()
# 开始训练,Trainer 内部会执行循环读取数据、前向计算、损失反传、优化器更新等流程。
运行环境: