当前位置:   article > 正文

gan(对抗生成网络)代码详解(一)-gan_video gan代码

video gan代码

 

  1. from __future__ import print_function, division
  2. from keras.datasets import mnist
  3. from keras.layers import Input, Dense, Reshape, Flatten, Dropout
  4. from keras.layers import BatchNormalization, Activation, ZeroPadding2D
  5. from keras.layers.advanced_activations import LeakyReLU
  6. from keras.layers.convolutional import UpSampling2D, Conv2D
  7. from keras.models import Sequential, Model
  8. from keras.optimizers import Adam
  9. import matplotlib.pyplot as plt
  10. import sys
  11. import numpy as np
  12. class GAN():
  13. def __init__(self):
  14. self.img_rows = 28
  15. self.img_cols = 28
  16. self.channels = 1
  17. self.img_shape = (self.img_rows, self.img_cols, self.channels)
  18. self.latent_dim = 100
  19. optimizer = Adam(0.0002, 0.5)
  20. # Build and compile the discriminator
  21. # 建立和编译判别器
  22. self.discriminator = self.build_discriminator()
  23. self.discriminator.compile(loss='binary_crossentropy',
  24. optimizer=optimizer,
  25. metrics=['accuracy'])
  26. # Build the generator
  27. # 建立生成器
  28. self.generator = self.build_generator()
  29. # The generator takes noise as input and generates imgs
  30. # 生成器输入随机数值(噪声)生成图片
  31. z = Input(shape=(self.latent_dim,))
  32. img = self.generator(z)
  33. # For the combined model we will only train the generator
  34. # 联合模型,只训练生成器
  35. self.discriminator.trainable = False
  36. # The discriminator takes generated images as input and determines validity
  37. # 判别器以生成的图片为输入判别有效性
  38. validity = self.discriminator(img)
  39. # The combined model (stacked generator and discriminator)
  40. # 联合模型(叠加生成器和判别器)
  41. # Trains the generator to fool the discriminator
  42. # 训练生成器欺骗判别器
  43. self.combined = Model(z, validity)
  44. self.combined.compile(loss='binary_crossentropy', optimizer=optimizer)
  45. def build_generator(self):
  46. model = Sequential()
  47. model.add(Dense(256, input_dim=self.latent_dim))
  48. model.add(LeakyReLU(alpha=0.2))
  49. model.add(BatchNormalization(momentum=0.8))
  50. model.add(Dense(512))
  51. model.add(LeakyReLU(alpha=0.2))
  52. model.add(BatchNormalization(momentum=0.8))
  53. model.add(Dense(1024))
  54. model.add(LeakyReLU(alpha=0.2))
  55. model.add(BatchNormalization(momentum=0.8))
  56. model.add(Dense(np.prod(self.img_shape), activation='tanh'))
  57. model.add(Reshape(self.img_shape))
  58. model.summary()
  59. noise = Input(shape=(self.latent_dim,))
  60. img = model(noise)
  61. return Model(noise, img)
  62. def build_discriminator(self):
  63. model = Sequential()
  64. model.add(Flatten(input_shape=self.img_shape))
  65. model.add(Dense(512))
  66. model.add(LeakyReLU(alpha=0.2))
  67. model.add(Dense(256))
  68. model.add(LeakyReLU(alpha=0.2))
  69. model.add(Dense(1, activation='sigmoid'))
  70. model.summary()
  71. img = Input(shape=self.img_shape)
  72. validity = model(img)
  73. return Model(img, validity)
  74. def train(self, epochs, batch_size=128, sample_interval=50):
  75. # Load the dataset
  76. # 加载数据
  77. (X_train, _), (_, _) = mnist.load_data()
  78. # Rescale -1 to 1
  79. # 数据缩放到-11之间
  80. X_train = X_train / 127.5 - 1.
  81. X_train = np.expand_dims(X_train, axis=3)
  82. # Adversarial ground truths
  83. # 对抗性的基本事实
  84. valid = np.ones((batch_size, 1)) # 321列,每个值都是1
  85. fake = np.zeros((batch_size, 1)) # 321列,每个值都是0
  86. for epoch in range(epochs):
  87. # ---------------------
  88. # Train Discriminator 训练判别器
  89. # ---------------------
  90. # Select a random batch of images
  91. # 随机选择一批图像
  92. idx = np.random.randint(0, X_train.shape[0], batch_size) # X_train.shape (6000,28,28,1
  93. imgs = X_train[idx]
  94. noise = np.random.normal(0, 1, (batch_size, self.latent_dim))
  95. # Generate a batch of new images
  96. # 随机生成一批图像
  97. gen_imgs = self.generator.predict(noise)
  98. # Train the discriminator
  99. # 训练判别器
  100. d_loss_real = self.discriminator.train_on_batch(imgs, valid)
  101. d_loss_fake = self.discriminator.train_on_batch(gen_imgs, fake)
  102. d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)
  103. # ---------------------
  104. # Train Generator 训练生成器
  105. # ---------------------
  106. noise = np.random.normal(0, 1, (batch_size, self.latent_dim))
  107. # Train the generator (to have the discriminator label samples as valid)
  108. # 训练生成器(使判别器标签样本有效)
  109. g_loss = self.combined.train_on_batch(noise, valid)
  110. # Plot the progress
  111. # 规划进度
  112. print ("%d [D loss: %f, acc.: %.2f%%] [G loss: %f]" % (epoch, d_loss[0], 100*d_loss[1], g_loss))
  113. # If at save interval => save generated image samples
  114. # 如果在保证图像间隔 => 保存生成的图像样本
  115. if epoch % sample_interval == 0:
  116. self.sample_images(epoch)
  117. def sample_images(self, epoch):
  118. r, c = 5, 5
  119. noise = np.random.normal(0, 1, (r * c, self.latent_dim))
  120. gen_imgs = self.generator.predict(noise)
  121. # Rescale images 0 - 1
  122. # 数据缩放到-11之间
  123. gen_imgs = 0.5 * gen_imgs + 0.5
  124. fig, axs = plt.subplots(r, c)
  125. cnt = 0
  126. for i in range(r):
  127. for j in range(c):
  128. axs[i,j].imshow(gen_imgs[cnt, :,:,0], cmap='gray')
  129. axs[i,j].axis('off')
  130. cnt += 1
  131. fig.savefig("images/%d.png" % epoch)
  132. plt.close()
  133. if __name__ == '__main__':
  134. gan = GAN()
  135. gan.train(epochs=30000, batch_size=32, sample_interval=200)

 

代码解释

from __future__ import print_function, division

在开头加上from __future__ import print_function, division这句之后,即使在python2.X,使用print、division就得像python3.X那样加括号使用。python2.X中print不需要括号,而在python3.X中则需要。

详解:

https://blog.csdn.net/xiaotao_1/article/details/79460365

https://blog.csdn.net/feixingfei/article/details/7081446

 

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/小丑西瓜9/article/detail/118682
推荐阅读
相关标签
  

闽ICP备14008679号