当前位置:   article > 正文

基础GAN实例(pytorch代码实现)

gan实例

目录

导入库

数据准备

定义生成器

定义判别器

 初始化模型,优化器及损失计算函数

 绘图函数

GAN的训练 

运行结果 

                ​编辑


导入库

  1. import torch
  2. import torch.nn as nn
  3. import torch.nn.functional as F
  4. import torch.optim as optim #优化
  5. import numpy as np
  6. import matplotlib.pyplot as plt #绘图
  7. import torchvision #加载图片
  8. from torchvision import transforms #图片变换

数据准备

  1. #对数据做归一化(-11
  2. transform=transforms.Compose([
  3. #将shanpe为(H,W,C)的数组或img转为shape为(C,H,W)的tensor
  4. transforms.ToTensor(), #转为张量并归一化到【01】;数据只是范围变了,并没有改变分布
  5. transforms.Normalize(0.5,0.5)#数据归一化处理,将数据整理到[-1,1]之间;可让数据呈正态分布
  6. ])

 transforms.Compose(): 将多个预处理依次累加在一起, 每次执行transform都会依次执行其中包含的多个预处理程序

transforms.ToTensor():在做数据归一化之前必须要把PIL Image转成Tensor

transforms.Normalize([0.5], [0.5]):归一化,这里的两个0.5分别表示对张量进行归一化的 全局平均值和方差,因为图像是灰色的只有一个通道,所以分别指定一了一个值,如果有多个通道,需要有多个数字,如3个通道,就应该是Normalize([m1, m2, m3], [n1, n2, n3])

  1. #下载数据到指定的文件夹
  2. train_ds = torchvision.datasets.MNIST('data',
  3. train=True,
  4. transform=transform,
  5. download=True)

root :需要下载至地址的根目录位置

train:如果是True, 下载训练集 trainin.pt; 如果是False,下载测试集 test.pt; 默认是True

transform:一系列作用在PIL图片上的转换操作,返回一个转换后的版本

download:是否下载到 root指定的位置,如果指定的root位置已经存在该数据集,则不再下载

datalodar=torch.utils.data.DataLoader(train_ds,batch_size=64,shuffle=True)

PyTorch中数据读取的一个重要接口是 torch.utils.data.DataLoader。

该接口主要用来将自定义的数据读取接口的输出或者PyTorch已有的数据读取接口的输入按照batch_size封装成Tensor,后续只需要再包装成Variable即可作为模型的输入。

torch.utils.data.DataLoader(onject)的可用参数如下:

dataset(Dataset): 数据读取接口,该输出是torch.utils.data.Dataset类的对象(或者继承自该类的自定义类的对象)。

batch_size (int, optional): 批训练数据量的大小,根据具体情况设置即可。一般为2的N次方(默认:1)

