当前位置:   article > 正文

CycleGAN 简介与代码实战_cyclegan代码

cyclegan代码

1.介绍
  CycleGAN 出自于论文“Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks”,其实题目已经把论文的重点给凸显出来了,Unpaired(不成对),Cycle(圈)
 

2.模型结构

 整个结构就是一个Cycle操作(见图中的a),对于(b)小图,x为输入,y_生成=G(x),x_生成=F(y_生成);对于(c)小图y为输入,x_生成=F(y),y_生成=G(x_生成),其中G和F是生成器,Dx和Dy是判别器。

3.模型特点

 总的损失函数为公式(3),公式(2)为L1损失函数,能使得恢复的图片更加接近原图片且相对不模糊(相对于L2损失函数),公式(1)就是常规的gan损失函数。判别器用的是Patch-D(Pix2Pix用的一样),生成器为U-net结构。

 

 

4.代码实现keras

  1. class CycleGAN():
  2. def __init__(self):
  3. # Input shape
  4. self.img_rows = 128
  5. self.img_cols = 128
  6. self.channels = 3
  7. self.img_shape = (self.img_rows, self.img_cols, self.channels)
  8. # Configure data loader
  9. self.dataset_name = 'apple2orange'
  10. self.data_loader = DataLoader(dataset_name=self.dataset_name,
  11. img_res=(self.img_rows, self.img_cols))
  12. # Calculate output shape of D (PatchGAN)
  13. patch = int(self.img_rows / 2**4)
  14. self.disc_patch = (patch, patch, 1)
  15. # Number of filters in the first layer of G and D
  16. self.gf = 32
  17. self.df = 64
  18. # Loss weights
  19. self.lambda_cycle = 10.0 # Cycle-consistency loss
  20. self.lambda_id = 0.1 * self.lambda_cycle # Identity loss
  21. optimizer = Adam(0.0002, 0.5)
  22. # Build and compile the discriminators
  23. self.d_A = self.build_discriminator()
  24. self.d_B = self.build_discriminator()
  25. self.d_A.compile(loss='mse',
  26. optimizer=optimizer,
  27. metrics=['accuracy'])
  28. self.d_B.compile(loss='mse',
  29. optimizer=optimizer,
  30. metrics=['accuracy'])
  31. #-------------------------
  32. # Construct Computational
  33. # Graph of Generators
  34. #-------------------------
  35. # Build the generators
  36. self.g_AB = self.build_generator()
  37. self.g_BA = self.build_generator()
  38. # Input images from both domains
  39. img_A = Input(shape=self.img_shape)
  40. img_B = Input(shape=self.img_shape)
  41. # Translate images to the other domain
  42. fake_B = self.g_AB(img_A)
  43. fake_A = self.g_BA(img_B)
  44. # Translate images back to original domain
  45. reconstr_A = self.g_BA(fake_B)
  46. reconstr_B = self.g_AB(fake_A)
  47. # Identity mapping of images
  48. img_A_id = self.g_BA(img_A)
  49. img_B_id = self.g_AB(img_B)
  50. # For the combined model we will only train the generators
  51. self.d_A.trainable = False
  52. self.d_B.trainable = False
  53. # Discriminators determines validity of translated images
  54. valid_A = self.d_A(fake_A)
  55. valid_B = self.d_B(fake_B)
  56. # Combined model trains generators to fool discriminators
  57. self.combined = Model(inputs=[img_A, img_B],
  58. outputs=[ valid_A, valid_B,
  59. reconstr_A, reconstr_B,
  60. img_A_id, img_B_id ])
  61. self.combined.compile(loss=['mse', 'mse',
  62. 'mae', 'mae',
  63. 'mae', 'mae'],
  64. loss_weights=[ 1, 1,
  65. self.lambda_cycle, self.lambda_cycle,
  66. self.lambda_id, self.lambda_id ],
  67. optimizer=optimizer)
  68. def build_generator(self):
  69. """U-Net Generator"""
  70. def conv2d(layer_input, filters, f_size=4):
  71. """Layers used during downsampling"""
  72. d = Conv2D(filters, kernel_size=f_size, strides=2, padding='same')(layer_input)
  73. d = LeakyReLU(alpha=0.2)(d)
  74. d = InstanceNormalization()(d)
  75. return d
  76. def deconv2d(layer_input, skip_input, filters, f_size=4, dropout_rate=0):
  77. """Layers used during upsampling"""
  78. u = UpSampling2D(size=2)(layer_input)
  79. u = Conv2D(filters, kernel_size=f_size, strides=1, padding='same', activation='relu')(u)
  80. if dropout_rate:
  81. u = Dropout(dropout_rate)(u)
  82. u = InstanceNormalization()(u)
  83. u = Concatenate()([u, skip_input])
  84. return u
  85. # Image input
  86. d0 = Input(shape=self.img_shape)
  87. # Downsampling
  88. d1 = conv2d(d0, self.gf)
  89. d2 = conv2d(d1, self.gf*2)
  90. d3 = conv2d(d2, self.gf*4)
  91. d4 = conv2d(d3, self.gf*8)
  92. # Upsampling
  93. u1 = deconv2d(d4, d3, self.gf*4)
  94. u2 = deconv2d(u1, d2, self.gf*2)
  95. u3 = deconv2d(u2, d1, self.gf)
  96. u4 = UpSampling2D(size=2)(u3)
  97. output_img = Conv2D(self.channels, kernel_size=4, strides=1, padding='same', activation='tanh')(u4)
  98. return Model(d0, output_img)
  99. def build_discriminator(self):
  100. def d_layer(layer_input, filters, f_size=4, normalization=True):
  101. """Discriminator layer"""
  102. d = Conv2D(filters, kernel_size=f_size, strides=2, padding='same')(layer_input)
  103. d = LeakyReLU(alpha=0.2)(d)
  104. if normalization:
  105. d = InstanceNormalization()(d)
  106. return d
  107. img = Input(shape=self.img_shape)
  108. d1 = d_layer(img, self.df, normalization=False)
  109. d2 = d_layer(d1, self.df*2)
  110. d3 = d_layer(d2, self.df*4)
  111. d4 = d_layer(d3, self.df*8)
  112. validity = Conv2D(1, kernel_size=4, strides=1, padding='same')(d4)
  113. return Model(img, validity)
  114. def train(self, epochs, batch_size=1, sample_interval=50):
  115. start_time = datetime.datetime.now()
  116. # Adversarial loss ground truths
  117. valid = np.ones((batch_size,) + self.disc_patch)
  118. fake = np.zeros((batch_size,) + self.disc_patch)
  119. for epoch in range(epochs):
  120. for batch_i, (imgs_A, imgs_B) in enumerate(self.data_loader.load_batch(batch_size)):
  121. # ----------------------
  122. # Train Discriminators
  123. # ----------------------
  124. # Translate images to opposite domain
  125. fake_B = self.g_AB.predict(imgs_A)
  126. fake_A = self.g_BA.predict(imgs_B)
  127. # Train the discriminators (original images = real / translated = Fake)
  128. dA_loss_real = self.d_A.train_on_batch(imgs_A, valid)
  129. dA_loss_fake = self.d_A.train_on_batch(fake_A, fake)
  130. dA_loss = 0.5 * np.add(dA_loss_real, dA_loss_fake)
  131. dB_loss_real = self.d_B.train_on_batch(imgs_B, valid)
  132. dB_loss_fake = self.d_B.train_on_batch(fake_B, fake)
  133. dB_loss = 0.5 * np.add(dB_loss_real, dB_loss_fake)
  134. # Total disciminator loss
  135. d_loss = 0.5 * np.add(dA_loss, dB_loss)
  136. # ------------------
  137. # Train Generators
  138. # ------------------
  139. # Train the generators
  140. g_loss = self.combined.train_on_batch([imgs_A, imgs_B],
  141. [valid, valid,
  142. imgs_A, imgs_B,
  143. imgs_A, imgs_B])
  144. elapsed_time = datetime.datetime.now() - start_time
  145. # Plot the progress
  146. print ("[Epoch %d/%d] [Batch %d/%d] [D loss: %f, acc: %3d%%] [G loss: %05f, adv: %05f, recon: %05f, id: %05f] time: %s " \
  147. % ( epoch, epochs,
  148. batch_i, self.data_loader.n_batches,
  149. d_loss[0], 100*d_loss[1],
  150. g_loss[0],
  151. np.mean(g_loss[1:3]),
  152. np.mean(g_loss[3:5]),
  153. np.mean(g_loss[5:6]),
  154. elapsed_time))
  155. # If at save interval => save generated image samples
  156. if batch_i % sample_interval == 0:
  157. self.sample_images(epoch, batch_i)
  158. def sample_images(self, epoch, batch_i):
  159. os.makedirs('images/%s' % self.dataset_name, exist_ok=True)
  160. r, c = 2, 3
  161. imgs_A = self.data_loader.load_data(domain="A", batch_size=1, is_testing=True)
  162. imgs_B = self.data_loader.load_data(domain="B", batch_size=1, is_testing=True)
  163. # Demo (for GIF)
  164. #imgs_A = self.data_loader.load_img('datasets/apple2orange/testA/n07740461_1541.jpg')
  165. #imgs_B = self.data_loader.load_img('datasets/apple2orange/testB/n07749192_4241.jpg')
  166. # Translate images to the other domain
  167. fake_B = self.g_AB.predict(imgs_A)
  168. fake_A = self.g_BA.predict(imgs_B)
  169. # Translate back to original domain
  170. reconstr_A = self.g_BA.predict(fake_B)
  171. reconstr_B = self.g_AB.predict(fake_A)
  172. gen_imgs = np.concatenate([imgs_A, fake_B, reconstr_A, imgs_B, fake_A, reconstr_B])
  173. # Rescale images 0 - 1
  174. gen_imgs = 0.5 * gen_imgs + 0.5
  175. titles = ['Original', 'Translated', 'Reconstructed']
  176. fig, axs = plt.subplots(r, c)
  177. cnt = 0
  178. for i in range(r):
  179. for j in range(c):
  180. axs[i,j].imshow(gen_imgs[cnt])
  181. axs[i, j].set_title(titles[j])
  182. axs[i,j].axis('off')
  183. cnt += 1
  184. fig.savefig("images/%s/%d_%d.png" % (self.dataset_name, epoch, batch_i))
  185. plt.close()

 

 

 

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/小丑西瓜9/article/detail/442498
推荐阅读
相关标签
  

闽ICP备14008679号