当前位置:   article > 正文

使用生成式对抗网络(GAN)生成动漫人物图像

使用生成式对抗网络(GAN)生成动漫人物图像

【图书推荐】《PyTorch深度学习与企业级项目实战》-CSDN博客

《PyTorch深度学习与企业级项目实战(人工智能技术丛书)》(宋立桓,宋立林)【摘要 书评 试读】- 京东图书 (jd.com)

如今AI艺术创作能力越来越强大,Google发布的ImageGen项目基于文本提示作画的结果和真实艺术家的成品难辨真假。本项目将使用PyTorch实现生成式对抗网络生成式对抗网络来完成AI生成动漫人物图像。

本项目中使用的数据集是一个由63 632个高质量动画人脸组成的数据集,从www.getchu.com中抓取,然后使用https://github.com/nagadomi/lbpcascade_animeface中的动画人脸检测算法进行裁剪。图像大小从90×90到120×120不等。该数据集包含高质量的动漫角色图像,具有干净的背景和丰富的颜色。数据集下载链接:https://github.com/bchao1/Anime-Face-Dataset

我们知道在生成式对抗网络中有两个模型——生成模型(Generative Model,G)和判别模型(Discriminative Model,D)。G就是一个生成图片的网络,它接收一个随机的噪声z,然后通过这个噪声生成图片,生成的数据记作G(z)。D是一个判别网络,判别一幅图片是不是“真实的”(是不是捏造的)。它的输入参数是x,x代表一幅图片,输出D(x)代表x为真实图片的概率,如果为1,就代表是真实的图片,而输出为0,就代表不可能是真实的图片。

  1. 定义生成器Generator:生成器的输入为100维的高斯噪声,生成器会利用这个噪声生成指定大小的图片,关于最初的噪声,可以看成10011的特征图,然后利用转置卷积来进行尺寸还原操作,标准的卷积操作是不断缩小尺寸,转置卷积就可以理解为它的逆操作,这样就可以不断放大图像。
  2. 定义判别器Discriminator:判别器就是一个典型的二分类网络,首先它的输入是我们输入的图片,我们会利用一系列卷积操作来形成一维特征图进行分类操作,这里可以发现判别器的网络和生成器的相关操作是可逆的,唯独不一样的是激活函数。

模型训练的步骤如下:

   步骤1:首先固定生成器,训练判别器,提高真实样本被判别为真的概率,同时降低生成器生成的假图像被判别为真的概率,目标是判别器能准确进行分类。

   步骤2:固定判别器,训练生成器,生成器生成图像,尽可能提高该图像被判别器判别为真的概率,目标是生成器的结果能够骗过判别器。

   步骤3:重复,循环交替训练,最终生成器生成的样本足够逼真,使得鉴别器只有大约50%的判断正确率(相当于乱猜)。