shuffle (bool, optional):是否打乱数据,一般在训练数据中会采用。(默认:False

定义生成器

输出是长度为100的噪声(正态分布随机数)

输出为(1,28,28)的图片

linear 1:100---256

linear 2: 256--512

linear 3:512--784(28*28)

reshape: 784---(1,28,28)

  1. class Generator(nn.Module):
  2. def __init__(self):
  3. super(Generator,self).__init__()
  4. self.main=nn.Sequential(
  5. nn.Linear(100,256),
  6. nn.ReLU(),
  7. nn.Linear(256,512),
  8. nn.ReLU(),
  9. nn.Linear(512,784),
  10. nn.Tanh()#对于生成器,最后一个激活函数是tanh,值域:-11
  11. )
  12. #定义前向传播
  13. def forward(self,x): #x表示长度为100的noise输入
  14. img = self.main(x)
  15. img=img.view(-1,28,28)#转换成图片的形式
  16. return img

 

定义判别器

输入为为(1,28,28)的图片,输出为二分类的概率值,输出使用sigmoid的激活0-1

BCEloss计算交叉熵损失

在判别器中一般推荐使用nn.LeakyReLU

  1. class Discriminator(nn.Module):
  2. def __init__(self):
  3. super(Discriminator,self).__init__()
  4. self.main = nn.Sequential(
  5. nn.Linear(784,512),
  6. nn.LeakyReLU(),
  7. nn.Linear(512,256),
  8. nn.LeakyReLU(),
  9. nn.Linear(256,1),
  10. nn.Sigmoid()
  11. )
  12. def forward(self,x):
  13. x =x.view(-1,784) #展平
  14. x =self.main(x)
  15. return x

 初始化模型,优化器及损失计算函数

  1. #设备的配置
  2. device='cuda' if torch.cuda.is_available() else 'cpu'
  1. #初始化生成器和判别器把他们放到相应的设备上
  2. gen = Generator().to(device)
  3. dis = Discriminator().to(device)
  1. #训练器的优化器
  2. d_optim = torch.optim.Adam(dis.parameters(),lr=0.0001)
  3. #训练生成器的优化器
  4. g_optim = torch.optim.Adam(dis.parameters(),lr=0.0001)
  1. #交叉熵损失函数
  2. loss_fn = torch.nn.BCELoss()

 绘图函数

  1. def gen_img_plot(model,test_input):
  2. prediction = np.squeeze(model(test_input).detach().cpu().numpy())
  3. fig = plt.figure(figsize=(4,4))
  4. for i in range(16):
  5. plt.subplot(4,4,i+1)
  6. plt.imshow((prediction[i]+1)/2)
  7. plt.axis('off')
  8. plt.show()
test_input = torch.randn(16,100 ,device=device) #16个长度为100的随机数

GAN的训练 

  1. D_loss = []
  2. G_loss = []
  1. #训练循环
  2. for epoch in range(20):
  3. #初始化损失值
  4. d_epoch_loss = 0
  5. g_epoch_loss = 0
  6. count = len(dataloader) #返回批次数
  7. #对数据集进行迭代
  8. for step,(img,_) in enumerate(dataloader):
  9. img =img.to(device) #把数据放到设备上
  10. size = img.size(0) #img的第一位是size,获取批次的大小
  11. random_noise = torch.randn(size,100,device=device)
  12. #判别器训练(真实图片的损失和生成图片的损失),损失的构建和优化
  13. d_optim.zero_grad()#梯度归零
  14. #判别器对于真实图片产生的损失
  15. real_output = dis(img) #判别器输入真实的图片,real_output对真实图片的预测结果
  16. d_real_loss = loss_fn(real_output,
  17. torch.ones_like(real_output)
  18. )
  19. d_real_loss.backward()#计算梯度
  20. #在生成器上去计算生成器的损失,优化目标是判别器上的参数
  21. gen_img = gen(random_noise) #得到生成的图片
  22. #因为优化目标是判别器,所以对生成器上的优化目标进行截断
  23. fake_output = dis(gen_img.detach()) #判别器输入生成的图片,fake_output对生成图片的预测;detach会截断梯度,梯度就不会再传递到gen模型中了
  24. #判别器在生成图像上产生的损失
  25. d_fake_loss = loss_fn(fake_output,
  26. torch.zeros_like(fake_output)
  27. )
  28. d_fake_loss.backward()
  29. #判别器损失
  30. d_loss = d_real_loss + d_fake_loss
  31. #判别器优化
  32. d_optim.step()
  33. #生成器上损失的构建和优化
  34. g_optim.zero_grad() #先将生成器上的梯度置零
  35. fake_output = dis(gen_img)
  36. g_loss = loss_fn(fake_output,
  37. torch.ones_like(fake_output)
  38. ) #生成器损失
  39. g_loss.backward()
  40. g_optim.step()
  41. #累计每一个批次的loss
  42. with torch.no_grad():
  43. d_epoch_loss +=d_loss
  44. g_epoch_loss +=g_loss
  45. #求平均损失
  46. with torch.no_grad():
  47. d_epoch_loss /=count
  48. g_epoch_loss /=count
  49. D_loss.append(d_epoch_loss)
  50. G_loss.append(g_epoch_loss)
  51. print('Epoch:',epoch)
  52. gen_img_plot(gen,test_input)

运行结果 

一共有20张,因篇幅有限,这里先展示前六张和最后一张

                

 

 

 

 

 

 

 

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/从前慢现在也慢/article/detail/365715
推荐阅读
相关标签
  

闽ICP备14008679号