当前位置:   article > 正文

简单对抗神经网络GAN实现与讲解-图片对抗_gan的简单实现

gan的简单实现

1、理论讲解,清晰易懂:

一文看懂「生成对抗网络 - GAN」基本原理+10种典型算法+13种应用 (easyai.tech)

2、代码实现集合:

GitHub - eriklindernoren/Keras-GAN: Keras implementations of Generative Adversarial Networks.

3、这里简单说以下 

GNN的通俗理解基于两个对手之间相互博弈,共同进步。类似于:假设一个城市治安混乱,很快,这个城市里就会出现无数的小偷。在这些小偷中,有的可能是盗窃高手,有的可能毫无技术可言。假如这个城市开始整饬其治安,突然开展一场打击犯罪的「运动」,警察们开始恢复城市中的巡逻,很快,一批「学艺不精」的小偷就被捉住了。之所以捉住的是那些没有技术含量的小偷,是因为警察们的技术也不行了,在捉住一批低端小偷后,城市的治安水平变得怎样倒还不好说,但很明显,城市里小偷们的平均水平已经大大提高了。

在图像处理方面可以这么理解:你用真实图片和生成器(Generater)生成的虚假图片共同训练判别器(Discriminator),以致于其能够达到区分真假的功能。生成器(Generater)利用你随机输入的数字生成其对应与真实图片类似的图片,反复训练以至于能够生成越来越逼真的图片。

其图片如下:

生成图片和训练代码用的是卷积。

4、其损失函数

其实对于两者的损失函数可以分开考虑。对于判别器(Discriminator),就是其真实图片损失函数和制造的假图片损失函数求和。valid、fake对应的是其虚假标签;imgs,gen_imgs分别是真实图片和虚假图片。

  1.        d_loss_real = self.discriminator.train_on_batch(imgs, valid)
  2.        d_loss_fake = self.discriminator.train_on_batch(gen_imgs, fake)

对于生成器(Generater),其先将判别器(Discriminator)设置为预测状态,然后X将传入生成器(Generater),将生成的虚假图片输入判别器(Discriminator)进行判别,这样就可以达到更新生成器(Generater)的参数。这里0(真)、1(假)主要是当Discriminator识别效果很好,说明Generater需要努力学习,才能继续蒙混过关,所以其损失值大;反之亦然。

 其代码实现时将其标签直接传为了1(真实图片):

 5、其代码实现:

