跳到主要内容
极客日志极客日志面向AI+效率的开发者社区
首页博客GitHub 精选镜像工具UI配色美学隐私政策关于联系
搜索内容 / 工具 / 仓库 / 镜像...⌘K搜索
注册
博客列表
PythonAI算法

深度生成模型对比:VAE、GAN、AR、Flow 与 Diffusion 原理及实现

综述由AI生成对比了 VAE、GAN、自回归模型、流模型和扩散模型五大深度生成架构。涵盖核心原理、训练目标函数及优缺点分析,并提供基于 PyTorch 的代码实现示例。重点解析了潜在空间分布学习、对抗博弈机制、序列预测、可逆变换及去噪过程,帮助读者理解不同模型在图像、文本等场景下的适用性与权衡。

清酒独酌发布于 2026/4/8更新于 2026/6/1221 浏览
深度生成模型对比:VAE、GAN、AR、Flow 与 Diffusion 原理及实现

随着 Sora、Diffusion、GPT 等模型的爆发,深度生成模型再次成为技术焦点。这类模型能从输入数据学习潜在分布,生成与训练数据相似的新样本,在计算机视觉、自然语言处理等领域广泛应用。

本文汇总了五种主流深度学习生成模型:VAE(变分自编码器)、GAN(生成对抗网络)、AR(自回归模型)、Flow(流模型)和 Diffusion(扩散模型),深入解析其原理、损失函数及代码实现。

模型核心目标原理优点缺点应用场景
VAE学习潜在空间分布,重构与生成样本基于变分推断,映射到正态分布,优化重构误差与 KL 散度训练稳定,支持插值;多样性较好生成图像模糊;KL 约束可能丢失信息数据填充、特征提取、图像修复
GAN生成器与判别器对抗,生成逼真样本零和博弈优化,达到纳什均衡细节丰富;推理速度快训练不稳定;多样性不足;调参难艺术创作、风格迁移、超分辨率
AR自回归预测序列下一个元素概率条件概率分解,捕捉长程依赖建模能力强;训练稳定生成速度慢;高维计算成本高文本生成、时序预测、图像生成
Flow可逆变换转换分布,精确密度估计设计可逆层,利用变量变换公式计算似然支持精确密度估计;生成重建可逆高维下变换复杂;雅可比行列式开销大语音合成、密度估计
Diffusion逐步去噪重建数据分布,高质量生成正向加噪与逆向去噪结合,马尔可夫链建模生成质量最高;训练稳定推理慢;显存占用高高清图像、多模态/视频生成

1 变分自编码器(VAE)

1.1 概念

VAE 在自编码器的基础上结合了变分推断和贝叶斯理论。它的目标是学习一个能生成与训练数据相似样本的模型。VAE 假设隐变量服从某种先验分布(如标准正态分布),通过编码器将输入映射到隐变量的后验分布,再通过解码器还原生成样本。

简单来说,VAE 不仅要求解码器能把隐变量还原成接近原图的样子,还强制隐变量的分布符合常识(如正态分布)。这就像学习绘画时,既要准确临摹,又要符合透视比例规则。

1.2 训练损失

VAE 的训练损失包含两部分:重构损失(衡量重建能力)和 KL 散度(约束潜在分布与先验分布的差异)。

损失函数逻辑:

  • 重构项:常用均方误差或交叉熵,确保解码器能还原输入。
  • KL 散度项:约束潜在分布 $q(z|x)$ 与先验分布 $p(z)$ 的相似性,平衡参数为 $eta$。

优化目标是最大化证据下界(ELBO),同时保证潜在空间的结构化和连续性。

直观理解:VAE 的损失函数像'双面裁判'。一面监督'重建能力',另一面监督'规则意识'。如果只关注重建,模型可能生成奇形怪状的样本;如果过度强调规则,样本又会千篇一律。$eta$ 参数就像音量旋钮,调节这两者的权重。

