当前位置:   article > 正文

【对抗网络】CycleGAN模型讲解和代码实现

cyclegan模型

cycleGAN理论讲解:

论文地址:https://arxiv.org/pdf/1703.10593.pdf
        cycleGAN适用于非配对的图像到图像转换,cycleGAN解决了需要对数据进行训练的困难。

        可以看到上图中左边是配对图片,鞋子的素描和鞋子的真实图片为一对。右边为非配对图片,X是真实图片,Y是油画风格图片。

        CycleGAN的原理可以概述为:将一类图片转换成另一类图片。也就是说,现在有两个样本空间,X 和 Y, 我们希望把 X 空间中的样本转换成 Y 空间中的样本。可以理解为一种风格上的转换。

        这样来看:实际的目标就是学习从 X 到 Y 的映射。我们假设这个映射为F。他就对应着GAN中的生成器,F可以将X中的图片x转换为Y中的图片F(x)。对于生成的图片,我们还需要GAN中的判别器来判别它是否为真实图片,由此构成对抗生成网络。

CycleGAN的整体架构:

关于损失函数:

        这里有一个问题是在足够大的样本容量下,网络可以将相同的输入图像集合映射到目标域中图像的任何随机排列,其中任何学习的映射可以归纳出与目标分布匹配的输出分布。换句话说,映射F完全可以将所有 X 都映射为 Y 空间中的同一张图片,是的损失无效化。因此单独的对抗损失Loss不能保证学习函数可以将单个输入 Xi 映射到期望的输出 Yi。对此,论文作者提出了所谓的”循环一致性损失“(cycle consistency loss)

循环一致损失:

还有一个identity loss:

可以理解为,生成器是负责域 X 到 域 Y 的图像生成,如果输入域Y的图片还是应该生成域Y的图片y‘’,计算 y‘’ 和 输入y 的loss。

总损失:

训练结果:

           

epoch = 1

           

epoch = 15

           

epoch = 30

           

epoch = 45

导入的库:

  1. import torch
  2. import torch.nn as nn
  3. import torch.nn.functional as F
  4. from torchvision import transforms
  5. from torch.utils import data
  6. import matplotlib.pyplot as plt
  7. import glob
  8. from PIL import Image
  9. import itertools
  10. import numpy as np

