跳到主要内容5 种主流深度生成模型对比:VAE、GAN、AR、Flow 与 Diffusion 原理及实现 | 极客日志PythonAI算法
5 种主流深度生成模型对比:VAE、GAN、AR、Flow 与 Diffusion 原理及实现
综述由AI生成深度生成模型涵盖 VAE、GAN、AR、Flow 和 Diffusion 五大类。VAE 通过变分推断学习潜在分布,训练稳定但生成略模糊;GAN 利用对抗博弈提升细节,但训练难收敛;AR 模型基于序列预测,适合文本与时序,推理速度受限;Flow 模型通过可逆变换实现精确密度估计;Diffusion 模型凭借逐步去噪机制在图像生成质量上表现最佳。本文对比了各模型的核心原理、损失函数及 PyTorch 代码实现,分析了优缺点与适用场景,为技术选型提供参考。
利刃15 浏览 5 种主流深度生成模型对比:VAE、GAN、AR、Flow 与 Diffusion 原理及实现
随着 Sora、Diffusion、GPT 等模型的兴起,深度生成模型再次成为技术焦点。这类模型能从输入数据中学习潜在分布,生成与训练数据相似的新样本,在计算机视觉、自然语言处理等领域应用广泛。
本文汇总了五种常用的深度学习生成模型:VAE(变分自编码器)、GAN(生成对抗网络)、AR(自回归模型)、Flow(流模型)和 Diffusion(扩散模型),深入解析其原理、损失函数及代码实现。
| 模型 | 核心目标 | 原理 | 优点 | 缺点 | 应用场景 |
|---|
| VAE | 学习潜在空间分布,重构样本 | 基于变分推断,优化重构误差与 KL 散度 | 训练稳定,支持插值 | 生成图像模糊 | 数据填充、特征提取 |
| GAN | 生成器与判别器对抗,生成逼真样本 | 零和博弈优化,达到纳什均衡 | 细节丰富,推理快 | 训练不稳定,多样性不足 | 艺术创作、风格迁移 |
| AR | 自回归预测序列下一个元素 | 条件概率分解,捕捉长程依赖 | 建模能力强,训练稳 | 生成速度慢,计算成本高 | 文本生成、时序预测 |
| Flow | 可逆变换转换分布,精确密度估计 | 设计可逆层,利用变量变换公式 | 支持精确密度估计 | 高维下变换设计复杂 | 语音合成、密度估计 |
| Diffusion | 逐步去噪重建数据分布 | 正向加噪与逆向去噪结合 | 生成质量最高,训练稳 | 推理慢,显存占用高 | 高清图像、视频生成 |
1. 变分自编码器(VAE)
1.1 概念
VAE 结合了自编码器与变分推断。它假设隐变量服从某种先验分布(如标准正态分布),通过编码器将输入映射到隐变量的后验分布,再由解码器还原样本。
简单来说,VAE 不仅要求能还原数据,还要求隐空间是连续且有序的,这样我们才能从隐空间中采样出新的合理数据。
1.2 训练损失
VAE 的损失函数包含两部分:重构损失和KL 散度。
- 重构项:衡量解码器重建输入的能力,常用均方误差或交叉熵。
- KL 散度项:约束潜在分布 $q(z|x)$ 与先验分布 $p(z)$ 的相似性。
优化目标是最大化证据下界(ELBO)。这就像在学习绘画时,既要准确临摹原图(重构),又要符合透视比例规则(KL 约束)。
1.3 VAE 的实现
下面是一个基于 PyTorch 的简化实现,展示了编码、重参数化技巧和解码过程。
import torch
import torch.nn as nn
import torch.nn.functional F
(nn.Module):
():
(VAE, ).__init__()
.encoder = nn.Sequential(
nn.Linear(input_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, latent_dim * )
)
.decoder = nn.Sequential(
nn.Linear(latent_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, input_dim),
nn.Sigmoid()
)
():
std = torch.exp( * log_var)
eps = torch.randn_like(std)
mu + eps * std
():
h = .encoder(x)
mu, log_var = torch.chunk(h, , dim=)
z = .reparameterize(mu, log_var)
x_recon = .decoder(z)
x_recon, mu, log_var
():
recon_loss = F.binary_cross_entropy(x_recon, x, reduction=)
kl_div = - * torch.( + log_var - mu.() - log_var.exp())
recon_loss + kl_div
as
class
VAE
def
__init__
self, input_dim=784, hidden_dim=400, latent_dim=20
super
self
self
2
self
def
reparameterize
self, mu, log_var
"""重参数化技巧:从 N(mu, sigma^2) 采样 z"""
0.5
return
def
forward
self, x
self
2
1
self
self
return
def
loss_function
self, x_recon, x, mu, log_var
'sum'
0.5
sum
1
pow
2
return
2. 生成对抗网络(GAN)
2.1 概念
GAN 由**生成器(Generator)和判别器(Discriminator)**组成。生成器试图制造假数据欺骗判别器,判别器则努力区分真假。两者在对抗中共同进化。
2.2 训练损失
a. 判别器的损失
判别器希望真实样本输出 1,生成样本输出 0。损失为二分类交叉熵之和。
b. 生成器的损失
生成器希望判别器将生成的样本误判为 1。损失函数旨在最小化判别器对生成样本的正确判断概率。
c. 对抗训练的动态过程
训练初期,生成器质量差,判别器轻松识别;随着训练进行,生成器改进,判别器被迷惑,最终达到平衡,生成器能产出以假乱真的样本。
2.3 GAN 的实现
这里展示一个简化的全连接 GAN 结构,包含生成器和判别器的定义及交替训练逻辑。
class Generator(nn.Module):
def __init__(self, noise_dim=100, output_dim=784):
super(Generator, self).__init__()
self.model = nn.Sequential(
nn.Linear(noise_dim, 256),
nn.ReLU(),
nn.Linear(256, 512),
nn.ReLU(),
nn.Linear(512, output_dim),
nn.Tanh()
)
def forward(self, z):
return self.model(z).view(-1, 1, 28, 28)
class Discriminator(nn.Module):
def __init__(self, input_dim=784):
super(Discriminator, self).__init__()
self.model = nn.Sequential(
nn.Linear(input_dim, 512),
nn.LeakyReLU(0.2),
nn.Linear(512, 256),
nn.LeakyReLU(0.2),
nn.Linear(256, 1),
nn.Sigmoid()
)
def forward(self, x):
x = x.view(-1, 784)
return self.model(x)
def train_gan():
G = Generator()
D = Discriminator()
criterion = nn.BCELoss()
for real_images, _ in dataloader:
real_labels = torch.ones(real_images.size(0), 1)
fake_labels = torch.zeros(real_images.size(0), 1)
real_loss = criterion(D(real_images), real_labels)
z = torch.randn(real_images.size(0), 100)
fake_images = G(z)
fake_loss = criterion(D(fake_images.detach()), fake_labels)
d_loss = real_loss + fake_loss
d_loss.backward()
optimizer_D.step()
g_loss = criterion(D(fake_images), real_labels)
g_loss.backward()
optimizer_G.step()
3. 自回归模型(AR)
3.1 概念
AR 模型基于序列数据,通过预测下一个元素的值来生成数据。给定序列 $(x_1, ..., x_n)$,模型学习条件概率 $P(x_t | x_{t-1}, ..., x_1)$。
Transformer 是 AR 的典型代表,通过注意力机制捕捉长程依赖。相比 RNN,它解决了梯度消失问题,更适合长序列。
3.2 训练过程
a. 核心思想
用历史预测未来。例如根据'今天天气'预测下一个词'很'。
b. Transformer 的损失计算
使用交叉熵监督预测。输入序列右移一位作为目标,模型预测每个位置下一个词的概率分布。
c. 具体步骤
- 嵌入与位置编码:将词转换为向量并添加位置信息。
- 因果掩码:屏蔽未来信息,确保预测只依赖过去。
- 多头注意力:整合历史信息。
- 输出层:映射到词表概率分布。
- 计算损失:对比预测与真实标签。
3.3 代码实现(Transformer-AR)
这是一个基于 Transformer 的自回归图像生成模型示例,将图像展平为序列进行处理。
class TransformerAR(nn.Module):
def __init__(self, vocab_size=256, embed_dim=128, num_heads=4, num_layers=3):
super(TransformerAR, self).__init__()
self.embedding = nn.Embedding(vocab_size, embed_dim)
self.positional_enc = nn.Parameter(torch.randn(784, embed_dim))
encoder_layer = nn.TransformerEncoderLayer(
d_model=embed_dim, nhead=num_heads, dim_feedforward=512
)
self.transformer = nn.TransformerEncoder(encoder_layer, num_layers)
self.fc = nn.Linear(embed_dim, vocab_size)
def forward(self, x):
x = self.embedding(x) + self.positional_enc
mask = torch.triu(torch.ones(784, 784), diagonal=1).bool()
out = self.transformer(x, mask=mask)
logits = self.fc(out)
return logits
def generate(self, start_token, max_len=784):
generated = start_token
for _ in range(max_len):
logits = self(generated)
next_pixel = torch.multinomial(F.softmax(logits[:, -1, :], dim=-1), 1)
generated = torch.cat([generated, next_pixel], dim=1)
return generated
4. 流模型(Flow)
4.1 概念
流模型通过一系列可逆变换,将简单分布(如正态分布)转换为复杂的数据分布。核心在于变换的可逆性,这使得我们可以精确计算概率密度。
想象橡皮泥变形,既能捏成任意形状,又能完美恢复原状。这种特性让 Flow 模型支持精确的密度估计。
4.2 训练过程
训练目标是最大化数据的对数似然。通过变量变换公式,将复杂分布的对数似然转化为简单分布的对数似然加上雅可比行列式的对数。
4.3 代码实现(Flow)
以下是一个基于 RealNVP 的简化流模型实现。
class FlowModel(nn.Module):
def __init__(self, input_dim=784, hidden_dim=512):
super(FlowModel, self).__init__()
self.scale_net = nn.Sequential(
nn.Linear(input_dim//2, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, input_dim//2)
)
self.shift_net = nn.Sequential(
nn.Linear(input_dim//2, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, input_dim//2)
)
def forward(self, x):
x1, x2 = x.chunk(2, dim=1)
s = self.scale_net(x1)
t = self.shift_net(x1)
z2 = x2 * torch.exp(s) + t
z = torch.cat([x1, z2], dim=1)
log_det = s.sum(dim=1)
return z, log_det
def inverse(self, z):
z1, z2 = z.chunk(2, dim=1)
s = self.scale_net(z1)
t = self.shift_net(z1)
x2 = (z2 - t) * torch.exp(-s)
x = torch.cat([z1, x2], dim=1)
return x
def flow_loss(self, z, log_det):
prior_logprob = -0.5 * (z ** 2).sum(dim=1)
return (-prior_logprob - log_det).mean()
5. 扩散模型(Diffusion)
5.1 概念
- 正向扩散:逐步给数据加噪声,直到变成纯噪声。
- 反向扩散:学习如何从噪声逐步去噪,恢复数据。
DDPM 证明了只需学习均值即可取得很好的生成效果。它通常采用 U-Net 架构,并在每个时间步共享参数,通过时间嵌入告诉模型当前进度。
5.2 训练过程
损失函数的核心是衡量'预测噪声'与'真实噪声'的差距,常用均方误差(MSE)。U-Net 根据带噪样本预测噪声,通过最小化损失学会还原清晰数据。
5.3 代码实现(Diffusion)
这是一个基于 UNet 的简化扩散模型实现,展示了前向加噪和反向采样过程。
class DiffusionModel(nn.Module):
def __init__(self, image_size=28, channels=1):
super(DiffusionModel, self).__init__()
self.net = nn.Sequential(
nn.Conv2d(channels, 64, 3, padding=1),
nn.ReLU(),
nn.Conv2d(64, 64, 3, padding=1),
nn.ReLU(),
nn.Conv2d(64, channels, 3, padding=1)
)
self.num_steps = 1000
self.betas = torch.linspace(1e-4, 0.02, self.num_steps)
self.alphas = 1 - self.betas
self.alpha_bars = torch.cumprod(self.alphas, dim=0)
def forward(self, x, t):
return self.net(x)
def train_step(self, x0):
t = torch.randint(0, self.num_steps, (x0.size(0),))
sqrt_alpha_bar = torch.sqrt(self.alpha_bars[t]).view(-1, 1, 1, 1)
sqrt_one_minus_alpha_bar = torch.sqrt(1 - self.alpha_bars[t]).view(-1, 1, 1, 1)
epsilon = torch.randn_like(x0)
xt = sqrt_alpha_bar * x0 + sqrt_one_minus_alpha_bar * epsilon
epsilon_pred = self(xt, t)
loss = F.mse_loss(epsilon_pred, epsilon)
return loss
def sample(self, num_samples=16):
xt = torch.randn(num_samples, 1, 28, 28)
for t in reversed(range(self.num_steps)):
epsilon_pred = self(xt, t)
xt = (xt - self.betas[t] * epsilon_pred) / torch.sqrt(self.alphas[t])
if t > 0:
xt += torch.sqrt(self.betas[t]) * torch.randn_like(xt)
return xt
6. 小结
- VAE 和 GAN 是经典方案,前者基于概率理论,后者基于对抗训练。
- AR 擅长处理时序依赖,但推理较慢。
- Flow 支持精确密度估计,但高维变换复杂。
- Diffusion 目前生成质量最高,适合高质量图像和视频生成,但计算成本较大。
未来的研究方向可能集中在混合架构融合(如 Diffusion-GAN)、轻量化部署以及物理约束嵌入等方面,以兼顾生成质量与推理效率。
相关免费在线工具
- 加密/解密文本
使用加密算法(如AES、TripleDES、Rabbit或RC4)加密和解密文本明文。 在线工具,加密/解密文本在线工具,online
- RSA密钥对生成器
生成新的随机RSA私钥和公钥pem证书。 在线工具,RSA密钥对生成器在线工具,online
- Mermaid 预览与可视化编辑
基于 Mermaid.js 实时预览流程图、时序图等图表,支持源码编辑与即时渲染。 在线工具,Mermaid 预览与可视化编辑在线工具,online
- 随机西班牙地址生成器
随机生成西班牙地址(支持马德里、加泰罗尼亚、安达卢西亚、瓦伦西亚筛选),支持数量快捷选择、显示全部与下载。 在线工具,随机西班牙地址生成器在线工具,online
- Gemini 图片去水印
基于开源反向 Alpha 混合算法去除 Gemini/Nano Banana 图片水印,支持批量处理与下载。 在线工具,Gemini 图片去水印在线工具,online
- curl 转代码
解析常见 curl 参数并生成 fetch、axios、PHP curl 或 Python requests 示例代码。 在线工具,curl 转代码在线工具,online