当前位置:   article > 正文




论文:Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks



与pix2pixGAN的区别:二者都可以做图像变换,pix2pix模型必须要求成对数据(paired data),而CycleGAN利用非成对数据也能进行训练(unpaired data)。




CycleGAN其实就是一个 A→B 的单向 GAN 加上一个 B→A 的单向 GAN。两个 GAN 共享两个生成器,然后各自带一个判别器,所以加起来总共有两个判别器和两个生成器。一个单向 GAN 有两个 loss, 故 CycleGAN 加起来总共有四个 loss。

循环一致损失:因为网络需要保证生成的图像必须保留有原 始图像的特性,所以如果我们使用生成器GenratorA-B生 成一张假图像,那么要能够使用另外一个生成器 GenratorB-A来努力恢复成原始图像。此过程必须满足循环一致性。

identity loss:可以理解为,生成器是负责域x到域y的图像生成, 如果输入域y的图片还是应该生成域y的图片。

  1. # 用狗的图像生成猫的图像
  2. import itertools
  3. import torch
  4. import torch.nn as nn
  5. import torch.nn.functional as F
  6. from torchvision import transforms, datasets, utils
  7. import matplotlib.pyplot as plt
  8. import numpy as np
  9. import torch.optim as optim
  10. from torch.utils.data.dataset import Dataset
  11. from PIL import Image
  12. import tqdm
  13. import glob
  14. dogs_path = glob.glob('D:\cnn\All Classfication\AlexNet\data/train\Dog/*.jpg') #获取数据集中的.jpg图片
  15. cats_path = glob.glob('D:\cnn\All Classfication\AlexNet\data/train\Cat/*.jpg') #获取数据集中的.jpg图片
  16. # print(cats_path[:3])
  17. # print(dogs_path[:3])
  18. cats_path_test = glob.glob('D:\cnn\All Classfication\AlexNet\data/val\Cat/*.jpg') #获取数据集中的.jpg图片
  19. dogs_path_test = glob.glob('D:\cnn\All Classfication\AlexNet\data/val\Dog/*.jpg') #获取数据集中的.jpg图片
  20. transform = transforms.Compose([transforms.ToTensor(),
  21. transforms.Resize((256, 256)),
  22. transforms.Normalize(mean=0.5, std=0.5)]) #Normalize为转化到-1~1之间
  23. # 定义数据读取
  24. class SGANDataset(Dataset):
  25. def __init__(self, imgs_path): #初始化
  26. super(SGANDataset, self).__init__()
  27. self.imgs_path = imgs_path #定义属性
  28. def __len__(self):
  29. return len(self.imgs_path)
  30. def __getitem__(self, index): #对数据切片
  31. img_path = self.imgs_path[index]
  32. # 从文件中读取图像
  33. pil_img = Image.open(img_path)
  34. pil_img = transform(pil_img)
  35. return pil_img
  36. # 初始化训练集
  37. dog_dataset = SGANDataset(dogs_path) #创建dataset
  38. cat_dataset = SGANDataset(cats_path) #创建dataset
  39. # 初始化测试集
  40. dog_dataset_test = SGANDataset(dogs_path_test) #创建dataset
  41. cat_dataset_test = SGANDataset(cats_path_test) #创建dataset
  42. dog_dataloader = torch.utils.data.DataLoader(dog_dataset, batch_size=4, shuffle=True)
  43. cat_dataloader = torch.utils.data.DataLoader(cat_dataset, batch_size=4, shuffle=True)
  44. dog_dataloader_test = torch.utils.data.DataLoader(dog_dataset_test, batch_size=4)
  45. cat_dataloader_test = torch.utils.data.DataLoader(cat_dataset_test, batch_size=4)
  46. # cat_bath = next(iter(cat_dataloader)) #查看
  47. # dog_bath = next(iter(dog_dataloader)) #查看
  48. # print(dog_bath.shape) #torch.Size([4, 3, 256, 256])
  49. # print(cat_bath.shape) #torch.Size([4, 3, 256, 256])
  50. # 查看数据集
  51. # plt.figure(figsize=(8, 12))
  52. # for i, (dog, cat) in enumerate(zip(dog_bath[:3], cat_bath[:3])): #zip代表元组
  53. # # 因为dataset返回的数据是tensor,需要转为numpy格式,因为Normalize为转化到-1~1之间,所以加1再除以2将其转化到0~1之间
  54. # dog = (dog.permute(1, 2, 0).numpy() + 1) / 2
  55. # cat = (cat.permute(1, 2, 0).numpy() + 1) / 2
  56. # plt.subplot(3, 2, 2*i+1)
  57. # plt.title('dog')
  58. # plt.imshow(dog)
  59. # plt.subplot(3, 2, 2*i+2)
  60. # plt.title('cat')
  61. # plt.imshow(cat)
  62. # plt.show()
  63. #定义下采样模块
  64. class Downsample(nn.Module):
  65. def __init__(self, in_channels, out_channels):
  66. super(Downsample, self).__init__()
  67. self.conv_relu = nn.Sequential(
  68. nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, padding=1),
  69. nn.LeakyReLU(inplace=True)
  70. )
  71. self.bn = nn.InstanceNorm2d(out_channels)
  72. def forward(self, x, is_bn=True): #is_bn用于确定是否使用bn层,默认为True
  73. x = self.conv_relu(x)
  74. if is_bn:
  75. x = self.bn(x)
  76. return x
  77. #定义上采样模块
  78. class Upsample(nn.Module):
  79. def __init__(self, in_channels, out_channels):
  80. super(Upsample, self).__init__()
  81. self.upconv_relu = nn.Sequential(
  82. nn.ConvTranspose2d(in_channels, out_channels, kernel_size=3, stride=2, padding=1, output_padding=1),
  83. nn.LeakyReLU(inplace=True)
  84. )
  85. self.bn = nn.InstanceNorm2d(out_channels)
  86. def forward(self, x, is_drop=False): #is_drop用于确定是否使用drop层,默认为False
  87. x = self.upconv_relu(x)
  88. x = self.bn(x)
  89. if is_drop:
  90. x = F.dropout2d(x)
  91. return x
  92. # 定义生成器,包含6个下采样层,6个上采样层
  93. class Generator(nn.Module):
  94. def __init__(self):
  95. super(Generator, self).__init__()
  96. self.down1 = Downsample(3, 64) #3,256,256 -- 64,128,128
  97. self.down2 = Downsample(64, 128) #64,128,128 -- 128,64,64
  98. self.down3 = Downsample(128, 256) #128,64,64 -- 256,32,32
  99. self.down4 = Downsample(256, 512) #256,32,32 -- 512,16,16
  100. self.down5 = Downsample(512, 512) #512,16,16 -- 512,8,8
  101. self.down6 = Downsample(512, 512) #512,8,8 -- 512,4,4
  102. self.up1 = Upsample(512, 512) #512,4,4 -- 512,8,8
  103. self.up2 = Upsample(1024, 512) #1024,8,8 -- 512,16,16
  104. self.up3 = Upsample(1024, 256) #1024,16,16 -- 256,32,32
  105. self.up4 = Upsample(512, 128) #512,32,32 -- 128,64,64
  106. self.up5 = Upsample(256, 64) #256,64,64 -- 64,128,128
  107. #128,128,128 -- 3,256,256
  108. self.last = nn.ConvTranspose2d(128, 3, kernel_size=3, stride=2, padding=1, output_padding=1)
  109. def forward(self, x):
  110. x1 = self.down1(x)
  111. x2 = self.down2(x1)
  112. x3 = self.down3(x2)
  113. x4 = self.down4(x3)
  114. x5 = self.down5(x4)
  115. x6 = self.down6(x5)
  116. x6 = self.up1(x6, is_drop=True)
  117. x6 = torch.cat([x6, x5], dim=1)
  118. x6 = self.up2(x6, is_drop=True)
  119. x6 = torch.cat([x6, x4], dim=1)
  120. x6 = self.up3(x6, is_drop=True)
  121. x6 = torch.cat([x6, x3], dim=1)
  122. x6 = self.up4(x6)
  123. x6 = torch.cat([x6, x2], dim=1)
  124. x6 = self.up5(x6)
  125. x6 = torch.cat([x6, x1], dim=1)
  126. x6 = torch.tanh(self.last(x6))
  127. return x6
  128. # 定义判别器
  129. class Discriminator(nn.Module):
  130. def __init__(self):
  131. super(Discriminator, self).__init__()
  132. self.down1 = Downsample(3, 64)
  133. self.down2 = Downsample(64, 128)
  134. self.last = nn.Conv2d(128, 1, 3)
  135. def forward(self, img):
  136. x = self.down1(img)
  137. x = self.down2(x)
  138. x =torch.sigmoid(self.last(x))
  139. return x
  140. device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
  141. # 初始化两个生成器
  142. gen_AB = Generator().to(device)
  143. gen_BA = Generator().to(device)
  144. # 初始化两个判别器
  145. dis_A = Discriminator().to(device)
  146. dis_B = Discriminator().to(device)
  147. # 损失函数 1.gan loss 2.cycle consistance loss 3.identity loss
  148. bce_loss = torch.nn.BCELoss()
  149. l1_loss = torch.nn.L1Loss()
  150. # 初始化优化器
  151. # 对两个生成器同时进行优化, 使用itertools.chain对二者同时进行迭代
  152. gen_optimizer = torch.optim.Adam(itertools.chain(gen_AB.parameters(), gen_BA.parameters()), lr=2e-4, betas=(0.5, 0.999))
  153. # 对两个判别器分别进行优化
  154. dis_A_optimizer = torch.optim.Adam(dis_A.parameters(), lr=2e-4, betas=(0.5, 0.999))
  155. dis_B_optimizer = torch.optim.Adam(dis_B.parameters(), lr=2e-4, betas=(0.5, 0.999))
  156. # 绘图函数,将每一个epoch中生成器生成的图片绘制
  157. def gen_img_plot(model, epoch, test_input): # model为gen_AB/gen_BA,test_input
  158. generate = model(test_input).permute(0, 2, 3, 1).cpu().numpy() #将通道维度放在最后
  159. test_input = test_input.permute(0, 2, 3, 1).cpu().numpy() #1,3,256,256 -- 1,256,256,3
  160. plt.figure(figsize=(10, 6))
  161. display_list = [test_input[0], generate[0]]
  162. title = ['Input image', 'Generate image']
  163. for i in range(2):
  164. plt.subplot(1, 2, i + 1)
  165. plt.title(title[i])
  166. plt.imshow((display_list[i]+1)/2) #从-1~1 --> 0~1
  167. plt.axis('off')
  168. plt.savefig('./image/image_at_{}.png'.format(epoch))
  169. test_batch = next(iter(dog_dataloader_test)) #batch_size,3,256,256
  170. # 测试输入:选取test_batch中的第一张图片,并添加一个batch_size维度 3,256,256--1,3,256,256
  171. test_input = torch.unsqueeze(test_batch[0], 0).to(device)
  172. # cycleGAN训练
  173. D_loss = []
  174. G_loss = []
  175. epochs = 50
  176. for epoch in range(epochs):
  177. d_epoch_loss = 0
  178. g_epoch_loss = 0
  179. for step, (real_A, real_B) in enumerate(zip(dog_dataloader, cat_dataloader)): #取出真实的狗,猫图片
  180. real_A = real_A.to(device)
  181. real_B = real_B.to(device)
  182. #--------------------begin--------------------#
  183. # 生成器训练
  184. gen_optimizer.zero_grad() #训练之前梯度清0
  185. # identity loss
  186. same_B = gen_AB(real_B) #真实的B经过生成器gen_AB还是要得到真实的B
  187. identity_B_loss = l1_loss(same_B, real_B)
  188. same_A = gen_AB(real_A) #真实的A经过生成器gen_BA还是要得到真实的A
  189. identity_A_loss = l1_loss(same_A, real_A)
  190. # 对抗损失 gan loss
  191. fake_B = gen_AB(real_A) #真实A通过生成器生成了B,此时生成器希望判别器将其判别为真
  192. D_pred_fake_B = dis_B(fake_B)
  193. gen_loss_AB = bce_loss(D_pred_fake_B, torch.ones_like(D_pred_fake_B, device=device))
  194. fake_A = gen_BA(real_B) #真实B通过生成器生成了A,此时生成器希望判别器将其判别为真
  195. D_pred_fake_A = dis_A(fake_A)
  196. gen_loss_BA = bce_loss(D_pred_fake_A, torch.ones_like(D_pred_fake_A, device=device))
  197. # 循环一致损失
  198. recovered_A = gen_BA(fake_B)
  199. cycle_loss_ABA = l1_loss(recovered_A, real_A)
  200. recovered_B = gen_AB(fake_A)
  201. cycle_loss_BAB = l1_loss(recovered_B, real_B)
  202. # 生成器总的损失
  203. g_loss = identity_A_loss + identity_B_loss + gen_loss_AB + gen_loss_BA +cycle_loss_ABA + cycle_loss_BAB
  204. g_loss.backward()
  205. gen_optimizer.step()
  206. # --------------------end--------------------#
  207. # --------------------begin--------------------#
  208. # 判别器训练
  209. # dis_A训练
  210. dis_A_optimizer.zero_grad()
  211. dis_A_real_output = dis_A(real_A) #输入为真,期望判定为真
  212. dis_A_real_loss = bce_loss(dis_A_real_output, torch.ones_like(dis_A_real_output, device=device))
  213. dis_A_fake_output = dis_A(fake_A.detach()) #输入为假,期望判定为假,梯度截断
  214. dis_A_fake_loss = bce_loss(dis_A_fake_output, torch.zeros_like(dis_A_fake_output, device=device))
  215. dis_A_loss = dis_A_real_loss + dis_A_fake_loss #生成器A的总损失
  216. dis_A_loss.backward()
  217. dis_A_optimizer.step()
  218. # dis_B训练
  219. dis_B_optimizer.zero_grad()
  220. dis_B_real_output = dis_B(real_B) #输入为真,期望判定为真
  221. dis_B_real_loss = bce_loss(dis_B_real_output, torch.ones_like(dis_B_real_output, device=device))
  222. dis_B_fake_output = dis_B(fake_B.detach()) #输入为假,期望判定为假,梯度截断
  223. dis_B_fake_loss = bce_loss(dis_B_fake_output, torch.zeros_like(dis_B_fake_output, device=device))
  224. dis_B_loss = dis_B_real_loss + dis_B_fake_loss #生成器B的总损失
  225. dis_B_loss.backward()
  226. dis_B_optimizer.step()
  227. # --------------------end--------------------#
  228. with torch.no_grad():
  229. g_epoch_loss += g_loss.item() #将每一个批次的loss累加
  230. d_epoch_loss += (dis_A_loss + dis_B_loss).item() # 将每一个批次的loss累加
  231. with torch.no_grad():
  232. g_epoch_loss /= (step + 1) #求得每一轮的平均loss
  233. d_epoch_loss /= (step + 1) #求得每一轮的平均loss
  234. D_loss.append(d_epoch_loss)
  235. G_loss.append(g_epoch_loss)
  236. print('epoch:', epoch, 'g_epoch_loss:', g_epoch_loss, 'd_epoch_loss:', d_epoch_loss)
  237. gen_img_plot(gen_AB, epoch, test_input)