训练代码:

  1. # 准备训练数据集
  2. mans_path = glob.glob("data/man_woman/trianA/*.jpg")
  3. print(len(mans_path))
  4. plt.figure(figsize=(12, 8))
  5. for i, man_path in enumerate(mans_path[:4]):
  6. man_img = Image.open(man_path)
  7. man_np_img = np.array(man_img)
  8. plt.subplot(2, 2, i + 1)
  9. plt.imshow(man_np_img)
  10. plt.title(str(man_np_img.shape))
  11. plt.show()
  12. womans_path = glob.glob("data/man_woman/trainB/*.jpg")
  13. print(len(womans_path))
  14. plt.figure(figsize=(12, 8))
  15. for i, woman_path in enumerate(womans_path[:4]):
  16. woman_img = Image.open(woman_path)
  17. woman_np_img = np.array(woman_img)
  18. plt.subplot(2, 2, i + 1)
  19. plt.imshow(woman_np_img)
  20. plt.title(str(woman_np_img.shape))
  21. plt.show()
  22. transform = transforms.Compose([
  23. transforms.ToTensor(),
  24. transforms.Resize((256, 256)),
  25. transforms.Normalize(0.5, 0.5)
  26. ])
  27. class MW_dataset(data.Dataset):
  28. def __init__(self, img_path):
  29. self.img_path = img_path
  30. def __getitem__(self, index):
  31. imgpath = self.img_path[index]
  32. pil_img = Image.open(imgpath)
  33. pil_img = transform(pil_img)
  34. return pil_img
  35. def __len__(self):
  36. return len(self.img_path)
  37. man_dataset = MW_dataset(mans_path)
  38. woman_dataset = MW_dataset(womans_path)
  39. BATCHSIZE = 4
  40. man_dl = data.DataLoader(man_dataset, batch_size=BATCHSIZE, shuffle=True)
  41. woman_dl = data.DataLoader(woman_dataset, batch_size=BATCHSIZE, shuffle=True)
  42. man_batch = next(iter(man_dl))
  43. woman_batch = next(iter(woman_dl))
  44. # 打印处理好的图片数据
  45. fig = plt.figure(figsize=(8, 15))
  46. for i, (m, w) in enumerate(zip(man_batch[:3], woman_batch[:3])):
  47. m = (m.permute(1, 2, 0).numpy() + 1) / 2
  48. w = (w.permute(1, 2, 0).numpy() + 1) / 2
  49. plt.subplot(3, 2, 2 * i + 1)
  50. plt.title("man")
  51. plt.imshow(m)
  52. plt.subplot(3, 2, 2 * i + 2)
  53. plt.title("woman")
  54. plt.imshow(w)
  55. plt.show()
  56. # 准备测试数据集
  57. mans_path_test = glob.glob("data/man_woman/testA/*.jpg")
  58. womans_path_test = glob.glob("data/man_woman/testB/*.jpg")
  59. man_dataset_test = MW_dataset(mans_path_test)
  60. woman_dataset_test = MW_dataset(womans_path_test)
  61. man_dl_test = data.DataLoader(man_dataset_test, batch_size=BATCHSIZE, shuffle=True)
  62. woman_dl_test = data.DataLoader(woman_dataset_test, batch_size=BATCHSIZE, shuffle=True)
  63. # 定义下采样模块
  64. class Downsample(nn.Module):
  65. def __init__(self, in_channels, out_channels):
  66. super(Downsample, self).__init__()
  67. self.conv_relu = nn.Sequential(
  68. nn.Conv2d(in_channels=in_channels,
  69. out_channels=out_channels,
  70. kernel_size=3,
  71. stride=2,
  72. padding=1),
  73. nn.LeakyReLU(inplace=True) # inplce=True:就地修改输入张量
  74. )
  75. self.bn = nn.InstanceNorm2d(out_channels) # 因为这里我们需要优化每一张图片,所以要用instancenorm2d
  76. def forward(self, x, is_bn=True):
  77. x = self.conv_relu(x)
  78. if is_bn:
  79. x = self.bn(x)
  80. return x
  81. # 定义上采样模块
  82. class Upsample(nn.Module):
  83. def __init__(self, in_channels, out_channels):
  84. super(Upsample, self).__init__()
  85. self.upconv1 = nn.Sequential(
  86. nn.ConvTranspose2d(in_channels=in_channels,
  87. out_channels=out_channels,
  88. kernel_size=3,
  89. stride=2,
  90. padding=1,
  91. output_padding=1),
  92. nn.LeakyReLU(inplace=True) # inplce=True:就地修改输入张量
  93. )
  94. self.bn = nn.InstanceNorm2d(out_channels) # 因为这里我们需要优化每一张图片,所以要用instancenorm2d
  95. def forward(self, x, is_drop=False):
  96. x = self.upconv1(x)
  97. x = self.bn(x)
  98. if is_drop:
  99. x = F.dropout2d(x)
  100. return x
  101. # 初始化生成器:6个下采样,5个上采样+1个输出层
  102. # PS:实战中建议画出模型图,方便了解输入层和输出层的关系(U-net)
  103. class Generator(nn.Module):
  104. def __init__(self):
  105. super(Generator, self).__init__()
  106. self.down1 = Downsample(3, 64) # (64, 128, 128)
  107. self.down2 = Downsample(64, 128) # (128, 64, 64)
  108. self.down3 = Downsample(128, 256) # (256, 32, 32)
  109. self.down4 = Downsample(256, 512) # (512, 16, 16)
  110. self.down5 = Downsample(512, 512) # (512, 8, 8)
  111. self.down6 = Downsample(512, 512) # (512, 4, 4)
  112. self.up1 = Upsample(512, 512) # (512, 8, 8)
  113. self.up2 = Upsample(1024, 512) # (512, 16, 16)
  114. self.up3 = Upsample(1024, 256) # (256, 32, 32)
  115. self.up4 = Upsample(512, 128) # (128, 64, 64)
  116. self.up5 = Upsample(256, 64) # (64, 128, 128)
  117. self.last = nn.ConvTranspose2d(128, 3,
  118. kernel_size=3,
  119. stride=2,
  120. padding=1,
  121. output_padding=1)
  122. def forward(self, x):
  123. x1 = self.down1(x)
  124. x2 = self.down2(x1)
  125. x3 = self.down3(x2)
  126. x4 = self.down4(x3)
  127. x5 = self.down5(x4)
  128. x6 = self.down6(x5)
  129. x6 = self.up1(x6, is_drop=True)
  130. x6 = torch.cat([x6, x5], dim=1)
  131. x6 = self.up2(x6, is_drop=True)
  132. x6 = torch.cat([x6, x4], dim=1)
  133. x6 = self.up3(x6, is_drop=True)
  134. x6 = torch.cat([x6, x3], dim=1)
  135. x6 = self.up4(x6)
  136. x6 = torch.cat([x6, x2], dim=1)
  137. x6 = self.up5(x6)
  138. x6 = torch.cat([x6, x1], dim=1)
  139. x6 = torch.tanh(self.last(x6))
  140. return x6
  141. # 初始化判别器(patchGAN) 输入anno+img(生成的或者真实的) concat
  142. class Discriminator(nn.Module):
  143. def __init__(self):
  144. super(Discriminator, self).__init__()
  145. self.down1 = Downsample(3, 64) # (64, 128, 128) PS:这里输入的6:anno+img
  146. self.down2 = Downsample(64, 128) # (128, 64, 64)
  147. self.last = nn.Conv2d(128, 1, 3) # (1, 62, 62)
  148. def forward(self, img):
  149. x = self.down1(img)
  150. x = self.down2(x)
  151. x = torch.sigmoid(self.last(x)) # (batch, 1, 60, 60)
  152. return x
  153. device = "cuda" if torch.cuda.is_available() else "cpu"
  154. # 创建两个生成器,两个判别器
  155. gen_AB = Generator().to(device)
  156. gen_BA = Generator().to(device)
  157. dis_A = Discriminator().to(device)
  158. dis_B = Discriminator().to(device)
  159. # 定义损失函数 1.gan loss 2.cycle consistance 3.identity loss
  160. bceloss = torch.nn.BCELoss()
  161. l1_loss = torch.nn.L1Loss()
  162. # 初始化优化器
  163. gen_optimizer = torch.optim.Adam(
  164. itertools.chain(gen_AB.parameters(), gen_BA.parameters()),
  165. lr=2e-4,
  166. betas=(0.5, 0.999)
  167. )
  168. dis_optimizer_A = torch.optim.Adam(
  169. itertools.chain(dis_A.parameters(), gen_BA.parameters()),
  170. lr=2e-4,
  171. betas=(0.5, 0.999)
  172. )
  173. dis_optimizer_B = torch.optim.Adam(
  174. itertools.chain(dis_B.parameters(), gen_BA.parameters()),
  175. lr=2e-4,
  176. betas=(0.5, 0.999)
  177. )
  178. # 画图函数
  179. def generate_image(model, test_input):
  180. predictions = model(test_input).permute(0, 2, 3, 1).detach().cpu().numpy()
  181. test_input = test_input.permute(0, 2, 3, 1).cpu().numpy()
  182. title_list = ["input", "output"]
  183. display_list = [test_input[0], predictions[0]]
  184. fig = plt.figure(figsize=(10, 6))
  185. for i in range(2):
  186. plt.subplot(1, 2, i + 1)
  187. plt.title(title_list[i])
  188. plt.imshow(display_list[i] * 0.5 + 0.5)
  189. plt.axis("off")
  190. plt.show()
  191. test_batch = next(iter(man_dl_test))
  192. # 因为我们只用一张图片作为test_input,所以没有bartchsize,因此用unsquezze设置batchsize为0
  193. test_input = torch.unsqueeze(test_batch[0], 0).to(device)
  194. # 训练模型
  195. D_loss = []
  196. G_loss = []
  197. best_gen_loss = float("inf")
  198. EPOCH = 5
  199. for epoch in range(EPOCH):
  200. D_epoch_loss = 0
  201. G_epoch_loss = 0
  202. count = min(len(man_dl), len(woman_dl))
  203. for step, (real_A, real_B) in enumerate(zip(man_dl, woman_dl)):
  204. real_A = real_A.to(device)
  205. real_B = real_B.to(device)
  206. # 训练生成器(Generator)
  207. gen_optimizer.zero_grad()
  208. # identity loss
  209. same_A = gen_BA(real_A)
  210. same_A_loss = l1_loss(same_A, real_A)
  211. same_B = gen_AB(real_B)
  212. same_B_loss = l1_loss(same_B, real_B)
  213. # gan loss 对抗损失
  214. fake_A = gen_BA(real_B)
  215. fake_A_output = dis_A(fake_A)
  216. fake_A_output_loss = bceloss(fake_A_output, torch.ones_like(fake_A_output, device=device))
  217. fake_B = gen_AB(real_A)
  218. fake_B_output = dis_B(fake_B)
  219. fake_B_output_loss = bceloss(fake_B_output, torch.ones_like(fake_B_output, device=device))
  220. # cycle loss 循环一致损失
  221. recovered_A = gen_BA(fake_B)
  222. cycle_ABA_loss = l1_loss(recovered_A, real_A)
  223. recovered_B = gen_AB(fake_A)
  224. cycle_BAB_loss = l1_loss(recovered_B, real_B)
  225. g_loss = (same_A_loss + same_B_loss + fake_A_output_loss + fake_B_output_loss + cycle_ABA_loss + cycle_BAB_loss)
  226. g_loss.backward()
  227. gen_optimizer.step()
  228. # 训练判别器(Discriminator)
  229. # 训练dis_A
  230. dis_optimizer_A.zero_grad()
  231. real_A_output = dis_A(real_A)
  232. real_A_loss = bceloss(real_A_output, torch.ones_like(real_A_output))
  233. fake_A_output = dis_A(fake_A.detach())
  234. fake_A_loss = bceloss(fake_A_output, torch.zeros_like(fake_A_output))
  235. dis_A_loss = real_A_loss + fake_A_loss
  236. dis_A_loss.backward()
  237. dis_optimizer_A.step()
  238. # 训练dis_B
  239. dis_optimizer_B.zero_grad()
  240. real_B_output = dis_B(real_B)
  241. real_B_loss = bceloss(real_B_output, torch.ones_like(real_B_output))
  242. fake_B_output = dis_B(fake_B.detach())
  243. fake_B_loss = bceloss(fake_B_output, torch.zeros_like(fake_B_output))
  244. dis_B_loss = real_B_loss + fake_B_loss
  245. dis_B_loss.backward()
  246. dis_optimizer_B.step()
  247. with torch.no_grad():
  248. G_epoch_loss += g_loss.item()
  249. D_epoch_loss += (dis_A_loss + dis_B_loss).item()
  250. # 保存最好的模型
  251. if G_epoch_loss < best_gen_loss:
  252. best_gen_loss = G_epoch_loss
  253. # 保存生成器的状态字典
  254. torch.save(gen_AB.state_dict(), 'best_cycleGAN_model.pth')
  255. with torch.no_grad():
  256. D_epoch_loss /= count
  257. G_epoch_loss /= count
  258. D_loss.append(D_epoch_loss)
  259. G_loss.append(G_epoch_loss)
  260. print("Epoch:{}".format(epoch),
  261. "g_epoch_loss:{}".format(G_epoch_loss),
  262. "d_epoch_loss:{}".format(D_epoch_loss))
  263. # if epoch % 5 == 0:
  264. # generate_image(gen_AB, test_input)