完整代码如下:

  1. #####################GANDEMO.py####################
  2. import os
  3. import torch
  4. import torch.nn as nn
  5. import torch.nn.functional as F
  6. from torch.utils.data import TensorDataset
  7. import torchvision
  8. from torchvision import transforms, datasets
  9. from tqdm import tqdm
  10. class Config(object):
  11. data_path = './gandata/data/'
  12. image_size = 96
  13. batch_size = 32
  14. epochs = 200
  15. lr1 = 2e-3
  16. lr2 = 2e-4
  17. beta1 = 0.5
  18. gpu = False
  19. device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
  20. nz = 100
  21. ngf = 64
  22. ndf = 64
  23. save_path = './gandata/images'
  24. generator_path = './gandata/generator.pkl' #模型保存路径
  25. discriminator_path = './gandata/discriminator.pkl' #模型保存路径
  26. gen_img = './gandata/result.png'
  27. gen_num = 64
  28. gen_search_num = 5000
  29. gen_mean = 0
  30. gen_std = 1
  31. config = Config()
  32. # 1.数据转换
  33. data_transform = transforms.Compose([
  34. transforms.Resize(config.image_size),
  35. transforms.CenterCrop(config.image_size),
  36. transforms.ToTensor(),
  37. transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
  38. ])
  39. # 2.形成训练集
  40. train_dataset = datasets.ImageFolder(root=os.path.join(config.data_path),
  41. transform=data_transform)
  42. # 3.形成迭代器
  43. train_loader = torch.utils.data.DataLoader(train_dataset,
  44. config.batch_size,
  45. True,
  46. drop_last=True)
  47. print('using {} images for training.'.format(len(train_dataset)))
  48. class Generator(nn.Module):
  49. def __init__(self, config):
  50. super().__init__()
  51. ngf = config.ngf
  52. self.model = nn.Sequential(
  53. nn.ConvTranspose2d(config.nz, ngf * 8, 4, 1, 0),
  54. nn.BatchNorm2d(ngf * 8),
  55. nn.ReLU(True),
  56. nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1),
  57. nn.BatchNorm2d(ngf * 4),
  58. nn.ReLU(True),
  59. nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1),
  60. nn.BatchNorm2d(ngf * 2),
  61. nn.ReLU(True),
  62. nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1),
  63. nn.BatchNorm2d(ngf),
  64. nn.ReLU(True),
  65. nn.ConvTranspose2d(ngf, 3, 5, 3, 1),
  66. nn.Tanh()
  67. )
  68. def forward(self, x):
  69. output = self.model(x)
  70. return output
  71. class Discriminator(nn.Module):
  72. def __init__(self, config):
  73. super().__init__()
  74. ndf = config.ndf
  75. self.model = nn.Sequential(
  76. nn.Conv2d(3, ndf, 5, 3, 1),
  77. nn.LeakyReLU(0.2, inplace=True),
  78. nn.Conv2d(ndf, ndf * 2, 4, 2, 1),
  79. nn.BatchNorm2d(ndf * 2),
  80. nn.LeakyReLU(0.2, inplace=True),
  81. nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1),
  82. nn.BatchNorm2d(ndf * 4),
  83. nn.LeakyReLU(0.2, inplace=True),
  84. nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1),
  85. nn.BatchNorm2d(ndf * 8),
  86. nn.LeakyReLU(0.2, inplace=True),
  87. nn.Conv2d(ndf * 8, 1, 4, 1, 0)
  88. )
  89. def forward(self, x):
  90. output = self.model(x)
  91. return output.view(-1)
  92. generator = Generator(config)
  93. discriminator = Discriminator(config)
  94. optimizer_generator = torch.optim.Adam(generator.parameters(),
  95. config.lr1,
  96. betas=(config.beta1, 0.999))
  97. optimizer_discriminator = torch.optim.Adam(discriminator.parameters(),
  98. config.lr2,
  99. betas=(config.beta1, 0.999))
  100. true_labels = torch.ones(config.batch_size)
  101. fake_labels = torch.zeros(config.batch_size)
  102. fix_noises = torch.randn(config.batch_size, config.nz, 1, 1)
  103. noises = torch.randn(config.batch_size, config.nz, 1, 1)
  104. for epoch in range(config.epochs):
  105. for ii, (img, _) in tqdm(enumerate(train_loader)):
  106. real_img = img.to(config.device)
  107. if ii % 2 == 0:
  108. optimizer_discriminator.zero_grad()
  109. r_preds = discriminator(real_img)
  110. noises.data.copy_(torch.randn(config.batch_size, config.nz, 1, 1))
  111. fake_img = generator(noises).detach()
  112. f_preds = discriminator(fake_img)
  113. r_f_diff = (r_preds - f_preds.mean()).clamp(max=1)
  114. f_r_diff = (f_preds - r_preds.mean()).clamp(min=-1)
  115. loss_d_real = (1 - r_f_diff).mean()
  116. loss_d_fake = (1 + f_r_diff).mean()
  117. loss_d = loss_d_real + loss_d_fake
  118. loss_d.backward()
  119. optimizer_discriminator.step()
  120. else:
  121. optimizer_generator.zero_grad()
  122. noises.data.copy_(torch.randn(config.batch_size, config.nz, 1, 1))
  123. fake_img = generator(noises)
  124. f_preds = discriminator(fake_img)
  125. r_preds = discriminator(real_img)
  126. r_f_diff = r_preds - torch.mean(f_preds)
  127. f_r_diff = f_preds - torch.mean(r_preds)
  128. loss_g = torch.mean(F.relu(1 + r_f_diff)) \
  129. + torch.mean(F.relu(1 - f_r_diff))
  130. loss_g.backward()
  131. optimizer_generator.step()
  132. if epoch == config.epochs - 1:
  133. # 保存模型
  134. torch.save(discriminator.state_dict(), config.discriminator_path)
  135. torch.save(generator.state_dict(), config.generator_path)
  136. print('Finished Training')
  137. generator = Generator(config)
  138. discriminator = Discriminator(config)
  139. noises = torch.randn(config.gen_search_num,
  140. config.nz, 1, 1).normal_(config.gen_mean,
  141. config.gen_std)
  142. noises = noises.to(config.device)
  143. generator.load_state_dict(torch.load(config.generator_path,
  144. map_location='cpu'))
  145. discriminator.load_state_dict(torch.load(config.discriminator_path,
  146. map_location='cpu'))
  147. generator.to(config.device)
  148. discriminator.to(config.device)
  149. fake_img = generator(noises)
  150. scores = discriminator(fake_img).detach()
  151. indexs = scores.topk(config.gen_num)[1]
  152. result = []
  153. for ii in indexs:
  154. result.append(fake_img.data[ii])
  155. torchvision.utils.save_image(torch.stack(result), config.gen_img,
  156. normalize=True, value_range=(-1, 1))

代码运行结果如下:

  1. using 900 images for training.
  2. 28it [00:20, 1.40it/s]
  3. 28it [00:20, 1.33it/s]
  4. 28it [00:21, 1.29it/s]
  5. 28it [00:26, 1.06it/s]
  6. Finished Training

效果图如图13-9所示,由于只训练了100个Epoch,因此图像生成的纹理还不算太清楚,大家计算资源允许的话,可以多训练一些Epoch来生成更多的图像细节。

图13-9

声明:本文内容由网友自发贡献,转载请注明出处:【wpsshop】
推荐阅读
相关标签
  

闽ICP备14008679号