1.3 VAE 的实现

下面是一个基于 PyTorch 的简化实现,展示了编码器、解码器及重参数化技巧。

import torch
import torch.nn as nn
import torch.nn.functional as F

class VAE(nn.Module):
    def __init__(self, input_dim=784, hidden_dim=400, latent_dim=20):
        super(VAE, self).__init__()
        # 编码器:输入 → 隐藏层 → 均值和方差
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, latent_dim * 2)
        )
        # 解码器:潜在变量 → 隐藏层 → 重构输入
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, input_dim),
            nn.Sigmoid() 
        )

    def reparameterize(self, mu, log_var):
        """重参数化技巧:从 N(μ, σ²) 采样潜在变量 z"""
        std = torch.exp(0.5 * log_var)
        eps = torch.randn_like(std)
        return mu + eps * std

    def forward(self, x):
        h = self.encoder(x)
        mu, log_var = torch.chunk(h, 2, dim=1)
        z = self.reparameterize(mu, log_var)
        x_recon = self.decoder(z)
        return x_recon, mu, log_var

    def loss_function(self, x_recon, x, mu, log_var):
        recon_loss = F.binary_cross_entropy(x_recon, x, reduction='sum')
        kl_div = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())
        return recon_loss + kl_div

2 生成对抗网络(GAN)

2.1 概念

GAN 由**生成器(Generator)**和 **判别器(Discriminator)**组成。生成器试图制造逼真的假数据,判别器则负责区分真假。两者通过竞争进化,最终生成器能产出难以分辨的样本。

训练过程:

  1. 判别器接受真实数据和生成数据,进行二分类训练。
  2. 生成器根据判别器反馈,尝试欺骗它。
  3. 交替训练,直到判别器无法区分真伪。

2.2 训练损失

a. 判别器的损失函数

目标是最大化正确判断的概率。对真实样本输出 1,对生成样本输出 0。 损失是两部分交叉熵的总和:惩罚对真实样本判断错误和对生成样本判断错误的情况。

b. 生成器的损失函数

目标是让判别器误判生成样本为真(即让判别器输出趋近于 1)。

c. 对抗训练的动态过程

初期生成器随机生成低质量样本,判别器轻松识别。随着训练进行,生成器改进技术,判别器被迷惑,损失上升。最终达到平衡,生成器能以假乱真。

2.3 GAN 的实现

这里展示简化的 Generator 和 Discriminator 结构及训练循环逻辑。

class Generator(nn.Module):
    def __init__(self, noise_dim=100, output_dim=784):
        super(Generator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(noise_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 512),
            nn.ReLU(),
            nn.Linear(512, output_dim),
            nn.Tanh() 
        )
    def forward(self, z):
        return self.model(z).view(-1, 1, 28, 28)

class Discriminator(nn.Module):
    def __init__(self, input_dim=784):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(input_dim, 512),
            nn.LeakyReLU(0.2),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 1),
            nn.Sigmoid()
        )
    def forward(self, x):
        x = x.view(-1, 784)
        return self.model(x)

# 训练循环示例(简化版)
def train_gan():
    G = Generator()
    D = Discriminator()
    criterion = nn.BCELoss()
    for real_images, _ in dataloader:
        # 训练判别器
        real_labels = torch.ones(real_images.size(0), 1)
        fake_labels = torch.zeros(real_images.size(0), 1)
        real_loss = criterion(D(real_images), real_labels)
        z = torch.randn(real_images.size(0), 100)
        fake_images = G(z)
        fake_loss = criterion(D(fake_images.detach()), fake_labels)
        d_loss = real_loss + fake_loss
        d_loss.backward()
        optimizer_D.step()
        
        # 训练生成器
        g_loss = criterion(D(fake_images), real_labels)
        g_loss.backward()
        optimizer_G.step()

3 自回归模型(AR)

3.1 概念

