当前位置:   article > 正文

用keras写一个生成对抗网络(GAN)_基于keras搭建gan网络

基于keras搭建gan网络

用 keras 写一个生成对抗网络(GAN)

1、GAN的作用

GAN的作用:基于现有数据生成类似的新数据。

 

2、GAN的结构:

理解GAN的两大护法GD

G是generator,生成器: 负责凭空捏造数据出来

D是discriminator,判别器: 负责判断数据是不是真数据

这样可以简单的看作是两个网络的博弈过程。在最原始的GAN论文里面,G和D都是两个多层感知机网络。

首先,注意一点,GAN操作的数据不一定非得是图像数据,不过为了更方便解释,我在这里用图像数据为例解释以下GAN:

 

GAN是怎么训练的?

 

3、GAN的代码案例:

  1. from __future__ import print_function, division
  2. from keras.datasets import mnist
  3. from keras.layers import Input, Dense, Reshape, Flatten, Dropout
  4. from keras.layers import BatchNormalization, Activation, ZeroPadding2D
  5. from keras.layers.advanced_activations import LeakyReLU
  6. from keras.layers.convolutional import UpSampling2D, Conv2D
  7. from keras.models import Sequential, Model
  8. from keras.optimizers import Adam
  9. import matplotlib.pyplot as plt
  10. import sys
  11. import numpy as np
  12. class GAN():
  13. def __init__(self):
  14. self.img_rows = 28
  15. self.img_cols = 28
  16. self.channels = 1
  17. self.img_shape = (self.img_rows, self.img_cols, self.channels)
  18. self.latent_dim = 100
  19. optimizer = Adam(0.0002, 0.5)
  20. # Build and compile the discriminator
  21. self.discriminator = self.build_discriminator()
  22. self.discriminator.compile(loss='binary_crossentropy',
  23. optimizer=optimizer,
  24. metrics=['accuracy'])
  25. # Build the generator
  26. self.generator = self.build_generator()
  27. # The generator takes noise as input and generates imgs
  28. z = Input(shape=(self.latent_dim,))
  29. img = self.generator(z)
  30. # For the combined model we will only train the generator
  31. self.discriminator.trainable = False
  32. # The discriminator takes generated images as input and determines validity
  33. validity = self.discriminator(img)
  34. # The combined model (stacked generator and discriminator)
  35. # Trains the generator to fool the discriminator
  36. self.combined = Model(z, validity)
  37. self.combined.compile(loss='binary_crossentropy', optimizer=optimizer)
  38. def build_generator(self):
  39. model = Sequential()
  40. model.add(Dense(256, input_dim=self.latent_dim))
  41. model.add(LeakyReLU(alpha=0.2))
  42. model.add(BatchNormalization(momentum=0.8))
  43. model.add(Dense(512))
  44. model.add(LeakyReLU(alpha=0.2))
  45. model.add(BatchNormalization(momentum=0.8))
  46. model.add(Dense(1024))
  47. model.add(LeakyReLU(alpha=0.2))
  48. model.add(BatchNormalization(momentum=0.8))
  49. model.add(Dense(np.prod(self.img_shape), activation='tanh'))
  50. model.add(Reshape(self.img_shape))
  51. model.summary()
  52. noise = Input(shape=(self.latent_dim,))
  53. img = model(noise)
  54. return Model(noise, img)
  55. def build_discriminator(self):
  56. model = Sequential()
  57. model.add(Flatten(input_shape=self.img_shape))
  58. model.add(Dense(512))
  59. model.add(LeakyReLU(alpha=0.2))
  60. model.add(Dense(256))
  61. model.add(LeakyReLU(alpha=0.2))
  62. model.add(Dense(1, activation='sigmoid'))
  63. model.summary()
  64. img = Input(shape=self.img_shape)
  65. validity = model(img)
  66. return Model(img, validity)
  67. def train(self, epochs, batch_size=128, sample_interval=50):
  68. # Load the dataset
  69. (X_train, _), (_, _) = mnist.load_data()
  70. # Rescale -1 to 1
  71. X_train = X_train / 127.5 - 1.
  72. X_train = np.expand_dims(X_train, axis=3)
  73. # Adversarial ground truths
  74. valid = np.ones((batch_size, 1))
  75. fake = np.zeros((batch_size, 1))
  76. for epoch in range(epochs):
  77. # ---------------------
  78. # Train Discriminator
  79. # ---------------------
  80. # Select a random batch of images
  81. idx = np.random.randint(0, X_train.shape[0], batch_size)
  82. imgs = X_train[idx]
  83. noise = np.random.normal(0, 1, (batch_size, self.latent_dim))
  84. # Generate a batch of new images
  85. gen_imgs = self.generator.predict(noise)
  86. # Train the discriminator
  87. d_loss_real = self.discriminator.train_on_batch(imgs, valid)
  88. d_loss_fake = self.discriminator.train_on_batch(gen_imgs, fake)
  89. d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)
  90. # ---------------------
  91. # Train Generator
  92. # ---------------------
  93. noise = np.random.normal(0, 1, (batch_size, self.latent_dim))
  94. # Train the generator (to have the discriminator label samples as valid)
  95. g_loss = self.combined.train_on_batch(noise, valid)
  96. # Plot the progress
  97. print ("%d [D loss: %f, acc.: %.2f%%] [G loss: %f]" % (epoch, d_loss[0], 100*d_loss[1], g_loss))
  98. # If at save interval => save generated image samples
  99. if epoch % sample_interval == 0:
  100. self.sample_images(epoch)
  101. def sample_images(self, epoch):
  102. r, c = 5, 5
  103. noise = np.random.normal(0, 1, (r * c, self.latent_dim))
  104. gen_imgs = self.generator.predict(noise)
  105. # Rescale images 0 - 1
  106. gen_imgs = 0.5 * gen_imgs + 0.5
  107. fig, axs = plt.subplots(r, c)
  108. cnt = 0
  109. for i in range(r):
  110. for j in range(c):
  111. axs[i,j].imshow(gen_imgs[cnt, :,:,0], cmap='gray')
  112. axs[i,j].axis('off')
  113. cnt += 1
  114. fig.savefig("images/%d.png" % epoch)
  115. plt.close()
  116. if __name__ == '__main__':
  117. gan = GAN()
  118. gan.train(epochs=30000, batch_size=32, sample_interval=200)

 

4.对抗性样本:

特点:

在原有真实样本的基础上稍加处理,使原有分类器对其真实类别无法判别。

比如:

带了一个面具之后你不认识我了,或者说是批了一个羊皮,你就识别不出来我是一个人了。

对于人脸识别,我带上一个特制眼镜就识别不出来我是谁了。

对于安保摄像识别人的系统,我穿了一个特制衬衫,摄像头就识别不出来我是一个人了。

路标指示牌上我贴了一个贴纸之后,自动驾驶就识别指示牌识别错误了。

作用:

可以利用对抗样本来进行训练,提高模型的抗干扰能力,因此有了对抗训练的概念。

通过对抗训练,相当于加了一种形式的正则,可以提高模型的鲁棒性。

【参考链接】

https://blog.csdn.net/leviopku/article/details/81292192

 

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/从前慢现在也慢/article/detail/744657
推荐阅读
相关标签
  

闽ICP备14008679号