使用训练好的模型:

  1. import os
  2. import torch
  3. import torchvision.utils
  4. from torchvision import transforms
  5. from PIL import Image
  6. import torch.nn as nn
  7. import torch.nn.functional as F
  8. # 定义下采样模块
  9. class Downsample(nn.Module):
  10. def __init__(self, in_channels, out_channels):
  11. super(Downsample, self).__init__()
  12. self.conv_relu = nn.Sequential(
  13. nn.Conv2d(in_channels=in_channels,
  14. out_channels=out_channels,
  15. kernel_size=3,
  16. stride=2,
  17. padding=1),
  18. nn.LeakyReLU(inplace=True) # inplce=True:就地修改输入张量
  19. )
  20. self.bn = nn.InstanceNorm2d(out_channels) # 因为这里我们需要优化每一张图片,所以要用instancenorm2d
  21. def forward(self, x, is_bn=True):
  22. x = self.conv_relu(x)
  23. if is_bn:
  24. x = self.bn(x)
  25. return x
  26. # 定义上采样模块
  27. class Upsample(nn.Module):
  28. def __init__(self, in_channels, out_channels):
  29. super(Upsample, self).__init__()
  30. self.upconv1 = nn.Sequential(
  31. nn.ConvTranspose2d(in_channels=in_channels,
  32. out_channels=out_channels,
  33. kernel_size=3,
  34. stride=2,
  35. padding=1,
  36. output_padding=1),
  37. nn.LeakyReLU(inplace=True) # inplce=True:就地修改输入张量
  38. )
  39. self.bn = nn.InstanceNorm2d(out_channels) # 因为这里我们需要优化每一张图片,所以要用instancenorm2d
  40. def forward(self, x, is_drop=False):
  41. x = self.upconv1(x)
  42. x = self.bn(x)
  43. if is_drop:
  44. x = F.dropout2d(x)
  45. return x
  46. class Generator(nn.Module):
  47. def __init__(self):
  48. super(Generator, self).__init__()
  49. self.down1 = Downsample(3, 64) # (64, 128, 128)
  50. self.down2 = Downsample(64, 128) # (128, 64, 64)
  51. self.down3 = Downsample(128, 256) # (256, 32, 32)
  52. self.down4 = Downsample(256, 512) # (512, 16, 16)
  53. self.down5 = Downsample(512, 512) # (512, 8, 8)
  54. self.down6 = Downsample(512, 512) # (512, 4, 4)
  55. self.up1 = Upsample(512, 512) # (512, 8, 8)
  56. self.up2 = Upsample(1024, 512) # (512, 16, 16)
  57. self.up3 = Upsample(1024, 256) # (256, 32, 32)
  58. self.up4 = Upsample(512, 128) # (128, 64, 64)
  59. self.up5 = Upsample(256, 64) # (64, 128, 128)
  60. self.last = nn.ConvTranspose2d(128, 3,
  61. kernel_size=3,
  62. stride=2,
  63. padding=1,
  64. output_padding=1)
  65. def forward(self, x):
  66. x1 = self.down1(x)
  67. x2 = self.down2(x1)
  68. x3 = self.down3(x2)
  69. x4 = self.down4(x3)
  70. x5 = self.down5(x4)
  71. x6 = self.down6(x5)
  72. x6 = self.up1(x6, is_drop=True)
  73. x6 = torch.cat([x6, x5], dim=1)
  74. x6 = self.up2(x6, is_drop=True)
  75. x6 = torch.cat([x6, x4], dim=1)
  76. x6 = self.up3(x6, is_drop=True)
  77. x6 = torch.cat([x6, x3], dim=1)
  78. x6 = self.up4(x6)
  79. x6 = torch.cat([x6, x2], dim=1)
  80. x6 = self.up5(x6)
  81. x6 = torch.cat([x6, x1], dim=1)
  82. x6 = torch.tanh(self.last(x6))
  83. return x6
  84. print("00000000")
  85. # 确保文件夹存在
  86. output_folder = "output"
  87. os.makedirs(output_folder, exist_ok=True)
  88. # 初始化生成器模型
  89. gen_AB = Generator()
  90. # 加载保存的模型状态字典
  91. gen_AB.load_state_dict(torch.load("best_cycleGAN_model.pth"))
  92. # 初始化数据集
  93. img_path = "input.jpg"
  94. transform = transforms.Compose([
  95. transforms.ToTensor(),
  96. transforms.Resize((256, 256)),
  97. transforms.Normalize(0.5, 0.5)
  98. ])
  99. img = Image.open(img_path)
  100. img = transform(img)
  101. img = img.unsqueeze(0)
  102. img = img
  103. output = gen_AB(img).detach().cpu()
  104. torchvision.utils.save_image((img + 1) / 2, os.path.join(output_folder, "output.jpg"))


 

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/繁依Fanyi0/article/detail/442497
推荐阅读
相关标签
  

闽ICP备14008679号