自回归模型是一种基于序列数据的生成模型,通过预测序列中下一个元素的值来生成数据。给定序列 $(x_1, x_2, ..., x_n)$,模型学习条件概率分布 $P(x_t | x_{t-1}, ..., x_1)$。

早期 RNN 在处理长序列时存在梯度消失问题,Transformer 的出现解决了这一痛点。GPT、Bert 等大模型均基于 Transformer 架构实现了卓越性能。

3.2 训练过程

a. 核心思想:用历史预测未来

根据过去的输出预测未来的输出。例如语言模型根据'今天天气'预测下一个词。

b. Transformer 的损失计算

采用交叉熵监督预测。输入序列右移一位作为目标序列,模型对每个位置预测下一个词的概率分布。

c. 损失计算的具体步骤
  1. 嵌入与位置编码:将输入转换为向量并添加位置信息。
  2. 因果掩码:屏蔽未来信息,预测时只能看到历史。
  3. 多头注意力与前馈网络:整合历史信息。
  4. 输出层:映射到词表概率分布。
  5. 计算损失:对比预测概率与真实标签的 one-hot 编码。

为什么使用交叉熵? 这是分类问题的天然选择。每个位置的预测本质上是多分类任务(从词表中选一个词)。

3.3 代码实现(Transformer-AR)

这是一个基于 Transformer 的自回归图像生成模型示例(Pixel Transformer)。

class TransformerAR(nn.Module):
    def __init__(self, vocab_size=256, embed_dim=128, num_heads=4, num_layers=3):
        super(TransformerAR, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.positional_enc = nn.Parameter(torch.randn(784, embed_dim))
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=embed_dim, nhead=num_heads, dim_feedforward=512
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers)
        self.fc = nn.Linear(embed_dim, vocab_size)

    def forward(self, x):
        x = self.embedding(x) + self.positional_enc
        mask = torch.triu(torch.ones(784, 784), diagonal=1).bool()
        out = self.transformer(x, mask=mask)
        logits = self.fc(out)
        return logits

    def generate(self, start_token, max_len=784):
        generated = start_token
        for _ in range(max_len):
            logits = self(generated)
            next_pixel = torch.multinomial(F.softmax(logits[:, -1, :], dim=-1), 1)
            generated = torch.cat([generated, next_pixel], dim=1)
        return generated

4 流模型(Flow)

4.1 概念

流模型基于可逆变换,将简单分布(如正态分布)转换为复杂的数据分布。核心思想是用'可逆魔法'转换分布,既能变形也能恢复原状。

想象有一团橡皮泥(简单分布),通过一系列可逆向操作的手法(拉伸、折叠),把它捏成跟真实数据分布一样复杂的形状。这种特性使得流模型支持精确密度估计。

4.2 训练过程

流模型通过最小化负对数似然来训练。利用变量变换公式计算数据对数似然,优化雅可比行列式。

4.3 代码实现(Flow)

基于 RealNVP 的可逆流模型示例。