代码时相对于Mnist实现,通过生成Mnist图像蒙混判别器(Discriminator)。

  1. from __future__ import print_function, division
  2. import tensorflow as tf
  3. import os
  4. os.environ["CUDA_VISIBLE_DEVICES"] = "0"
  5. config = tf.ConfigProto(allow_soft_placement = True)
  6. gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction = 0.5)
  7. config.gpu_options.allow_growth = True
  8. sess0 = tf.InteractiveSession(config = config)
  9. from keras.datasets import mnist
  10. from keras.layers import Input, Dense, Reshape, Flatten, Dropout
  11. from keras.layers import BatchNormalization, Activation, ZeroPadding2D
  12. from keras.layers.advanced_activations import LeakyReLU
  13. from keras.layers.convolutional import UpSampling2D, Conv2D
  14. from keras.models import Sequential, Model
  15. from keras.optimizers import Adam
  16. import matplotlib.pyplot as plt
  17. import sys
  18. import numpy as np
  19. class GAN():
  20. def __init__(self):
  21. self.img_rows = 28
  22. self.img_cols = 28
  23. self.channels = 1
  24. self.img_shape = (self.img_rows, self.img_cols, self.channels)
  25. self.latent_dim = 100
  26. optimizer = Adam(0.0002, 0.5)
  27. # Build and compile the discriminator
  28. self.discriminator = self.build_discriminator()
  29. self.discriminator.compile(loss='binary_crossentropy',
  30. optimizer=optimizer,
  31. metrics=['accuracy'])
  32. # Build the generator
  33. self.generator = self.build_generator()
  34. # The generator takes noise as input and generates imgs
  35. z = Input(shape=(self.latent_dim,))
  36. img = self.generator(z)
  37. # For the combined model we will only train the generator
  38. self.discriminator.trainable = False
  39. # The discriminator takes generated images as input and determines validity
  40. validity = self.discriminator(img)
  41. # The combined model (stacked generator and discriminator)
  42. # Trains the generator to fool the discriminator
  43. self.combined = Model(z, validity)
  44. self.combined.compile(loss='binary_crossentropy', optimizer=optimizer)
  45. def build_generator(self):
  46. model = Sequential()
  47. model.add(Dense(256, input_dim=self.latent_dim))
  48. model.add(LeakyReLU(alpha=0.2))
  49. model.add(BatchNormalization(momentum=0.8))
  50. model.add(Dense(512))
  51. model.add(LeakyReLU(alpha=0.2))
  52. model.add(BatchNormalization(momentum=0.8))
  53. model.add(Dense(1024))
  54. model.add(LeakyReLU(alpha=0.2))
  55. model.add(BatchNormalization(momentum=0.8))
  56. model.add(Dense(np.prod(self.img_shape), activation='tanh'))
  57. model.add(Reshape(self.img_shape))
  58. model.summary()
  59. noise = Input(shape=(self.latent_dim,))
  60. img = model(noise)
  61. return Model(noise, img)
  62. def build_discriminator(self):
  63. model = Sequential()
  64. model.add(Flatten(input_shape=self.img_shape))
  65. model.add(Dense(512))
  66. model.add(LeakyReLU(alpha=0.2))
  67. model.add(Dense(256))
  68. model.add(LeakyReLU(alpha=0.2))
  69. model.add(Dense(1, activation='sigmoid'))
  70. model.summary()
  71. img = Input(shape=self.img_shape)
  72. validity = model(img)
  73. return Model(img, validity)
  74. def train(self, epochs, batch_size=128, sample_interval=50):
  75. # Load the dataset
  76. (X_train, _), (_, _) = mnist.load_data()
  77. # Rescale -1 to 1
  78. X_train = X_train / 127.5 - 1.
  79. X_train = np.expand_dims(X_train, axis=3)
  80. # Adversarial ground truths
  81. valid = np.ones((batch_size, 1))
  82. fake = np.zeros((batch_size, 1))
  83. for epoch in range(epochs):
  84. # ---------------------
  85. # Train Discriminator
  86. # ---------------------
  87. # Select a random batch of images
  88. idx = np.random.randint(0, X_train.shape[0], batch_size)
  89. imgs = X_train[idx]
  90. noise = np.random.normal(0, 1, (batch_size, self.latent_dim))
  91. # Generate a batch of new images
  92. gen_imgs = self.generator.predict(noise)
  93. # Train the discriminator
  94. d_loss_real = self.discriminator.train_on_batch(imgs, valid)
  95. d_loss_fake = self.discriminator.train_on_batch(gen_imgs, fake)
  96. d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)
  97. # ---------------------
  98. # Train Generator
  99. # ---------------------
  100. noise = np.random.normal(0, 1, (batch_size, self.latent_dim))
  101. # Train the generator (to have the discriminator label samples as valid)
  102. g_loss = self.combined.train_on_batch(noise, valid)
  103. # Plot the progress
  104. print ("%d [D loss: %f, acc.: %.2f%%] [G loss: %f]" % (epoch, d_loss[0], 100*d_loss[1], g_loss))
  105. # If at save interval => save generated image samples
  106. if epoch % sample_interval == 0:
  107. self.sample_images(epoch)
  108. def sample_images(self, epoch):
  109. r, c = 5, 5
  110. noise = np.random.normal(0, 1, (r * c, self.latent_dim))
  111. gen_imgs = self.generator.predict(noise)
  112. # Rescale images 0 - 1
  113. gen_imgs = 0.5 * gen_imgs + 0.5
  114. fig, axs = plt.subplots(r, c)
  115. cnt = 0
  116. for i in range(r):
  117. for j in range(c):
  118. axs[i,j].imshow(gen_imgs[cnt, :,:,0], cmap='gray')
  119. axs[i,j].axis('off')
  120. cnt += 1
  121. fig.savefig("images/%d.png" % epoch)
  122. plt.close()
  123. if __name__ == '__main__':
  124. gan = GAN()
  125. gan.train(epochs=30000, batch_size=1, sample_interval=200)

其结果如下(训练时间不会很久,有兴趣的可以试试):

 

 

 通过结果可以发现 随着训练步数的增加,生成的图片越来越能以假乱真。

 

 

 

 

 

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/我家自动化/article/detail/876156
推荐阅读
相关标签
  

闽ICP备14008679号