当前位置:   article > 正文

深度学习--生成对抗网络GAN_深度学习gan网络

深度学习gan网络

GAN简介

让我们先来简单了解一下GAN

GAN的全称是Generative Adversarial Networks,中文称为“生成对抗网络”,是一种在深度学习领域广泛使用的无监督学习方法。

GAN主要由两部分组成:生成器判别器。生成器的目标是尽可能地生成真实的样本数据,而判别器的目标是尽可能准确地辨别出生成样本与真实样本。这两个组件通过竞争和对抗的方式共同工作,以提升各自的能力。这种网络结构能够处理没有标注数据的问题,并且在图像处理、自然语言处理等多个领域都有广泛应用。

它通过对抗生成来训练,目的是估测数据样本的潜在分布并生成新的数据样本。

GAN结构图

原理 

生成器根据噪声,也就是随机值,来生成样本,而判别器判断哪些是真实数据,哪些是生成数据,然后将学习的经验反向传播给生成器,让生成器生成的样本不断向真实样本靠拢。

在训练过程中,生成器努力让生成的数据更加真实,而判别器努力的去判别数据的真假,二者·形成了对抗。最终两个网络形成了动态平衡,生成样本接近真实样本,而判别器也分辨不出来样本的真假,最终对给定图像预测为真的概率基本接近0.5,也就相当于随即猜测类别了。

公式

在公式中,

z代表输入G网络的噪声,

x代表真实图片

G(z)表示G网络生成的图片,

D(*)表示D网络判断图片是否真实的概率

2.GAN的算法流程和公式详解_哔哩哔哩_bilibili 

在这个视频里有对这个公式的详解,这里就不详细说了/

 我们经过简单了解之后,就要开始搭建GAN网络了,这里我们以手写字体识别数据集为例。

构建GAN网络的步骤

GAN生成对抗网络,步骤:
首先编写生成器和判别器
然后固定生成器,用我们的数据优化判别器,试得我们最开始生成器生成的图片判断为0,真实图片判断为1
接着固定判别器,利用我们的判别器判断生成器生成的图片,以判断的尽可能接近1为目的优化我们的生成器
生成器的代码(针对手写字体识别)

 预备知识

transforms.Normalize

transforms.Normalize()函数用于对图像数据进行【标准化】处理。在深度学习中,数据标准化是一个常见的预处理步骤,它有助于模型更快地收敛,并提高模型的性能

作用

数据标准化:如上所述,transforms.Normalize()函数可以对图像数据进行标准化处理,使数据分布符合标准正态分布。这有助于模型更快地收敛,并提高模型的性能。


提高模型泛化能力:通过对数据进行标准化,我们可以减少模型对特定数据集的过拟合,从而提高模型在未见过的数据上的泛化能力。


加速模型训练:标准化的数据可以使模型在训练过程中更快地学习到数据的特征,从而加速模型的训练速度。

参数
  • mean:(list)长度与输入的通道数相同,代表每个通道上所有数值的平均值
  • std:(list)长度与输入的通道数相同,代表每个通道上所有数值的标准差

Datadoder

参数

dataset(数据集):需要提取数据的数据集,Dataset对象
batch_size(批大小):每一次装载样本的个数,int型
 shuffle(洗牌):进行新一轮epoch时是否要重新洗牌,Boolean型
num_workers:是否多进程读取机制
drop_last:当样本数不能被batchsize整除时, 是否舍弃最后一批数据

LeakyReLu函数

图像及参数

我们可以与ReLu函数对比,看一下区别:

 

主要区别就是在小于0的部分了 

代码

导库

  1. import matplotlib.pyplot as plt
  2. import matplotlib
  3. import torch
  4. from torch.utils.data import DataLoader
  5. import torchvision
  6. from torchvision import transforms
  7. import numpy as np

数据集处理

  1. transform = transforms.Compose([
  2. transforms.ToTensor(),
  3. transforms.Normalize(0.5, 0.5)
  4. ])
  5. traindata = torchvision.datasets.MNIST(root='D:\learn_pytorch\数据集', train=True, download=True,
  6. transform=transform) # 训练集60,000张用于训练

在加载数据集时,我们要将数据进行归一化,在GAN中,我们就需要将数据归一化到(-1,1)之间,这是为什么呢?原因是我们在下面会用到Tanh激活函数,而Tanh函数的范围是在-1到1之间的,见下图:

在我们既然知道了为什么要这样,下面就要学会如何做到了 

ToTensor中,我们是将数据的范围限制在了(0,1)之间,而后面的Normalize是将数据限制在(-1,1)之间,计算公式为(x-均值)/方差 

生成器

  1. class Generator(torch.nn.Module):
  2. def __init__(self):
  3. super(Generator, self).__init__()
  4. self.main = torch.nn.Sequential(
  5. torch.nn.Linear(100, 256),
  6. torch.nn.ReLU(),
  7. torch.nn.Linear(256, 512),
  8. torch.nn.ReLU(),
  9. torch.nn.Linear(512, 28 * 28),
  10. torch.nn.Tanh()
  11. )
  12. def forward(self, x):
  13. img = self.main(x)
  14. img = img.reshape(-1, 28, 28)
  15. return img

在这里,我们需要知道,生成器的输入和输出是什么,输入时我们的噪音,而输出一张图片。

在后向传播中,我们最后再将图片进行展平。

判别器

  1. class Discraiminator(torch.nn.Module):
  2. def __init__(self):
  3. super(Discraiminator, self).__init__()
  4. self.mainf = torch.nn.Sequential(
  5. torch.nn.Linear(28 * 28, 512),
  6. torch.nn.LeakyReLU(),
  7. torch.nn.Linear(512, 256),
  8. torch.nn.LeakyReLU(),
  9. torch.nn.Linear(256, 1),
  10. torch.nn.Sigmoid()
  11. )
  12. def forward(self, x):
  13. x = x.view(-1, 28 * 28)
  14. x = self.mainf(x)
  15. return x

我们同样需要了解判别器的输入和输出,输入是一张(1,28,28)图片,输出为二分类的概率值。

在判别器中,我们如果使用ReLu函数,在小于0的部分就会出现梯度消失的问题,这时候我们就可以用到LeadkyReLu了,它能够优化GAN的训练。

最后的Sigmoid激活函数,将输出压缩到0到1之间,这通常用于二分类问题,但在这里,它用于表示输入是真实数据的概率。

而在后向传播中,我们需要先对图片进行展平。


定义损失函数,优化函数和优化器

  1. # 定义损失函数和优化函数
  2. device = 'cuda' if torch.cuda.is_available() else 'cpu'
  3. gen = Generator().to(device)
  4. dis = Discraiminator().to(device)
  5. # 定义优化器
  6. gen_opt = torch.optim.Adam(gen.parameters(), lr=0.0001)
  7. dis_opt = torch.optim.Adam(dis.parameters(), lr=0.0001)
  8. loss_fn = torch.nn.BCELoss() # 损失函数

在这里,我们选择使用BCELoss,交叉熵损失函数,这是因为在GAN中,判别器通常被视为一个二分类器,它试图区分输入是真实样本还是由生成器生成的假样本,而BCELoss就是用来做二分类的损失函数,正好对应。

在优化器部分,它们分别对生成器和判别器的参数进行优化。

图像显示

  1. def gen_img_plot(model, testdata):
  2. pre = np.squeeze(model(testdata).detach().cpu().numpy())
  3. # tensor.detach()
  4. # 返回一个新的tensor,从当前计算图中分离下来的,但是仍指向原变量的存放位置,不同之处只是requires_grad为false,得到的这个tensor永远不需要计算其梯度,不具有grad。
  5. # 即使之后重新将它的requires_grad置为true,它也不会具有梯度grad
  6. # 这样我们就会继续使用这个新的tensor进行计算,后面当我们进行反向传播时,到该调用detach()的tensor就会停止,不能再继续向前进行传播
  7. plt.figure()
  8. for i in range(16):
  9. plt.subplot(4, 4, i + 1)
  10. plt.imshow(pre[i])
  11. plt.show()

因为我们最终要得到要得到的是处理数据输出的数组,所以我们要用squeeze将额外的单维度删除。

detach是单独开辟空间来保存数据,从而保证数据的稳定性。

plt.figure用来生成一个新画布。

 使用subplot函数在一个4x4的网格中定位每个子图。i + 1是因为子图的索引是从1开始的,而不是从0开始。 

imshow是在子图中显示图像。

最后的show来显示整体的图像。

后向传播与训练模型

  1. dis_loss = [] # 判别器损失值记录
  2. gen_loss = [] # 生成器损失值记录
  3. lun = [] # 轮数
  4. for epoch in range(60):
  5. d_epoch_loss = 0
  6. g_epoch_loss = 0
  7. cout = len(trainload) # 938批次
  8. for step, (img, _) in enumerate(trainload):
  9. img = img.to(device) # 图像数据
  10. # print('img.size:',img.shape)#img.size: torch.Size([64, 1, 28, 28])
  11. size = img.size(0) # 一批次的图片数量64
  12. # 随机生成一批次的100维向量样本,或者说100个像素点
  13. random_noise = torch.randn(size, 100, device=device)
  14. # 判断器的后向传播
  15. dis_opt.zero_grad()
  16. real_output = dis(img)
  17. d_real_loss = loss_fn(real_output, torch.ones_like(real_output)) # 真实数据的损失函数值
  18. d_real_loss.backward()
  19. gen_img = gen(random_noise)
  20. fake_output = dis(gen_img.detach())
  21. d_fake_loss = loss_fn(fake_output, torch.zeros_like(fake_output)) # 人造的数据的损失函数值
  22. d_fake_loss.backward()
  23. d_loss = d_real_loss + d_fake_loss
  24. dis_opt.step()
  25. # 生成器的后向传播
  26. gen_opt.zero_grad()
  27. fake_output = dis(gen_img)
  28. g_loss = loss_fn(fake_output, torch.ones_like(fake_output))
  29. g_loss.backward()
  30. gen_opt.step()
  31. d_epoch_loss += d_loss
  32. g_epoch_loss += g_loss
  33. dis_loss.append(float(d_epoch_loss))
  34. gen_loss.append(float(g_epoch_loss))
  35. print(f'第{epoch + 1}轮的生成器损失值:{g_epoch_loss},判别器损失值{d_epoch_loss}')
  36. lun.append(epoch + 1)

使用enumerate遍历训练数据集trainload,其中img是图像数据,但_表示我们在这里不使用标签(因为GAN是无监督的)。

step()用来更新判别器的模型参数。

在生成器的后向传播部分,

我们先进行梯度清零,然后通过生成器生成假图像,然后进行前向传播。

我们期望判别器对假图像的评分接近1(真实),因此我们将目标标签设置为与fake_output形状相同的全1张量torch.ones_like(fake_output)

在这里,d_loss和g_loss是一张图像中的损失值,而d_epoch_loss和g_epoch_loss是每一轮损失值的累加,用于最后图像的绘制。

生成图像

  1. matplotlib.rcParams['font.sans-serif'] = ['KaiTi']
  2. plt.figure()
  3. plt.plot(lun, dis_loss, 'r', label='判别器损失值')
  4. plt.plot(lun, gen_loss, 'b', label='生成器损失值')
  5. plt.xlabel('训练轮数', fontsize=12)
  6. plt.ylabel('损失值', fontsize=12)
  7. plt.title('损失值随着训练轮数得变化情况:', fontsize=18)
  8. plt.legend()
  9. plt.show()
  10. random_noise = torch.randn(16, 100, device=device)
  11. gen_img_plot(gen, random_noise)

随机生成的噪声有16个样本,100个维度

全部代码

  1. import matplotlib.pyplot as plt
  2. import matplotlib
  3. import torch
  4. from torch.utils.data import DataLoader
  5. import torchvision
  6. from torchvision import transforms
  7. import numpy as np
  8. # 导入数据集并且进行数据处理
  9. transform = transforms.Compose([
  10. transforms.ToTensor(),
  11. transforms.Normalize(0.5, 0.5)
  12. ])
  13. traindata = torchvision.datasets.MNIST(root='./data', train=True, download=True,
  14. transform=transform) # 训练集60,000张用于训练
  15. # 利用DataLoader加载数据集
  16. trainload = DataLoader(dataset=traindata, shuffle=True, batch_size=64)
  17. # GAN生成对抗网络,步骤:
  18. # 首先编写生成器和判别器
  19. # 然后固定生成器,用我们的数据优化判别器,试得我们最开始生成器生成的图片判断为0,真实图片判断为1
  20. # 接着固定判别器,利用我们的判别器判断生成器生成的图片,以判断的尽可能接近一为目的优化我们的生成器
  21. # 生成器的代码(针对手写字体识别)
  22. class Generator(torch.nn.Module):
  23. def __init__(self):
  24. super(Generator, self).__init__()
  25. self.main = torch.nn.Sequential(
  26. torch.nn.Linear(100, 256),
  27. torch.nn.ReLU(),
  28. torch.nn.Linear(256, 512),
  29. torch.nn.ReLU(),
  30. torch.nn.Linear(512, 28 * 28),
  31. torch.nn.Tanh()
  32. )
  33. def forward(self, x):
  34. img = self.main(x)
  35. img = img.reshape(-1, 28, 28)
  36. return img
  37. # 判别器,最后判断0,1,这意味着最后可以是一个神经元或者两个神经元
  38. class Discraiminator(torch.nn.Module):
  39. def __init__(self):
  40. super(Discraiminator, self).__init__()
  41. self.mainf = torch.nn.Sequential(
  42. torch.nn.Linear(28 * 28, 512),
  43. torch.nn.LeakyReLU(),
  44. torch.nn.Linear(512, 256),
  45. torch.nn.LeakyReLU(),
  46. torch.nn.Linear(256, 1),
  47. torch.nn.Sigmoid()
  48. )
  49. def forward(self, x):
  50. x = x.view(-1, 28 * 28)
  51. x = self.mainf(x)
  52. return x
  53. # 定义损失函数和优化函数
  54. device = 'cuda' if torch.cuda.is_available() else 'cpu'
  55. gen = Generator().to(device)
  56. dis = Discraiminator().to(device)
  57. # 定义优化器
  58. gen_opt = torch.optim.Adam(gen.parameters(), lr=0.0001)
  59. dis_opt = torch.optim.Adam(dis.parameters(), lr=0.0001)
  60. loss_fn = torch.nn.BCELoss() # 损失函数
  61. def gen_img_plot(model, testdata):
  62. pre = np.squeeze(model(testdata).detach().cpu().numpy())
  63. # tensor.detach()
  64. # 返回一个新的tensor,从当前计算图中分离下来的,但是仍指向原变量的存放位置,不同之处只是requires_grad为false,得到的这个tensor永远不需要计算其梯度,不具有grad。
  65. # 即使之后重新将它的requires_grad置为true,它也不会具有梯度grad
  66. # 这样我们就会继续使用这个新的tensor进行计算,后面当我们进行反向传播时,到该调用detach()的tensor就会停止,不能再继续向前进行传播
  67. plt.figure()
  68. for i in range(16):
  69. plt.subplot(4, 4, i + 1)
  70. plt.imshow(pre[i])
  71. plt.show()
  72. # 后向传播
  73. dis_loss = [] # 判别器损失值记录
  74. gen_loss = [] # 生成器损失值记录
  75. lun = [] # 轮数
  76. for epoch in range(60):
  77. d_epoch_loss = 0
  78. g_epoch_loss = 0
  79. cout = len(trainload) # 938批次
  80. for step, (img, _) in enumerate(trainload):
  81. img = img.to(device) # 图像数据
  82. # print('img.size:',img.shape)#img.size: torch.Size([64, 1, 28, 28])
  83. size = img.size(0) # 一批次的图片数量64
  84. # 随机生成一批次的100维向量样本,或者说100个像素点
  85. random_noise = torch.randn(size, 100, device=device)
  86. # 判断器的后向传播
  87. dis_opt.zero_grad()
  88. real_output = dis(img)
  89. d_real_loss = loss_fn(real_output, torch.ones_like(real_output)) # 真实数据的损失函数值
  90. d_real_loss.backward()
  91. gen_img = gen(random_noise)
  92. fake_output = dis(gen_img.detach())
  93. d_fake_loss = loss_fn(fake_output, torch.zeros_like(fake_output)) # 人造的数据的损失函数值
  94. d_fake_loss.backward()
  95. d_loss = d_real_loss + d_fake_loss
  96. dis_opt.step()
  97. # 生成器的后向传播
  98. gen_opt.zero_grad()
  99. fake_output = dis(gen_img)
  100. g_loss = loss_fn(fake_output, torch.ones_like(fake_output))
  101. g_loss.backward()
  102. gen_opt.step()
  103. d_epoch_loss += d_loss
  104. g_epoch_loss += g_loss
  105. dis_loss.append(float(d_epoch_loss))
  106. gen_loss.append(float(g_epoch_loss))
  107. print(f'第{epoch + 1}轮的生成器损失值:{g_epoch_loss},判别器损失值{d_epoch_loss}')
  108. lun.append(epoch + 1)
  109. matplotlib.rcParams['font.sans-serif'] = ['KaiTi']
  110. plt.figure()
  111. plt.plot(lun, dis_loss, 'r', label='判别器损失值')
  112. plt.plot(lun, gen_loss, 'b', label='生成器损失值')
  113. plt.xlabel('训练轮数', fontsize=12)
  114. plt.ylabel('损失值', fontsize=12)
  115. plt.title('损失值随着训练轮数得变化情况:', fontsize=18)
  116. plt.legend()
  117. plt.show()
  118. random_noise = torch.randn(16, 100, device=device)
  119. gen_img_plot(gen, random_noise)

运行结果
 

  1. 1轮的生成器损失值:2226.86328125,判别器损失值461.5265808105469
  2. 2轮的生成器损失值:2378.969970703125,判别器损失值459.3701477050781
  3. 3轮的生成器损失值:2422.438232421875,判别器损失值355.0154113769531
  4. 4轮的生成器损失值:3410.994873046875,判别器损失值172.3834686279297
  5. 5轮的生成器损失值:3589.7734375,判别器损失值168.7844696044922
  6. 6轮的生成器损失值:3944.258544921875,判别器损失值125.10688781738281
  7. 7轮的生成器损失值:4293.7861328125,判别器损失值138.3419952392578
  8. 8轮的生成器损失值:4436.89404296875,判别器损失值159.64407348632812
  9. 9轮的生成器损失值:4485.7646484375,判别器损失值177.5517578125
  10. 10轮的生成器损失值:4136.85986328125,判别器损失值210.64602661132812
  11. 11轮的生成器损失值:4072.7958984375,判别器损失值246.29910278320312
  12. 12轮的生成器损失值:4298.8623046875,判别器损失值183.00152587890625
  13. 13轮的生成器损失值:4899.4794921875,判别器损失值171.33628845214844
  14. 14轮的生成器损失值:4851.458984375,判别器损失值161.920654296875
  15. 15轮的生成器损失值:4995.62646484375,判别器损失值155.28732299804688
  16. 16轮的生成器损失值:4987.4140625,判别器损失值142.6618194580078
  17. 17轮的生成器损失值:5511.90673828125,判别器损失值126.41560363769531
  18. 18轮的生成器损失值:5509.65771484375,判别器损失值157.1754913330078
  19. 19轮的生成器损失值:5164.8671875,判别器损失值143.5445556640625
  20. 20轮的生成器损失值:5490.17236328125,判别器损失值156.86929321289062
  21. 21轮的生成器损失值:5189.4921875,判别器损失值177.5731201171875
  22. 22轮的生成器损失值:5293.32080078125,判别器损失值168.159912109375
  23. 23轮的生成器损失值:4971.2646484375,判别器损失值189.78167724609375
  24. 24轮的生成器损失值:4590.87158203125,判别器损失值211.07289123535156
  25. 25轮的生成器损失值:4739.5732421875,判别器损失值214.7382354736328
  26. 26轮的生成器损失值:4700.568359375,判别器损失值218.89926147460938
  27. 27轮的生成器损失值:4146.5048828125,判别器损失值269.0607604980469
  28. 28轮的生成器损失值:3846.898681640625,判别器损失值287.00604248046875
  29. 29轮的生成器损失值:3559.870361328125,判别器损失值317.5647888183594
  30. 30轮的生成器损失值:3378.71240234375,判别器损失值336.30572509765625
  31. 31轮的生成器损失值:4269.37060546875,判别器损失值257.89910888671875
  32. 32轮的生成器损失值:5209.896484375,判别器损失值191.99989318847656
  33. 33轮的生成器损失值:4632.1728515625,判别器损失值261.9479064941406
  34. 34轮的生成器损失值:2979.66015625,判别器损失值363.874267578125
  35. 35轮的生成器损失值:2710.74462890625,判别器损失值405.0263671875
  36. 36轮的生成器损失值:2661.800048828125,判别器损失值421.5466613769531
  37. 37轮的生成器损失值:2625.377197265625,判别器损失值414.751708984375
  38. 38轮的生成器损失值:2809.101318359375,判别器损失值399.09942626953125
  39. 39轮的生成器损失值:3797.715087890625,判别器损失值314.6676025390625
  40. 40轮的生成器损失值:6223.8974609375,判别器损失值151.0428924560547
  41. 41轮的生成器损失值:3305.96533203125,判别器损失值355.9456481933594
  42. 42轮的生成器损失值:2672.400634765625,判别器损失值395.23834228515625
  43. 43轮的生成器损失值:2538.265625,判别器损失值425.629638671875
  44. 44轮的生成器损失值:2496.415283203125,判别器损失值443.06085205078125
  45. 45轮的生成器损失值:2451.716796875,判别器损失值449.18194580078125
  46. 46轮的生成器损失值:2397.526123046875,判别器损失值467.0350341796875
  47. 47轮的生成器损失值:2427.2900390625,判别器损失值459.0263977050781
  48. 48轮的生成器损失值:2440.54736328125,判别器损失值469.6186218261719
  49. 49轮的生成器损失值:2597.76953125,判别器损失值439.3223876953125
  50. 50轮的生成器损失值:2724.003173828125,判别器损失值438.4668273925781
  51. 51轮的生成器损失值:2539.636474609375,判别器损失值459.2343444824219
  52. 52轮的生成器损失值:2288.4130859375,判别器损失值498.2747802734375
  53. 53轮的生成器损失值:2244.51513671875,判别器损失值506.4640197753906
  54. 54轮的生成器损失值:2242.865478515625,判别器损失值502.57275390625
  55. 55轮的生成器损失值:2198.66552734375,判别器损失值506.5917053222656
  56. 56轮的生成器损失值:2217.268310546875,判别器损失值502.77081298828125
  57. 57轮的生成器损失值:2246.22802734375,判别器损失值502.93206787109375
  58. 58轮的生成器损失值:2165.259033203125,判别器损失值516.4965209960938
  59. 59轮的生成器损失值:2146.760009765625,判别器损失值519.462890625
  60. 60轮的生成器损失值:2110.582763671875,判别器损失值528.8636474609375
  61. 进程已结束,退出代码为 0

我们得生成器损失值是波动的,判别器损失值也是,很难说他们的趋势走向(当然估计和我的训练轮数有关) 

 

这是我们生成器生成的“伪造的图片”,从这里可以看出来已经很不错了。  

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/黑客灵魂/article/detail/873760
推荐阅读
相关标签
  

闽ICP备14008679号