class FlowModel(nn.Module):
    def __init__(self, input_dim=784, hidden_dim=512):
        super(FlowModel, self).__init__()
        self.scale_net = nn.Sequential(
            nn.Linear(input_dim//2, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, input_dim//2)
        )
        self.shift_net = nn.Sequential(
            nn.Linear(input_dim//2, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, input_dim//2)
        )

    def forward(self, x):
        x1, x2 = x.chunk(2, dim=1)
        s = self.scale_net(x1)
        t = self.shift_net(x1)
        z2 = x2 * torch.exp(s) + t
        z = torch.cat([x1, z2], dim=1)
        log_det = s.sum(dim=1)
        return z, log_det

    def inverse(self, z):
        z1, z2 = z.chunk(2, dim=1)
        s = self.scale_net(z1)
        t = self.shift_net(z1)
        x2 = (z2 - t) * torch.exp(-s)
        x = torch.cat([z1, x2], dim=1)
        return x

    def flow_loss(self, z, log_det):
        prior_logprob = -0.5 * (z ** 2).sum(dim=1)
        return (-prior_logprob - log_det).mean()

5 扩散模型(Diffusion)

5.1 概念

Diffusion Model 灵感来源于物理扩散过程。与传统模型不同,它模拟数据从随机噪声逐渐扩散到目标数据的过程。

DDPM 证明了不用像 VAE 那样学方差,只学个均值就能有很好的效果。扩散模型类似编码器 - 解码器结构,但每个时间步输出的特征图大小一致,且共享 U-Net 参数,有点像 RNN 不断循环。

核心思想:模拟'破坏 - 修复'的物理过程

  • 正向扩散:给干净数据逐步加噪声,让数据从清晰变模糊,最后接近纯噪声。
  • 反向扩散:从纯噪声出发,一步步去除噪声,恢复成清晰数据。

5.2 训练过程

损失函数的核心是衡量'预测噪声'与'真实噪声'的差距,常用均方误差(MSE)。

在正向扩散中,模型知道每个时间步加了多少真实噪声。训练时,U-Net 根据带噪样本预测噪声,损失函数要求预测值尽可能接近真实噪声。通过最小化损失,U-Net 学会分析带噪数据的特征,最终在反向扩散时还原出清晰数据。

5.3 代码实现(Diffusion)

基于 UNet 的扩散模型简化实现。

class DiffusionModel(nn.Module):
    def __init__(self, image_size=28, channels=1):
        super(DiffusionModel, self).__init__()
        self.net = nn.Sequential(
            nn.Conv2d(channels, 64, 3, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 64, 3, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, channels, 3, padding=1)
        )
        self.num_steps = 1000
        self.betas = torch.linspace(1e-4, 0.02, self.num_steps)
        self.alphas = 1 - self.betas
        self.alpha_bars = torch.cumprod(self.alphas, dim=0)

    def forward(self, x, t):
        return self.net(x)

    def train_step(self, x0):
        t = torch.randint(0, self.num_steps, (x0.size(0),))
        sqrt_alpha_bar = torch.sqrt(self.alpha_bars[t]).view(-1, 1, 1, 1)
        sqrt_one_minus_alpha_bar = torch.sqrt(1 - self.alpha_bars[t]).view(-1, 1, 1, 1)
        epsilon = torch.randn_like(x0)
        xt = sqrt_alpha_bar * x0 + sqrt_one_minus_alpha_bar * epsilon
        epsilon_pred = self(xt, t)
        loss = F.mse_loss(epsilon_pred, epsilon)
        return loss

    def sample(self, num_samples=16):
        xt = torch.randn(num_samples, 1, 28, 28)
        for t in reversed(range(self.num_steps)):
            epsilon_pred = self(xt, t)
            xt = (xt - self.betas[t] * epsilon_pred) / torch.sqrt(self.alphas[t])
            if t > 0:
                xt += torch.sqrt(self.betas[t]) * torch.randn_like(xt)
        return xt

6 小结

回顾这五种常见的深度学习生成模型:

  • VAE 和 GAN 是基础架构,分别基于贝叶斯概率理论和对抗训练。
  • AR 模型适用于处理具有时序依赖关系的数据,如序列数据。
  • Flow 和 Diffusion 在生成样本上具有较好的稳定性和多样性,但计算成本较高。

未来研究方向:

  • 混合架构融合:结合 Diffusion 的高质量与 GAN 的速度,或 VAE 的压缩能力。
  • 轻量化:通过知识蒸馏、量化压缩降低部署资源消耗。
  • 物理约束嵌入:引入刚体动力学或流体力学方程,使生成内容符合现实规律。

不同模型各有优劣,选择时需权衡生成质量、推理速度及计算资源。

目录

  1. 1 变分自编码器(VAE)
  2. 1.1 概念
  3. 1.2 训练损失
  4. 1.3 VAE 的实现
  5. 2 生成对抗网络(GAN)
  6. 2.1 概念
  7. 2.2 训练损失
  8. a. 判别器的损失函数
  9. b. 生成器的损失函数
  10. c. 对抗训练的动态过程
  11. 2.3 GAN 的实现
  12. 训练循环示例(简化版)
  13. 3 自回归模型(AR)
  14. 3.1 概念
  15. 3.2 训练过程
  16. a. 核心思想:用历史预测未来
  17. b. Transformer 的损失计算
  18. c. 损失计算的具体步骤
  19. 3.3 代码实现(Transformer-AR)
  20. 4 流模型(Flow)
  21. 4.1 概念
  22. 4.2 训练过程
  23. 4.3 代码实现(Flow)
  24. 5 扩散模型(Diffusion)
  25. 5.1 概念
  26. 5.2 训练过程
  27. 5.3 代码实现(Diffusion)
  28. 6 小结
  • 💰 8折买阿里云服务器限时8折了解详情
  • Magick API 一键接入全球大模型注册送1000万token查看
  • 🤖 一键搭建Deepseek满血版了解详情
  • 一键打造专属AI 智能体了解详情
极客日志微信公众号二维码

微信扫一扫,关注极客日志

微信公众号「极客日志V2」,在微信中扫描左侧二维码关注。展示文案:极客日志V2 zeeklog

更多推荐文章

查看全部
  • GitHub 日榜精选(2026-01-08):AI Agent、Web 分析与开发工具
  • OpenClaw 内网穿透实战:随时随地访问本地 AI 服务
  • 在线图书借阅平台设计与实现
  • Spring Boot 微服务架构设计与实现
  • Spring AI Agent Skills 接入实战与原理剖析
  • Copilot 人工智能助手介绍
  • 基于 Java 的无人化台球棋牌茶室系统架构设计
  • 基于 SpringBoot 的宠物销售系统设计与实现
  • GLM-4.7 与 MiniMax M2.1 工程级 Agent 模型接入指南
  • Spring Cloud Config 与 Apollo 配置中心架构深度解析
  • 大模型(LLM)在企业中的典型应用场景
  • LocalSend:免费开源跨平台局域网文件传输工具
  • HDFS 分布式文件系统编程实践
  • B 站生态观察:从二次元社区到 AI 创新孵化器
  • AI 大模型应用入门实战:构建你的第一个大模型指南
  • 一个完整的车辆监控管理系统,包含后端API、Web管理后台和移动端应用
  • 成为黑客的 12 个基本步骤与核心技能指南
  • 大模型检索增强生成(RAG)技术综述
  • Java 全栈工程师面试实录:从基础到项目实战
  • KES 数据库运维:资源回收与膨胀防治全攻略

相关免费在线工具

  • 加密/解密文本

    使用加密算法(如AES、TripleDES、Rabbit或RC4)加密和解密文本明文。 在线工具,加密/解密文本在线工具,online

  • RSA密钥对生成器

    生成新的随机RSA私钥和公钥pem证书。 在线工具,RSA密钥对生成器在线工具,online

  • Mermaid 预览与可视化编辑

    基于 Mermaid.js 实时预览流程图、时序图等图表,支持源码编辑与即时渲染。 在线工具,Mermaid 预览与可视化编辑在线工具,online

  • 随机西班牙地址生成器

    随机生成西班牙地址(支持马德里、加泰罗尼亚、安达卢西亚、瓦伦西亚筛选),支持数量快捷选择、显示全部与下载。 在线工具,随机西班牙地址生成器在线工具,online

  • Gemini 图片去水印

    基于开源反向 Alpha 混合算法去除 Gemini/Nano Banana 图片水印,支持批量处理与下载。 在线工具,Gemini 图片去水印在线工具,online

  • curl 转代码

    解析常见 curl 参数并生成 fetch、axios、PHP curl 或 Python requests 示例代码。 在线工具,curl 转代码在线工具,online