人工智能(pytorch)搭建模型11-pytorch搭建DCGAN模型,一种生成对抗网络GAN的变体实际应用

人工智能(pytorch)搭建模型11-pytorch搭建DCGAN模型,一种生成对抗网络GAN的变体实际应用

大家好,我是微学AI,今天给大家介绍一下人工智能(pytorch)搭建模型11-pytorch搭建DCGAN模型,一种生成对抗网络GAN的变体实际应用,本文将具体介绍DCGAN模型的原理,并使用PyTorch搭建一个简单的DCGAN模型。我们将提供模型代码,并使用一些数据样例进行训练和测试。最后,我们将展示训练过程中的损失值和准确率。

文章目录:

  1. DCGAN模型简介
  2. DCGAN模型原理
  3. 使用PyTorch搭建DCGAN模型
  4. 数据样例
  5. 训练模型
  6. 测试模型
  7. 总结

1. DCGAN模型简介

DCGAN全称:Deep Convolutional Generative Adversarial Networks,它是一种生成对抗网络(GAN)的变体,它使用卷积神经网络(CNN)作为生成器和判别器。DCGAN在图像生成任务中表现出色,能够生成具有高分辨率和清晰度的图像。

2. DCGAN模型原理

DCGAN模型由两个部分组成:生成器(Generator)和判别器(Discriminator)。生成器负责生成图像,而判别器负责判断图像是否为真实图像。在训练过程中,生成器和判别器相互竞争,生成器试图生成越来越逼真的图像,而判别器试图更准确地识别生成的图像是否为真实图像。这个过程持续进行,直到生成器生成的图像足够逼真,以至于判别器无法区分生成的图像和真实图像。

DCGAN模型的数学原理表示:

生成器(Generator):

G ( z ) = x G(z) = x G(z)=x

其中, z z z是输入的随机噪声向量, x x x是生成的图像。

判别器(Discriminator):

D ( x ) = y D(x) = y D(x)=y

其中, x x x是输入的图像, y y y是判别器对图像的判断结果,表示图像是否为真实图像。

GAN的损失函数:

min ⁡ G max ⁡ D V ( D , G ) = E x ∼ p d a t a ( x ) [ log ⁡ D ( x ) ] + E z ∼ p z ( z ) [ log ⁡ ( 1 − D ( G ( z ) ) ) ] \min_G \max_D V(D,G) = \mathbb{E}{x \sim p_{data}(x)}[\log D(x)] + \mathbb{E}_{z \sim p_z(z)}[\log(1-D(G(z)))] Gmin​Dmax​V(D,G)=Ex∼pdata​(x)[logD(x)]+Ez∼pz​(z)​[log(1−D(G(z)))]

其中, p d a t a ( x ) p_{data}(x) pdata​(x)表示真实数据的分, p z ( z ) p_z(z) pz​(z)表示噪声向量的分布, D ( x ) D(x) D(x)表示判别器对图像 x x x的判断结果, G ( z ) G(z) G(z)表示生成器生成的图像, log ⁡ D ( x ) \log D(x) logD(x)表示判别器将真实图像判断为真实图像的概率, log ⁡ ( 1 − D ( G ( z ) ) ) \log(1-D(G(z))) log(1−D(G(z)))表示判别器将生成图像判断为真实图像的概率。

www.zeeklog.com - 人工智能(pytorch)搭建模型11-pytorch搭建DCGAN模型,一种生成对抗网络GAN的变体实际应用

3. 使用PyTorch搭建DCGAN模型

首先,我们需要导入所需的库:

import torch import torch.nn as nn import torch.optim as optim import torchvision import torchvision.transforms as transforms import torchvision.datasets as dset from torch.autograd import Variable 

接下来,我们定义生成器和判别器的网络结构:

class Generator(nn.Module): def __init__(self): super(Generator, self).__init__() self.main = nn.Sequential( # 输入是一个100维的向量 nn.ConvTranspose2d(100, 512, 4, 1, 0, bias=False), nn.BatchNorm2d(512), nn.ReLU(True), # 输出为(512, 4, 4) nn.ConvTranspose2d(512, 256, 4, 2, 1, bias=False), nn.BatchNorm2d(256), nn.ReLU(True), # 输出为(256, 8, 8) nn.ConvTranspose2d(256, 128, 4, 2, 1, bias=False), nn.BatchNorm2d(128), nn.ReLU(True), # 输出为(128, 16, 16) nn.ConvTranspose2d(128, 3, 4, 2, 1, bias=False), nn.Tanh() # 输出为(3, 32, 32) ) def forward(self, input): return self.main(input) class Discriminator(nn.Module): def __init__(self): super(Discriminator, self).__init__() self.main = nn.Sequential( # 输入为(3, 32, 32) nn.Conv2d(3, 128, 4, 2, 1, bias=False), nn.LeakyReLU(0.2, inplace=True), # 输出为(128, 16, 16) nn.Conv2d(128, 256, 4, 2, 1, bias=False), nn.BatchNorm2d(256), nn.LeakyReLU(0.2, inplace=True), # 输出为(256, 8, 8) nn.Conv2d(256, 512, 4, 2, 1, bias=False), nn.BatchNorm2d(512), nn.LeakyReLU(0.2, inplace=True), # 输出为(512, 4, 4) nn.Conv2d(512, 1, 4, 1, 0, bias=False), nn.Sigmoid() ) def forward(self, input): return self.main(input).view(-1) 

4. 数据样例

我们将使用CIFAR-10数据集进行训练。首先,我们需要对数据进行预处理:

if __name__ =="__main__": transform = transforms.Compose([ transforms.Resize(32), transforms.CenterCrop(32), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), ]) trainset = dset.CIFAR10(root='./data', train=True, download=True, transform=transform) trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True, num_workers=2) 

5. 训练模型

接下来,我们将训练DCGAN模型:

# 初始化生成器和判别器 netG = Generator() netD = Discriminator() # 设置损失函数和优化器 criterion = nn.BCELoss() optimizerD = optim.Adam(netD.parameters(), lr=0.0002, betas=(0.5, 0.999)) optimizerG = optim.Adam(netG.parameters(), lr=0.0002, betas=(0.5, 0.999)) # 训练模型 num_epochs = 10 for epoch in range(num_epochs): for i, data in enumerate(trainloader, 0): # 更新判别器 netD.zero_grad() real, _ = data batch_size = real.size(0) label = torch.full((batch_size,), 1) output = netD(real) errD_real = criterion(output, label) errD_real.backward() noise = torch.randn(batch_size, 100, 1, 1) fake = netG(noise) label.fill_(0) output = netD(fake.detach()) errD_fake = criterion(output, label) errD_fake.backward() errD = errD_real + errD_fake optimizerD.step() # 更新生成器 netG.zero_grad() label.fill_(1) output = netD(fake) errG = criterion(output, label) errG.backward() optimizerG.step() if i%5==0: # 打印损失值 print('[%d/%d][%d/%d] Loss_D: %.4f Loss_G: %.4f' % (epoch, num_epochs, i, len(trainloader), errD.item(), errG.item())) 

6. 测试模型

训练完成后,我们可以使用生成器生成一些图像进行测试:

import matplotlib.pyplot as plt import numpy as np def imshow(img): img = img / 2 + 0.5 npimg = img.numpy() plt.imshow(np.transpose(npimg, (1, 2, 0))) plt.show() noise = torch.randn(64, 100, 1, 1) fake = netG(noise) imshow(torchvision.utils.make_grid(fake.detach())) 

7. 总结

本文详细介绍了DCGAN模型的原理,并使用PyTorch搭建了一个简单的DCGAN模型。我们提供了模型代码,并使用CIFAR-10数据集进行训练和测试。最后,我们展示了训练过程中的损失值和生成的图像。希望本文能帮助您更好地理解DCGAN模型,并在实际项目中应用。

