当前位置:   article > 正文

CGAN 简介与代码实战_cgan dscn

cgan dscn

1.介绍
  原始GANGAN 简介与代码实战_天竺街潜水的八角的博客-CSDN博客)在理论上可以完全逼近真实数据,但它的可控性不强(生成小图片还行,生成的大图片可能是不合逻辑的),因此需要对gan加一些约束,能生成我们想要的图片,这个时候,CGAN就横空出世了,更加详细的介绍参考论文:Conditional Generative Adversarial Nets
 

2.模型结构

 公式1是原始GAN的损失函数,公式2相对于公式1多了一个条件y,这个y可以是标签和图片中需要修复的部分(比如动物)等

 如果只看公式2,很难想象到,怎样才能把y当作条件来融入网络。看下图之后,我们很容易想到,条件y和待判别的图像被拼接(concat)起来就可以达到这个效果。 

 

3.模型特点

使用额外信息y对模型增加条件,可以指导数据生成过程

4.代码实现 keras

  1. class CGAN():
  2. def __init__(self):
  3. # Input shape
  4. self.img_rows = 28
  5. self.img_cols = 28
  6. self.channels = 1
  7. self.img_shape = (self.img_rows, self.img_cols, self.channels)
  8. self.num_classes = 10
  9. self.latent_dim = 100
  10. optimizer = Adam(0.0002, 0.5)
  11. # Build and compile the discriminator
  12. self.discriminator = self.build_discriminator()
  13. self.discriminator.compile(loss=['binary_crossentropy'],
  14. optimizer=optimizer,
  15. metrics=['accuracy'])
  16. # Build the generator
  17. self.generator = self.build_generator()
  18. # The generator takes noise and the target label as input
  19. # and generates the corresponding digit of that label
  20. noise = Input(shape=(self.latent_dim,))
  21. label = Input(shape=(1,))
  22. img = self.generator([noise, label])
  23. # For the combined model we will only train the generator
  24. self.discriminator.trainable = False
  25. # The discriminator takes generated image as input and determines validity
  26. # and the label of that image
  27. valid = self.discriminator([img, label])
  28. # The combined model (stacked generator and discriminator)
  29. # Trains generator to fool discriminator
  30. self.combined = Model([noise, label], valid)
  31. self.combined.compile(loss=['binary_crossentropy'],
  32. optimizer=optimizer)
  33. def build_generator(self):
  34. model = Sequential()
  35. model.add(Dense(256, input_dim=self.latent_dim))
  36. model.add(LeakyReLU(alpha=0.2))
  37. model.add(BatchNormalization(momentum=0.8))
  38. model.add(Dense(512))
  39. model.add(LeakyReLU(alpha=0.2))
  40. model.add(BatchNormalization(momentum=0.8))
  41. model.add(Dense(1024))
  42. model.add(LeakyReLU(alpha=0.2))
  43. model.add(BatchNormalization(momentum=0.8))
  44. model.add(Dense(np.prod(self.img_shape), activation='tanh'))
  45. model.add(Reshape(self.img_shape))
  46. model.summary()
  47. noise = Input(shape=(self.latent_dim,))
  48. label = Input(shape=(1,), dtype='int32')
  49. label_embedding = Flatten()(Embedding(self.num_classes, self.latent_dim)(label))
  50. model_input = multiply([noise, label_embedding])
  51. img = model(model_input)
  52. return Model([noise, label], img)
  53. def build_discriminator(self):
  54. model = Sequential()
  55. model.add(Dense(512, input_dim=np.prod(self.img_shape)))
  56. model.add(LeakyReLU(alpha=0.2))
  57. model.add(Dense(512))
  58. model.add(LeakyReLU(alpha=0.2))
  59. model.add(Dropout(0.4))
  60. model.add(Dense(512))
  61. model.add(LeakyReLU(alpha=0.2))
  62. model.add(Dropout(0.4))
  63. model.add(Dense(1, activation='sigmoid'))
  64. model.summary()
  65. img = Input(shape=self.img_shape)
  66. label = Input(shape=(1,), dtype='int32')
  67. label_embedding = Flatten()(Embedding(self.num_classes, np.prod(self.img_shape))(label))
  68. flat_img = Flatten()(img)
  69. model_input = multiply([flat_img, label_embedding])
  70. validity = model(model_input)
  71. return Model([img, label], validity)
  72. def train(self, epochs, batch_size=128, sample_interval=50):
  73. # Load the dataset
  74. (X_train, y_train), (_, _) = mnist.load_data()
  75. # Configure input
  76. X_train = (X_train.astype(np.float32) - 127.5) / 127.5
  77. X_train = np.expand_dims(X_train, axis=3)
  78. y_train = y_train.reshape(-1, 1)
  79. # Adversarial ground truths
  80. valid = np.ones((batch_size, 1))
  81. fake = np.zeros((batch_size, 1))
  82. for epoch in range(epochs):
  83. # ---------------------
  84. # Train Discriminator
  85. # ---------------------
  86. # Select a random half batch of images
  87. idx = np.random.randint(0, X_train.shape[0], batch_size)
  88. imgs, labels = X_train[idx], y_train[idx]
  89. # Sample noise as generator input
  90. noise = np.random.normal(0, 1, (batch_size, 100))
  91. # Generate a half batch of new images
  92. gen_imgs = self.generator.predict([noise, labels])
  93. # Train the discriminator
  94. d_loss_real = self.discriminator.train_on_batch([imgs, labels], valid)
  95. d_loss_fake = self.discriminator.train_on_batch([gen_imgs, labels], fake)
  96. d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)
  97. # ---------------------
  98. # Train Generator
  99. # ---------------------
  100. # Condition on labels
  101. sampled_labels = np.random.randint(0, 10, batch_size).reshape(-1, 1)
  102. # Train the generator
  103. g_loss = self.combined.train_on_batch([noise, sampled_labels], valid)
  104. # Plot the progress
  105. print ("%d [D loss: %f, acc.: %.2f%%] [G loss: %f]" % (epoch, d_loss[0], 100*d_loss[1], g_loss))
  106. # If at save interval => save generated image samples
  107. if epoch % sample_interval == 0:
  108. self.sample_images(epoch)
  109. def sample_images(self, epoch):
  110. r, c = 2, 5
  111. noise = np.random.normal(0, 1, (r * c, 100))
  112. sampled_labels = np.arange(0, 10).reshape(-1, 1)
  113. gen_imgs = self.generator.predict([noise, sampled_labels])
  114. # Rescale images 0 - 1
  115. gen_imgs = 0.5 * gen_imgs + 0.5
  116. fig, axs = plt.subplots(r, c)
  117. cnt = 0
  118. for i in range(r):
  119. for j in range(c):
  120. axs[i,j].imshow(gen_imgs[cnt,:,:,0], cmap='gray')
  121. axs[i,j].set_title("Digit: %d" % sampled_labels[cnt])
  122. axs[i,j].axis('off')
  123. cnt += 1
  124. fig.savefig("images/%d.png" % epoch)
  125. plt.close()

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/Gausst松鼠会/article/detail/415170
推荐阅读
相关标签
  

闽ICP备14008679号