当前位置:   article > 正文

生成对抗网络cGAN(条件GAN)

cgan

1.介绍

论文:Conditional Generative Adversarial Nets

论文地址:https://arxiv.org/abs/1411.1784

针对原始GAN的缺点:生成的图像是随机的,不可预测的,无法控制网络输出特定的图片,生成目标不明确,可控性不强等问题。

改进方法:cGAN的中心思想是希望可以控制 GAN 生成的图片,而不是单纯的随机生成图片。 Conditional GAN 在生成器和判别器的输入中增加了额外的条件信息,生成器生成的图片只有足够真实且与条件相符,才能够通过判别器。其核心在于将属性信息融入生成器G和判别器D中,属性可以是任何标签信息, 例如图像的类别、人脸图像的面部表情等。

2.模型结构

在判别器和生成器中都添加了额外信息y,y可以是类别标签或者是其他类型的数据,可以将y作为一个额外的输入层引入判别器和生成器。

  1. import torch
  2. import torch.nn as nn
  3. import torch.nn.functional as F
  4. from torchvision import transforms, datasets, utils
  5. import matplotlib
  6. import matplotlib.pyplot as plt
  7. import numpy as np
  8. import torch.optim as optim
  9. from torchvision.datasets import ImageFolder
  10. import tqdm
  11. ROOT_TRAIN = r'D:\CNN\AlexNet\data1\train'
  12. def one_hot(x, num_class=2): #转化为独热标签
  13. return torch.eye(num_class)[x, :]
  14. train_transform = transforms.Compose([transforms.ToTensor(),
  15. transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
  16. train_dataset = ImageFolder(ROOT_TRAIN, transform=train_transform, target_transform=one_hot) # 加载训练集
  17. dataloader = torch.utils.data.DataLoader(train_dataset,
  18. batch_size=64,
  19. shuffle=True,
  20. num_workers=0)
  21. # print(train_dataset[0]) #返回数据和标签, 引入one_hot编码后,标签就为长度为num_class的tensor tensor([1., 0.]
  22. # 定义生成器,输入是长度为100的噪声(正态分布随机数),和标签独热编码(condition)
  23. # 输出为3*224*224的图片(tensor)
  24. class Generator(nn.Module):
  25. def __init__(self):
  26. super(Generator, self).__init__()
  27. self.linear1 = nn.Linear(100, 128*56*56)
  28. self.bn1 = nn.BatchNorm1d(128*56*56)
  29. self.linear2 = nn.Linear(2, 128*56*56)
  30. self.bn2 = nn.BatchNorm1d(128*56*56)
  31. self.deconv1 = nn.ConvTranspose2d(256, 128,
  32. kernel_size=(3, 3),
  33. stride=1,
  34. padding=1) #128*56*56
  35. self.bn3 = nn.BatchNorm2d(128)
  36. self.deconv2 = nn.ConvTranspose2d(128, 64,
  37. kernel_size=(4, 4),
  38. stride=2,
  39. padding=1) # 64*112*112
  40. self.bn4 = nn.BatchNorm2d(64)
  41. self.deconv3 = nn.ConvTranspose2d(64, 3,
  42. kernel_size=(4, 4),
  43. stride=2,
  44. padding=1) # 3*224*224
  45. def forward(self, x1, x2): #x1为噪声输入,x2标签独热编码输入(condition)
  46. x1 = F.relu(self.linear1(x1)) #100 -- 128*56*56
  47. x1 = self.bn1(x1)
  48. x2 = F.relu(self.linear2(x2)) #num_class -- 128*56*56
  49. x2 = self.bn2(x2)
  50. x1 = x1.view(-1, 128, 56, 56)
  51. x2 = x2.view(-1, 128, 56, 56)
  52. x = torch.cat([x1, x2], dim=1) #256*56*56
  53. x = F.relu(self.deconv1(x)) #256*56*56 -- 128*56*56
  54. x = self.bn3(x)
  55. x = F.relu(self.deconv2(x)) #128*56*56 -- 64*112*112
  56. x = self.bn4(x)
  57. x = torch.tanh(self.deconv3(x)) #64*112*112 -- 3*224*224 生成器的输出不使用bn层
  58. return x
  59. # 定义判别器,输入为3*224*224的图片,输出为二分类概率值
  60. class Discriminator(nn.Module):
  61. def __init__(self):
  62. super(Discriminator, self).__init__()
  63. self.linear = nn.Linear(2, 3*224*224)
  64. self.conv1 = nn.Conv2d(6, 64, kernel_size=3, stride=2)
  65. self.conv2 = nn.Conv2d(64, 128, kernel_size=3, stride=2)
  66. self.bn = nn.BatchNorm2d(128)
  67. self.fc = nn.Linear(128*55*55, 1)
  68. def forward(self, x1, x2): #x1为真实图像输入,x2标签独热编码输入(condition)
  69. x2 = self.linear(x2)
  70. x2 = x2.view(-1, 3, 224, 224)
  71. x = torch.cat([x1, x2], dim=1) #batchsize, 6, 224, 224
  72. x = F.dropout2d(F.leaky_relu(self.conv1(x)), p=0.3) #64*111*111 判别器的输入不使用bn层
  73. x = F.dropout2d(F.leaky_relu(self.conv2(x)), p=0.3) #128*55*55
  74. x = self.bn(x)
  75. x = x.view(-1, 128*55*55) #展平
  76. x = torch.sigmoid(self.fc(x))
  77. return x
  78. device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
  79. gen = Generator().to(device)
  80. dis = Discriminator().to(device)
  81. # 判别器优化器
  82. d_optim = torch.optim.Adam(dis.parameters(), lr=1e-4) #通过减小判别器的学习率降低其能力
  83. # 生成器优化器
  84. g_optim = torch.optim.Adam(gen.parameters(), lr=1e-3)
  85. loss_fn = torch.nn.BCELoss() # 二元交叉熵损失
  86. # 绘图函数,将每一个epoch中生成器生成的图片绘制
  87. def gen_img_plot(model, epoch, noise_input, label_input): # model为Generator,test_input代表生成器输入的随机数,label_input为标签输入
  88. # prediction = np.squeeze(model(test_input).detach().cpu().numpy()) #squeeze为去掉通道维度
  89. prediction = model(noise_input, label_input).permute(0, 2, 3, 1).cpu().numpy() #将通道维度放在最后
  90. plt.figure(figsize=(10, 10))
  91. for i in range(prediction.shape[0]): #prediction.shape[0]=noise_input的batchsize
  92. plt.subplot(2, 2, i + 1)
  93. plt.imshow((prediction[i]+1)/2) #从-1~1 --> 0~1
  94. plt.axis('off')
  95. plt.savefig('./CGAN_img/image_CGAN_{}.png'.format(epoch))
  96. # if epoch == 99:
  97. # plt.show()
  98. # 设置生成绘图图片的随机张量,这里可视化4张图片
  99. noise_input = torch.randn(4, 100, device=device) #测试输入:16个长度为100的随机数
  100. # print(noise_input)
  101. label_input0 = torch.randint(0, 1, size=(4, )) #生成4个从01的随机整数
  102. # print(label_input)
  103. label_input_onehot = one_hot(label_input0).to(device) #将tensor转化为独热编码形式
  104. # print(label_input_onehot)
  105. # CGAN训练
  106. D_loss = []
  107. G_loss = []
  108. for epoch in range(500):
  109. d_epoch_loss = 0 #判别器损失
  110. g_epoch_loss = 0 #生成器损失
  111. count = len(dataloader) #len(dataloader)返回批次数
  112. count1 = len(train_dataset) #len(train_dataset)返回样本数
  113. for step, (img, label) in enumerate(tqdm.tqdm(dataloader)): #此时返回的label已经是独热标签
  114. img = img.to(device)
  115. label = label.to(device)
  116. size = img.size(0) #该批次包含多少张图片
  117. random_noise = torch.randn(size, 100, device=device) #创建生成器的噪声输入
  118. d_optim.zero_grad() #判别器梯度清0
  119. real_output = dis(img, label) #将真实图像放到判别器上进行判断,得到对真实图像的预测结果
  120. d_real_loss = loss_fn(real_output, torch.ones_like(real_output)) #real_output应该被判定为1(真),得到判别器在真实图像上的损失
  121. d_real_loss.backward() #计算梯度
  122. gen_img = gen(random_noise, label) #得到生成图像
  123. fake_output = dis(gen_img.detach(), label) #将生成图像和对应的标签同时放到判别器上进行判断,得到对生成图像的预测结果,detach()为截断梯度
  124. d_fake_loss = loss_fn(fake_output, torch.zeros_like(fake_output)) #fake_output应该被判定为0(假),得到判别器在生成图像上的损失
  125. d_fake_loss.backward() # 计算梯度
  126. d_loss = d_real_loss + d_fake_loss #判别器的损失包含两部分
  127. d_optim.step() #判别器优化
  128. # 生成器
  129. g_optim.zero_grad() #生成器梯度清零
  130. fake_output = dis(gen_img, label) #将生成图像放到判别器上进行判断
  131. g_loss = loss_fn(fake_output, torch.ones_like(fake_output)) #此处希望生成的图像能被判定为1
  132. g_loss.backward() # 计算梯度
  133. g_optim.step() #生成器优化
  134. with torch.no_grad(): # loss累加的过程不需要计算梯度
  135. d_epoch_loss += d_loss.item() #将每一个批次的损失累加
  136. g_epoch_loss += g_loss.item() #将每一个批次的损失累加
  137. with torch.no_grad(): # loss累加的过程不需要计算梯度
  138. g_epoch_loss /= count
  139. d_epoch_loss /= count
  140. D_loss.append(d_epoch_loss) #保存每一个epoch的平均loss
  141. G_loss.append(g_epoch_loss) #保存每一个epoch的平均loss
  142. print('Epoch:', epoch)
  143. gen_img_plot(gen, epoch, noise_input, label_input_onehot) #每个epoch会生成一张图
  144. plt.figure(figsize=(10, 10))
  145. plt.plot(range(1, len(D_loss) + 1), D_loss, label='D_loss')
  146. plt.plot(range(1, len(G_loss) + 1), G_loss, label='G_loss')
  147. plt.xlabel('epoch') # 横轴名称
  148. plt.legend()
  149. plt.savefig('loss.png') # 保存图片

cGAN生成的图像虽有很多缺陷,如图像边缘模糊,生成的图像分辨率太低等,但是它为后面的pix2pixGAN和CycleGAN开拓了道路!!!

最后放上我训练的结果(数据量不大,只有四百张狗的图片,效果不太明显!!!) 

 

 

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/Cpp五条/article/detail/415189
推荐阅读
相关标签
  

闽ICP备14008679号