import math
import copy
from pathlib import Path
from random import random
from functools import partial
from collections import namedtuple
from multiprocessing import cpu_count
import torch
from torch import nn, einsum
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.optim import Adam
import torchvision
from torchvision import transforms as T, utils
from einops import rearrange, reduce, repeat
from einops.layers.torch import Rearrange
from PIL import Image
from tqdm.auto import tqdm
from ema_pytorch import EMA
from accelerate import Accelerator
import matplotlib.pyplot as plt
import os
torch.backends.cudnn.benchmark = True
torch.manual_seed(4096)
if torch.cuda.is_available():
torch.cuda.manual_seed(4096)
def linear_beta_schedule(timesteps):
"""
linear schedule, proposed in original ddpm paper
线性 Beta 时间调度函数,用于扩散模型(DDPM)中定义 Beta 参数在每个时间步的取值。
"""
scale = 1000 / timesteps
beta_start = scale * 0.0001
beta_end = scale * 0.02
return torch.linspace(beta_start, beta_end, timesteps, dtype = torch.float64)
def extract(a, t, x_shape):
"""
从向量 a 中取出与时间步 t 对应的值,并 reshape 成 x_shape 的形状。
"""
b, *_ = t.shape
out = a.gather(-1, t)
return out.reshape(b, *((1,) * (len(x_shape) - 1)))
class Dataset(Dataset):
"""
自定义数据集,用于加载指定文件夹下的 .jpg 图像文件。
"""
def __init__(
self,
folder,
image_size
):
self.folder = folder
self.image_size = image_size
self.paths = [p for p in Path(f'{folder}').glob(f'**/*.jpg')]
self.transform = T.Compose([
T.Resize(image_size),
T.ToTensor()
])
def __len__(self):
return len(self.paths)
def __getitem__(self, index):
path = self.paths[index]
img = Image.open(path)
return self.transform(img)
def exists(x):
return x is not None
def default(val, d):
if exists(val):
return val
return d() if callable(d) else d
def identity(t, *args, **kwargs):
return t
def cycle(dl):
while True:
for data in dl:
yield data
def has_int_squareroot(num):
return (math.sqrt(num) ** 2) == num
def num_to_groups(num, divisor):
groups = num // divisor
remainder = num % divisor
arr = [divisor] * groups
if remainder > 0:
arr.append(remainder)
return arr
def normalize_to_neg_one_to_one(img):
return img * 2 - 1
def unnormalize_to_zero_to_one(t):
return (t + 1) * 0.5
class Residual(nn.Module):
def __init__(self, fn):
super().__init__()
self.fn = fn
def forward(self, x, *args, **kwargs):
return self.fn(x, *args, **kwargs) + x
def Upsample(dim, dim_out = None):
return nn.Sequential(
nn.Upsample(scale_factor = 2, mode = 'nearest'),
nn.Conv2d(dim, default(dim_out, dim), 3, padding = 1)
)
def Downsample(dim, dim_out = None):
return nn.Sequential(
Rearrange('b c (h p1) (w p2) -> b (c p1 p2) h w', p1 = 2, p2 = 2),
nn.Conv2d(dim * 4, default(dim_out, dim), 1)
)
class WeightStandardizedConv2d(nn.Conv2d):
"""
https://arxiv.org/abs/1903.10520
weight standardization purportedly works synergistically with group normalization
"""
def forward(self, x):
eps = 1e-5 if x.dtype == torch.float32 else 1e-3
weight = self.weight
mean = reduce(weight, 'o ... -> o 1 1 1', 'mean')
var = reduce(weight, 'o ... -> o 1 1 1', partial(torch.var, unbiased = False))
normalized_weight = (weight - mean) * (var + eps).rsqrt()
return F.conv2d(x, normalized_weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
class LayerNorm(nn.Module):
def __init__(self, dim):
super().__init__()
self.g = nn.Parameter(torch.ones(1, dim, 1, 1))
def forward(self, x):
eps = 1e-5 if x.dtype == torch.float32 else 1e-3
var = torch.var(x, dim = 1, unbiased = False, keepdim = True)
mean = torch.mean(x, dim = 1, keepdim = True)
return (x - mean) * (var + eps).rsqrt() * self.g
class PreNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.fn = fn
self.norm = LayerNorm(dim)
def forward(self, x):
x = self.norm(x)
return self.fn(x)
class SinusoidalPosEmb(nn.Module):
def __init__(self, dim):
super().__init__()
self.dim = dim
def forward(self, x):
device = x.device
half_dim = self.dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
emb = x[:, None] * emb[None, :]
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
return emb
class RandomOrLearnedSinusoidalPosEmb(nn.Module):
"" following @crowsonkb 's lead with random (learned optional) sinusoidal pos emb ""
"" https://github.com/crowsonkb/v-diffusion-jax/blob/master/diffusion/models/danbooru_128.py#L8 ""
def __init__(self, dim, is_random = False):
super().__init__()
assert (dim % 2) == 0
half_dim = dim // 2
self.weights = nn.Parameter(torch.randn(half_dim), requires_grad = not is_random)
# 如果 is_random 为 False,则这些频率是可学习的;否则是随机固定
def forward(self, x):
x = rearrange(x, 'b -> b 1')
# 将输入 x reshape 成 (batch, 1)
freqs = x * rearrange(self.weights, 'd -> 1 d') * 2 * math.pi
# 计算随机或可学习的频率 freqs
fouriered = torch.cat((freqs.sin(), freqs.cos()), dim = -1)
# 拼接正弦和余弦
fouriered = torch.cat((x, fouriered), dim = -1)
# 再将原始 x 与正余弦部分合并
return fouriered
# 返回包含输入和正余弦编码的结果
# building block modules
class Block(nn.Module):
def __init__(self, dim, dim_out, groups = 8):
super().__init__()
self.proj = WeightStandardizedConv2d(dim, dim_out, 3, padding = 1)
self.norm = nn.GroupNorm(groups, dim_out)
self.act = nn.SiLU()
# 一个卷积 -> GroupNorm -> SiLU 激活的基本模块
# 卷积使用 WeightStandardizedConv2d,便于搭配 GroupNorm
def forward(self, x, scale_shift = None):
x = self.proj(x)
x = self.norm(x)
if exists(scale_shift):
scale, shift = scale_shift
x = x * (scale + 1) + shift
# 如果从时间嵌入得到 scale_shift,则对特征图进行缩放和偏移
x = self.act(x)
return x
# 输出经过标准化和激活的张量
class ResnetBlock(nn.Module):
def __init__(self, dim, dim_out, *, time_emb_dim = None, groups = 8):
super().__init__()
# 如果传入了 time_emb_dim,则对时间嵌入进行线性映射得到 scale 和 shift
self.mlp = nn.Sequential(
nn.SiLU(),
nn.Linear(time_emb_dim, dim_out * 2)
) if exists(time_emb_dim) else None
self.block1 = Block(dim, dim_out, groups = groups)
self.block2 = Block(dim_out, dim_out, groups = groups)
self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity()
# 如果 dim != dim_out,就用 1x1 卷积在残差分支中对通道数进行调整
def forward(self, x, time_emb = None):
scale_shift = None
if exists(self.mlp) and exists(time_emb):
time_emb = self.mlp(time_emb)
time_emb = rearrange(time_emb, 'b c -> b c 1 1')
scale_shift = time_emb.chunk(2, dim = 1)
# 将 time_emb 拆分为 (scale, shift)
h = self.block1(x, scale_shift = scale_shift)
h = self.block2(h)
return h + self.res_conv(x)
# 最终输出为正常流 (h) + 残差分支
class LinearAttention(nn.Module):
def __init__(self, dim, heads = 4, dim_head = 32):
super().__init__()
self.scale = dim_head ** -0.5
self.heads = heads
hidden_dim = dim_head * heads
self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False)
# 一次性生成 q, k, v
self.to_out = nn.Sequential(
nn.Conv2d(hidden_dim, dim, 1),
LayerNorm(dim)
)
# 输出层(卷积 + LayerNorm)
def forward(self, x):
b, c, h, w = x.shape
qkv = self.to_qkv(x).chunk(3, dim = 1)
# 将通道维分成 q, k, v
q, k, v = map(lambda t: rearrange(t, 'b (h c) x y -> b h c (x y)', h = self.heads), qkv)
q = q.softmax(dim = -2)
k = k.softmax(dim = -1)
# 分别对 q 的通道维 (-2) 和 k 的序列维 (-1) 做 softmax
q = q * self.scale
v = v / (h * w)
# 缩放 q,以及对 v 做归一化
context = torch.einsum('b h d n, b h e n -> b h d e', k, v)
# 先将 k 和 v 做乘积,得到上下文 context
out = torch.einsum('b h d e, b h d n -> b h e n', context, q)
# 再和 q 做乘积以得到输出
out = rearrange(out, 'b h c (x y) -> b (h c) x y', h = self.heads, x = h, y = w)
# reshape 回原始形状
return self.to_out(out)
# 卷积 + LayerNorm 得到最终结果
class Attention(nn.Module):
def __init__(self, dim, heads = 4, dim_head = 32):
super().__init__()
self.scale = dim_head ** -0.5
self.heads = heads
hidden_dim = dim_head * heads
self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False)
self.to_out = nn.Conv2d(hidden_dim, dim, 1)
# 自注意力机制:先获取 q, k, v,再做注意力加权求和,最后映射回 dim
def forward(self, x):
b, c, h, w = x.shape
qkv = self.to_qkv(x).chunk(3, dim = 1)
q, k, v = map(lambda t: rearrange(t, 'b (h c) x y -> b h c (x y)', h = self.heads), qkv)
q = q * self.scale
# 缩放 q
sim = torch.einsum('b h d i, b h d j -> b h i j', q, k)
# 相似度矩阵 sim (b, heads, i, j)
attn = sim.softmax(dim = -1)
# 沿着最后一维做 softmax,得到注意力分布
out = torch.einsum('b h i j, b h d j -> b h i d', attn, v)
# 加权求和得到输出
out = rearrange(out, 'b h (x y) d -> b (h d) x y', x = h, y = w)
# reshape 回原始分辨率
return self.to_out(out)
# 最后再用 1x1 卷积映射回 dim 维度
# model
class Unet(nn.Module):
def __init__(
self,
dim,
init_dim = None,
out_dim = None,
dim_mults=(1, 2, 4, 8),
channels = 3,
resnet_block_groups = 8,
learned_sinusoidal_cond = False,
random_fourier_features = False,
learned_sinusoidal_dim = 16
):
super().__init__()
# determine dimensions
self.channels = channels
init_dim = default(init_dim, dim)
self.init_conv = nn.Conv2d(channels, init_dim, 7, padding = 3)
# 输入通道 -> init_dim, 使用 7x7 卷积做初始特征提取
dims = [init_dim, *map(lambda m: dim * m, dim_mults)]
in_out = list(zip(dims[:-1], dims[1:]))
# 比如 dim=64, dim_mults=(1,2,4,8), 则 dims=[64, 64*1, 64*2, 64*4, 64*8]
# in_out 就是 [(64,64),(64,128),(128,256),(256,512)]
block_klass = partial(ResnetBlock, groups = resnet_block_groups)
# 使用部分函数 partial,将 ResnetBlock 的 groups 参数固定
# time embeddings
time_dim = dim * 4
# 时间嵌入的维度,一般设置为 4 倍 base dim
self.random_or_learned_sinusoidal_cond = learned_sinusoidal_cond or random_fourier_features
if self.random_or_learned_sinusoidal_cond:
sinu_pos_emb = RandomOrLearnedSinusoidalPosEmb(learned_sinusoidal_dim, random_fourier_features)
fourier_dim = learned_sinusoidal_dim + 1
else:
sinu_pos_emb = SinusoidalPosEmb(dim)
fourier_dim = dim
# 根据需要选择使用随机/可学习的正弦嵌入,或使用经典的正弦嵌入
self.time_mlp = nn.Sequential(
sinu_pos_emb,
nn.Linear(fourier_dim, time_dim),
nn.GELU(),
nn.Linear(time_dim, time_dim)
)
# 时间嵌入先经过正弦嵌入,然后用两个全连接层(中间激活为 GELU),维度转为 time_dim
# layers
self.downs = nn.ModuleList([])
self.ups = nn.ModuleList([])
num_resolutions = len(in_out)
for ind, (dim_in, dim_out) in enumerate(in_out):
is_last = ind >= (num_resolutions - 1)
# 判断是否是最后一个分辨率
self.downs.append(nn.ModuleList([
block_klass(dim_in, dim_in, time_emb_dim = time_dim),
block_klass(dim_in, dim_in, time_emb_dim = time_dim),
Residual(PreNorm(dim_in, LinearAttention(dim_in))),
Downsample(dim_in, dim_out) if not is_last else nn.Conv2d(dim_in, dim_out, 3, padding = 1)
]))
# down 阶段:
# 1) ResnetBlock(dim_in -> dim_in)
# 2) 再一个 ResnetBlock(dim_in -> dim_in)
# 3) Residual(PreNorm(LinearAttention))
# 4) 如果不是最后层,用 Downsample;否则用 3x3 卷积保持分辨率
mid_dim = dims[-1]
self.mid_block1 = block_klass(mid_dim, mid_dim, time_emb_dim = time_dim)
self.mid_attn = Residual(PreNorm(mid_dim, Attention(mid_dim)))
self.mid_block2 = block_klass(mid_dim, mid_dim, time_emb_dim = time_dim)
# 中间层(U-Net 最底部):ResnetBlock -> 自注意力 -> ResnetBlock
for ind, (dim_in, dim_out) in enumerate(reversed(in_out)):
is_last = ind == (len(in_out) - 1)
# 倒序遍历 in_out,用于 up 阶段
self.ups.append(nn.ModuleList([
block_klass(dim_out + dim_in, dim_out, time_emb_dim = time_dim),
block_klass(dim_out + dim_in, dim_out, time_emb_dim = time_dim),
Residual(PreNorm(dim_out, LinearAttention(dim_out))),
Upsample(dim_out, dim_in) if not is_last else nn.Conv2d(dim_out, dim_in, 3, padding = 1)
]))
# up 阶段的逻辑与 down 类似,只是要先拼接 skip connection
self.out_dim = default(out_dim, channels)
# 最终输出通道数,默认与输入通道一致
self.final_res_block = block_klass(dim * 2, dim, time_emb_dim = time_dim)
# 最后一步和初始输入拼接后,再过一个 ResnetBlock
self.final_conv = nn.Conv2d(dim, self.out_dim, 1)
# 通过 1x1 卷积将维度映射到 out_dim
def forward(self, x, time):
x = self.init_conv(x)
# 初始卷积提取特征
r = x.clone()
# 保存初始特征用于最后拼接
t = self.time_mlp(time)
# 将时间步 time 通过 time_mlp 得到时间嵌入 t
h = []
# 用于保存每层的输出,以便在解码器阶段做 skip connection
# downsample
for block1, block2, attn, downsample in self.downs:
x = block1(x, t)
h.append(x)
x = block2(x, t)
x = attn(x)
h.append(x)
x = downsample(x)
# 依次执行 block1 -> block2 -> attn -> downsample
# 并存储中间输出 h
# mid
x = self.mid_block1(x, t)
x = self.mid_attn(x)
x = self.mid_block2(x, t)
# U-Net 中间层的处理
# upsample
for block1, block2, attn, upsample in self.ups:
# pop 出下采样时存储的输出,进行 skip connection
x = torch.cat((x, h.pop()), dim = 1)
x = block1(x, t)
x = torch.cat((x, h.pop()), dim = 1)
x = block2(x, t)
x = attn(x)
x = upsample(x)
# final
x = torch.cat((x, r), dim = 1)
# 跟最初的输入特征 r 拼接
x = self.final_res_block(x, t)
return self.final_conv(x)
# 最终输出一个跟输入维度相匹配的特征图
model = Unet(64)
# 实例化一个 U-Net 模型,基本通道数 dim = 64
class GaussianDiffusion(nn.Module):
def __init__(
self,
model, # 传入的 U-Net 等模型,用于预测噪声
*,
image_size, # 图像大小(宽和高)
timesteps = 1000, # 扩散过程的总时间步数
beta_schedule = 'linear',# beta 的调度方式;此处仅支持 'linear'
auto_normalize = True # 是否自动将图像 [0,1] 归一化到 [-1,1]
):
super().__init__()
# 继承自 nn.Module
assert not (type(self) == GaussianDiffusion and model.channels != model.out_dim)
# 如果是 GaussianDiffusion 类本身,则要求 model 的输入通道和输出通道一致,否则会出错
assert not model.random_or_learned_sinusoidal_cond
# 在本实现里,不允许网络使用随机或可学习的正弦位置编码
self.model = model
# 保存传入的模型(通常是一个 U-Net)
self.channels = self.model.channels
# 模型的通道数量(图像的通道,默认为 3)
self.image_size = image_size
# 保存图像大小
if beta_schedule == 'linear':
beta_schedule_fn = linear_beta_schedule
else:
raise ValueError(f'unknown beta schedule {beta_schedule}')
# 根据传入的 beta_schedule 字符串选择 beta 调度函数
# 目前只支持 'linear',否则抛出异常
# calculate beta and other precalculated parameters
betas = beta_schedule_fn(timesteps)
# 计算在每个时间步上的 beta 值(线性递增)
alphas = 1. - betas
# α_t = 1 - β_t
alphas_cumprod = torch.cumprod(alphas, dim=0)
# 累乘得到 α_1 * α_2 * ... * α_t
alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value = 1.)
# 向前偏移一个时间步,便于在计算 q(x_{t-1}|x_t, x_0) 时使用
# 第一个时间步补 1,使 α_cumprod_prev 的长度与 alphas_cumprod 一致
timesteps, = betas.shape
# 获取时间步数(1000)
self.num_timesteps = int(timesteps)
# 将其保存为整型
# sampling related parameters
self.sampling_timesteps = timesteps
# 采样时使用的步数,默认和训练步数相同
# helper function to register buffer from float64 to float32
register_buffer = lambda name, val: self.register_buffer(name, val.to(torch.float32))
# 定义一个小函数,用于将各种张量注册为 buffer,并转换为 float32 类型
register_buffer('betas', betas)
register_buffer('alphas_cumprod', alphas_cumprod)
register_buffer('alphas_cumprod_prev', alphas_cumprod_prev)
# 将以上计算好的 beta、alpha 累乘、以及前一个时间步的 alpha 累乘注册为 buffer
# 这些值是训练和推理都会用到,但不会被训练的参数
# calculations for diffusion q(x_t | x_{t-1}) and others
register_buffer('sqrt_alphas_cumprod', torch.sqrt(alphas_cumprod))
# sqrt(累乘α_t)
register_buffer('sqrt_one_minus_alphas_cumprod', torch.sqrt(1. - alphas_cumprod))
# sqrt(1 - 累乘α_t)
register_buffer('log_one_minus_alphas_cumprod', torch.log(1. - alphas_cumprod))
# 记录 log(1 - 累乘α_t)
register_buffer('sqrt_recip_alphas_cumprod', torch.sqrt(1. / alphas_cumprod))
# sqrt(1 / 累乘α_t)
register_buffer('sqrt_recipm1_alphas_cumprod', torch.sqrt(1. / alphas_cumprod - 1))
# sqrt(1 / 累乘α_t - 1)
# calculations for posterior q(x_{t-1} | x_t, x_0)
posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)
# q(x_{t-1} | x_t, x_0) 的后验方差
# 根据公式: posterior_variance_t = β_t * (1 - α_{t-1}累乘) / (1 - α_t累乘)
# above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
register_buffer('posterior_variance', posterior_variance)
# 注册后验方差
# below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
register_buffer('posterior_log_variance_clipped', torch.log(posterior_variance.clamp(min =1e-20)))
# 取对数时夹紧最小值防止数值溢出
register_buffer('posterior_mean_coef1', betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))
# 后验均值系数 1
register_buffer('posterior_mean_coef2', (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod))
# 后验均值系数 2
# derive loss weight
# snr - signal noise ratio
snr = alphas_cumprod / (1 - alphas_cumprod)
# SNR = α_t累乘 / (1 - α_t累乘)
# https://arxiv.org/abs/2303.09556
maybe_clipped_snr = snr.clone()
# 这里可以对 snr 做一些裁剪操作,如果需要的话
register_buffer('loss_weight', maybe_clipped_snr / snr)
# 用于加权损失的系数
# auto-normalization of data [0, 1] -> [-1, 1] - can turn off by setting it to be False
self.normalize = normalize_to_neg_one_to_one if auto_normalize else identity
self.unnormalize = unnormalize_to_zero_to_one if auto_normalize else identity
# 根据 auto_normalize 决定是否对数据进行 [-1,1] <-> [0,1] 的转换
def predict_start_from_noise(self, x_t, t, noise):
"""
通过 x_t 和噪声,反推 x_0 的预测值
x_0 = 1 / sqrt(alpha_cumprod) * x_t - sqrt(1 / alpha_cumprod - 1) * noise
"""
return (
extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t -
extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise
)
def predict_noise_from_start(self, x_t, t, x0):
"""
通过 x_t 和对 x_0 的预测值,反推噪声的预测值
noise = (1 / sqrt(alpha_cumprod) * x_t - x_0) / sqrt(1 / alpha_cumprod - 1)
"""
return (
(extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - x0) / \
extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
)
def q_posterior(self, x_start, x_t, t):
"""
计算后验分布 q(x_{t-1} | x_t, x_0) 的均值和方差
"""
posterior_mean = (
extract(self.posterior_mean_coef1, t, x_t.shape) * x_start +
extract(self.posterior_mean_coef2, t, x_t.shape) * x_t
)
# 后验分布的均值
posterior_variance = extract(self.posterior_variance, t, x_t.shape)
# 后验分布的方差
posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape)
# 后验分布方差的对数(已做 clip)
return posterior_mean, posterior_variance, posterior_log_variance_clipped
def model_predictions(self, x, t, clip_x_start = False, rederive_pred_noise = False):
"""
给定当前噪声图 x 和时间步 t,通过模型预测噪声 pred_noise,并得到对 x_0 的估计 x_start
"""
model_output = self.model(x, t)
# 模型输出,通常是预测噪声
maybe_clip = partial(torch.clamp, min = -1., max = 1.) if clip_x_start else identity
# 如果需要对预测出的 x_0 做裁剪,则 partial(torch.clamp);否则恒等函数
pred_noise = model_output
# 这里把模型输出视为噪声预测
x_start = self.predict_start_from_noise(x, t, pred_noise)
x_start = maybe_clip(x_start)
# 对 x_0 进行 [-1,1] 裁剪(可选)
if clip_x_start and rederive_pred_noise:
# 如果 x_0 被裁剪,为了更准确,需要重新计算一次噪声
pred_noise = self.predict_noise_from_start(x, t, x_start)
return pred_noise, x_start
def p_mean_variance(self, x, t, clip_denoised = True):
"""
计算从扩散过程中 p(x_{t-1} | x_t) 的均值和方差,用于反向采样
"""
noise, x_start = self.model_predictions(x, t)
# 模型预测噪声和 x_0
if clip_denoised:
x_start.clamp_(-1., 1.)
# 默认会把 x_0 的范围裁剪到 [-1,1]
model_mean, posterior_variance, posterior_log_variance = self.q_posterior(
x_start = x_start,
x_t = x,
t = t
)
# 计算后验分布的均值和方差
# 这里的后验分布相当于 q(x_{t-1}|x_t, x_0)
return model_mean, posterior_variance, posterior_log_variance, x_start
@torch.no_grad()
def p_sample(self, x, t: int):
"""
在反向扩散的某一个时间步 t,从 p(x_{t-1} | x_t) 采样
"""
b, *_, device = *x.shape, x.device
batched_times = torch.full((b,), t, device = x.device, dtype = torch.long)
# 构造与批大小相同的时间张量
model_mean, _, model_log_variance, x_start = self.p_mean_variance(
x = x,
t = batched_times,
clip_denoised = True
)
# 根据 x_t 计算后验均值和方差
noise = torch.randn_like(x) if t > 0 else 0.
# 如果 t > 0 则在采样时加噪声;如果 t=0,则不再加噪声
pred_img = model_mean + (0.5 * model_log_variance).exp() * noise
# 采样公式: x_{t-1} = 均值 + 标准差 * 噪声
return pred_img, x_start
@torch.no_grad()
def p_sample_loop(self, shape, return_all_timesteps = False):
"""
从纯噪声开始,逐步反向采样还原图像
"""
batch, device = shape[0], self.betas.device
# batch 大小,使用存储在 buffer 中的 betas 的设备
img = torch.randn(shape, device = device)
# 初始从标准正态分布采样
imgs = [img]
# 用于保存采样过程中每个时间步的结果
x_start = None
###########################################
## TODO: plot the sampling process ##
###########################################
for t in tqdm(reversed(range(0, self.num_timesteps)), desc = 'sampling loop time step', total = self.num_timesteps):
# 从 T-1 到 0 逐步反向采样
img, x_start = self.p_sample(img, t)
imgs.append(img)
ret = img if not return_all_timesteps else torch.stack(imgs, dim = 1)
# 如果 return_all_timesteps=True, 返回整个采样序列;否则只返回最终生成的图像
ret = self.unnormalize(ret)
# 将图像从 [-1,1] 转回 [0,1]
return ret
@torch.no_grad()
def sample(self, batch_size = 16, return_all_timesteps = False):
"""
对外提供的采样接口
"""
image_size, channels = self.image_size, self.channels
sample_fn = self.p_sample_loop
# 默认使用 p_sample_loop 进行逐步采样
return sample_fn((batch_size, channels, image_size, image_size), return_all_timesteps = return_all_timesteps)
def q_sample(self, x_start, t, noise=None):
"""
前向扩散:从 x_0 得到 x_t 的采样
x_t = sqrt(α_cumprod) * x_0 + sqrt(1-α_cumprod) * noise
"""
noise = default(noise, lambda: torch.randn_like(x_start))
# 如果不指定噪声,则生成一个和 x_start 形状相同的高斯噪声
return (
extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
)
@property
def loss_fn(self):
return F.mse_loss
# 训练时使用的损失函数,默认是 MSE
def p_losses(self, x_start, t, noise = None):
"""
在给定 x_0 以及随机的时间步 t 时,计算训练时的损失
"""
b, c, h, w = x_start.shape
noise = default(noise, lambda: torch.randn_like(x_start))
# noise sample
x = self.q_sample(x_start = x_start, t = t, noise = noise)
# 前向扩散,将 x_0 添加噪声到 x_t
# predict and take gradient step
model_out = self.model(x, t)
# 模型对 x_t 进行估计噪声
loss = self.loss_fn(model_out, noise, reduction = 'none')
# 计算 MSE 损失 (逐元素)
loss = reduce(loss, 'b ... -> b (...)', 'mean')
# 在除 batch 之外的所有维度取平均 (即每个样本的损失)
loss = loss * extract(self.loss_weight, t, loss.shape)
# 乘以权重 (与 SNR 相关)
return loss.mean()
# 返回对整个 batch 的平均损失
def forward(self, img, *args, **kwargs):
"""
模块的前向调用接口,一般在训练时调用
"""
b, c, h, w, device, img_size, = *img.shape, img.device, self.image_size
# 解包图像形状、设备以及定义的图像大小
assert h == img_size and w == img_size, f'height and width of image must be {img_size}'
t = torch.randint(0, self.num_timesteps, (b,), device=device).long()
# 随机采样一个时间步 t 用于训练
img = self.normalize(img)
# 如果开启了 auto_normalize,则把 [0,1] 的图片映射到 [-1,1]
return self.p_losses(img, t, *args, **kwargs)
# 调用 p_losses 计算训练损失
path = './faces/faces'
# 数据所在的文件路径,这里假设所有训练图像都在 ./faces/faces 目录中
IMG_SIZE = 64
# 设置图像尺寸为 64x64
batch_size = 16
# 设置训练时的批大小为 16 张图像
train_num_steps = 10000
# 训练的总步数,指优化器更新(iteration)次数
lr = 1e-3
# 学习率 (learning rate),这里设置为 0.001
grad_steps = 1
# 梯度累积步数;若设置大于 1 则表示每累积一定次数的反向传播再进行一次优化更新
ema_decay = 0.995
# 指数移动平均 (EMA) 的衰减率,常用于在训练过程中平滑模型权重
channels = 16
# U-Net 的基础通道数,即第一个卷积层的通道数
dim_mults = (1, 2, 4)
# 用来指定 U-Net 不同下采样 / 上采样阶段的通道扩展倍数,
# 最终网络结构中的通道数将按 (channels, 2*channels, 4*channels, ...) 的形式逐步增加
timesteps = 100
# 扩散过程中加噪声的时间步数 T;比如在 DDPM 中可以是 1000,这里设置为 100
beta_schedule = 'linear'
# beta 的调度方式(表示在扩散过程中 beta 的变化),此处设置为线性
model = Unet(
dim = channels,
dim_mults = dim_mults
)
# 实例化一个 U-Net 模型对象,输入的基本通道数为 16,
# 会根据 dim_mults 逐步在网络层中增加通道数
diffusion = GaussianDiffusion(
model,
image_size = IMG_SIZE,
timesteps = timesteps,
beta_schedule = beta_schedule
)
# 将 U-Net 模型封装到 GaussianDiffusion 类中,
# 并设置扩散过程中的一些参数(如图像大小、时间步数等)。
# 该类会负责前向扩散(加噪)和反向扩散(去噪)的具体实现。
class Trainer:
def __init__(
self,
diffusion,
path,
train_batch_size=16,
train_lr=1e-3,
train_num_steps=10000,
gradient_accumulate_every=1,
ema_decay=0.995,
save_and_sample_every=1000,
):
self.accelerator = Accelerator()
self.diffusion = diffusion
self.batch_size = train_batch_size
self.gradient_accumulate_every = gradient_accumulate_every
self.train_num_steps = train_num_steps
self.ema = EMA(diffusion, decay=ema_decay)
self.save_and_sample_every = save_and_sample_every
self.path = path
self.optimizer = Adam(diffusion.parameters(), lr=train_lr)
self.dataset = Dataset(path, image_size=diffusion.image_size)
self.dataloader = DataLoader(self.dataset, batch_size=train_batch_size, shuffle=True)
self.dataloader = self.accelerator.prepare(self.dataloader)
self.diffusion = self.accelerator.prepare(self.diffusion)
self.optimizer = self.accelerator.prepare(self.optimizer)
def train(self):
progress_bar = tqdm(initial=0, total=self.train_num_steps)
progress_bar.set_description("Training")
for i in range(self.train_num_steps):
for img, _ in self.dataloader:
loss = self.diffusion(img)
self.accelerator.backward(loss / self.gradient_accumulate_every)
if i % self.gradient_accumulate_every == 0:
self.optimizer.step()
self.optimizer.zero_grad()
if i % self.save_and_sample_every == 0:
self.ema.ema_model.eval()
samples = self.ema.ema_model.sample(batch_size=self.batch_size)
print(f"Step {i}: Saving samples...")
progress_bar.update(1)
self.accelerator.end_training()
trainer = Trainer(
diffusion,
path,
train_batch_size=batch_size,
train_lr=lr,
train_num_steps=train_num_steps,
gradient_accumulate_every=grad_steps,
ema_decay=ema_decay,
save_and_sample_every=1000
)
# 实例化一个 Trainer 类来管理训练流程:
# - 使用 diffusion 模型进行前向与反向传播
# - 每个 batch 的大小为 16
# - 使用学习率 1e-3
# - 总训练步数为 10000
# - 每个 step 都更新梯度(grad_steps=1)
# - EMA 衰减因子为 0.995
# - 每 1000 步保存一次模型并进行一次采样
trainer.train()
# 开始训练,Trainer 内部会执行循环读取数据、前向计算、损失反传、优化器更新等流程。
运行环境: