当前位置:   article > 正文

生成对抗网络(Generative Adversial Network,GAN)原理简介_生成对抗网络的原理

生成对抗网络的原理

  生成对抗网络(GAN)是深度学习中一类比较大的家族,主要功能是实现图像、音乐或文本等生成(或者说是创作),生成对抗网络的主要思想是:通过生成器(generator)与判别器(discriminator)不断对抗进行训练。最终使得判别器难以分辨生成器生成的数据(图片,音频等)和真实的数据。所以,对于生成对抗网络,我们最终的目标一般是得到生成器,因为训练结束后我们是需要得到神经网络创作出来的作品。

一:基本原理

  生成对抗网络的基本思想是训练过程中生成网络(生成器,generator)与判别网络(判别器,discriminator)不断对抗的过程。所以,无论GAN模型多么复杂,基本思想仍是这两种网络的对抗,基本结构一定要有生成器与判别器。

  为了了解GAN的一般过程,兔兔以一个经典的例子来说明。我们假设生成器是一个普通的画家,判别器是名画鉴别师。一开始这个画家的画的赝品很容易被鉴别师识别出来,然后画家再想办法使画更像真的名画,但是鉴别师之后还是能够识别。不过在一次次被鉴别的过程中,画家的画技逐渐炉火纯青,而鉴别师的鉴别水平也逐渐提高。所以最终画家可能画得十分逼真,与真的名画相差无几,以至于鉴别师也分辨不出来了。

  以上例子只是一个很笼统的比喻。如果更细致一些,实际的过程是:把一些真的画和画家的假画给鉴别师,告诉鉴别师哪个是真,哪个是假,这个过程实际上是训练鉴别师;然后画家画几幅画给鉴别师,如果有一些画被鉴别师鉴别出来了,画家需要根据这些被鉴别出来的画进行反思,使得之后的画作尽量不要被鉴别师鉴别出来,这个过程是训练画家;然后还是将真画与画家的假画交给鉴别师,告诉他哪个是真,哪个是假......这个过程依次交替进行,直到训练结束。

  以上这个过程实际上就是GAN的大致思路了。

  接下来是GAN的具体结构。兔兔以图片的生成为例,这也是GAN最经典的应用,并且这里先不讲解以文字生成图像的stack-GAN等模型,仅仅是最简单的GAN。

  GAN一般情况下可以看做是无监督学习,因为我们训练的数据只有真实的图片,并没有标签。这里使用的标签也仅仅是真实数据的"真"与生成器生成图片的"假"。

1.生成器(generator)

  首先,对于生成器,其内部一般是多层卷积、全连接层等构成的网络,并且采用上采样,通过接收的噪声来生成一个大小合适的图片。这里的噪声其实就是服从某些分布的随机数,一般选取正态分布随机数,该随机数组成一个长度为n的向量,每个这样的向量最终都会生成一个对应的图片。从某种意义来说,正是这种随机数的存在,每次生成的图片都会不同,但图片是否长得像真的,则取决于generator。当然,随机数的概率分布对图像也是有影响的,如果训练时用高斯噪声,那么使用时也要用这个噪声。

2.判别器(discriminator)

  对于判别器,其内部一般也是多层卷积,全连接等层组成的网络,并且采用下采样,它接收一个批次的图片(batch,c,w,h),每个图片都相应标记为真或假(batch,1)(标签也可以用one-hot编码表示(batch,2)),其中假图片是generator生成的。判别器的训练过程和我们以往的监督学习的训练方法是一致的。

3.生成对抗模型(GAN,adversarial model)

  在构建好generator与discriminator后,将两个组合起来,形成对抗网络GAN,用来训练生成网络。

  对于GAN,它接收一批次噪声,输出为“真”或“假”标签,如果为真,说明生成器生成的这个图片骗过了判别器,否则根据这个损失来调整generator内部参数。这个训练过程也和以往的监督学习的训练方法相同,不过这里的discriminator参数不能更改,只有generator内部参数可以改,毕竟这是训练生成器的过程,而不是训练判别器;并且标签是"True",因为我们是想让generator生成的图像看起来更像是真的。

  整个训练的过程为:

(1).从高斯分布中采样一批次长度为n的噪声向量。

(2).利用(1)中噪声向量,使用generator生成假图像。

(3).从真实数据采一批次真实图像,与(2)中的假图像混合,做好标签,训练discriminator。

(4).再从高斯分布中采样长度为n的一批次噪声向量,标签为“True”,训练GAN,此时GAN中的discriminator参数不能更新,只训练generator。

(5).按指定轮数重复上述步骤。

二:基本框架

  1. import torch
  2. import numpy as np
  3. from torch.utils.data import DataLoader
  4. from torch import nn
  5. import argparse
  6. class Generator(nn.Module):
  7. '''生成器'''
  8. def __init__(self):
  9. super(Generator, self).__init__()
  10. pass
  11. def forward(self,input):
  12. pass
  13. class Discriminator(nn.Module):
  14. '''判别器'''
  15. def __init__(self):
  16. super(Discriminator, self).__init__()
  17. pass
  18. def forward(self,input):
  19. pass
  20. class GAN(nn.Module):
  21. '''GAN模型'''
  22. def __init__(self):
  23. super(GAN, self).__init__()
  24. self.gene=Generator()
  25. self.gene.requires_grad_(True)
  26. self.disc=Discriminator()
  27. self.disc.requires_grad_(False)
  28. def forward(self,input):
  29. out=self.gene(input)
  30. out=self.disc(out)
  31. return out
  32. class dataset:
  33. '''真实图片数据集'''
  34. def __init__(self):
  35. pass
  36. def __len__(self):
  37. pass
  38. def __getitem__(self, item):
  39. pass
  40. if __name__=='__main__':
  41. parser=argparse.ArgumentParser()
  42. parser.add_argument('--epoch',type=int,default=10,help='the train epoch')
  43. parser.add_argument('--n',type=int,default=200,help='the length of noise')
  44. parser.add_argument('--noise_batch',type=int,default=10,help='the batch size of noise')
  45. parser.add_argument('--true_batch',type=int,default=10,help='the batch size of true picture')
  46. opt=parser.parse_args()
  47. gene=Generator()
  48. disc=Discriminator()
  49. gan=GAN()
  50. disc_optim=torch.optim.Adam(disc.parameters())
  51. gan_optim=torch.optim.Adam(gan.parameters())
  52. criterion=nn.MSELoss()
  53. for i in range(opt.epoch):
  54. true_pict = DataLoader(dataset(), batch_size=opt.true_batch, shuffle=True)
  55. for batch in true_pict:
  56. noise=torch.tensor(np.random.normal((opt.noise_batch,opt.n)),dtype=torch.float32)
  57. false_pict=gene(noise)
  58. label_true=torch.tensor(np.ones(shape=opt.true_batch),dtype=torch.float32)
  59. label_fake=torch.tensor(np.ones(shape=opt.noise_batch),dtype=torch.float32)
  60. loss_true=criterion(batch,label_true)
  61. loss_fake=criterion(false_pict,label_fake)
  62. disc_optim.zero_grad()
  63. loss_fake.backward()
  64. loss_true.backward()
  65. loss_fake.step()
  66. loss_true.step()
  67. noise1=torch.tensor(np.random.normal((opt.noise_batch,opt.n)),dtype=torch.float32)
  68. pict=gene(noise1)
  69. label=torch.tensor(np.ones(shape=(opt.noise_batch,opt.n)),dtype=torch.float32)
  70. loss_gan=criterion(pict,label)
  71. loss_gan.zero_grad()
  72. loss_gan.backward()
  73. loss_gan.step()

当然,关于内部的训练批次等问题,以及判别网络与生成网络每次的训练次数,有时是需要具体问题具体分析的。

三:总结:

生成对抗网络的发展时间并不长,但是目前却已经有非常多的GAN模型,并由GAN衍生出提出许多新的方法。生成对抗网络不仅打开了深度学习在创作领域的大门,更重要的是它所带来的一种新法方法与思想,对诸多领域有着深远影响。

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/Cpp五条/article/detail/499087
推荐阅读
相关标签
  

闽ICP备14008679号