Read more

C++ 继承入门(上):从基础概念定义到默认成员函数,吃透类复用的核心逻辑

C++ 继承入门(上):从基础概念定义到默认成员函数,吃透类复用的核心逻辑

🔥小叶-duck:个人主页 ❄️个人专栏:《Data-Structure-Learning》 《C++入门到进阶&自我学习过程记录》《算法题讲解指南》--从优选到贪心 ✨未择之路,不须回头 已择之路,纵是荆棘遍野,亦作花海遨游 目录 前言 一. 继承的概念与定义   1、继承的核心概念   2、继承的定义格式   3、继承方式与成员访问权限 二. 基类与派生类的转换:子类对象能当父类用吗? 三. 继承中的作用域:同名成员会冲突吗?   1、变量隐藏   2、函数隐藏 四、派生类的默认成员函数:构造、拷贝、析构怎么写?   1、构造函数:先调用父类构造,再初始化子类成员   2、拷贝构造:先拷贝父类,再拷贝子类   3、 赋值重载:

By Ne0inhk
《C++ 多态》三大面向对象编程——多态:虚函数机制、重写规范与现代C++多态控制全概要

《C++ 多态》三大面向对象编程——多态:虚函数机制、重写规范与现代C++多态控制全概要

🔥个人主页:Cx330🌸 ❄️个人专栏:《C语言》《LeetCode刷题集》《数据结构-初阶》《C++知识分享》 《优选算法指南-必刷经典100题》《Linux操作系统》:从入门到入魔 🌟心向往之行必能至 🎥Cx330🌸的简介: 目录 前言: 一、认识多态:面向对象编程的灵魂 1.1  多态的核心概念解析 1.2  联系实际:现实世界中的多态类比 二、多态的实现机制深度探索 2.1  多态的本质与构成必要条件 2.1.1  多态的科学定义 2.1.2  实现多态的两个关键条件 2.2  虚函数:多态的基石 2.3  虚函数重写(覆盖)详解 2.4

By Ne0inhk
C++ 继承入门(下):友元、静态成员与菱形继承的底层逻辑

C++ 继承入门(下):友元、静态成员与菱形继承的底层逻辑

🔥小叶-duck:个人主页 ❄️个人专栏:《Data-Structure-Learning》 《C++入门到进阶&自我学习过程记录》《算法题讲解指南》--从优选到贪心 ✨未择之路,不须回头 已择之路,纵是荆棘遍野,亦作花海遨游 目录 前言 一. 友元 —— 友元关系不可继承   1、错误版本   2、正确版本 二. 静态成员 —— 继承体系中静态成员的共享性 三. 多继承及菱形继承问题:本质特点与解决方案   1、单继承与多继承模型   2、菱形继承:虚继承解决“数据冗余”与“二义性”     2.1 菱形继承出现的坑(解决二义性问题)     2.2 虚继承:彻底解决菱形继承问题     3、多继承中指针偏移问题 友元,静态成员,

By Ne0inhk
《C++进阶之STL》【哈希表】

《C++进阶之STL》【哈希表】

【哈希表】目录 * 前言 * ------------概念介绍------------ * 1. 什么是哈希? * ------------核心术语------------ * 一、哈希函数 * 1. 哈希函数的核心特点是什么? * 2. 哈希函数的设计目标是什么? * 3. 常见的哈希函数有哪些? * 直接定址法 * 除法散列法 * 乘法散列法 * 全域散列法 * 二、负载因子 * 1. 什么是负载因子? * 2. 负载因子对哈希表的性能有什么影响? * 3. 负载因子超过阈值时会发什么? * 三、哈希冲突 * 四、冲突处理 * 方法一:开放定址法 * 线性探测 * 二次探测 * 双重散列 * 方法二:链地址法 * ------------基本操作------------ * 怎么解决键key不能取模的问题? * 一、开放定址法 * 哈希结构 * 删除操作 * 扩容操作 * 二、链地址法 * 哈希结构 *

By Ne0inhk