当前位置:   article > 正文

基于深度学习神经网络CycleGan转卡通头像系统_cyclegan 卡通头像

cyclegan 卡通头像

第一步:CycleGan介绍

        CycleGAN主要用于图像之间的转换,假设有两个不成对的图像X和Y,算法训练去学习一个“自动相互转换”,训练时不需要成对的配对样本,只需要源域和目标域的图像。训练后网络就能实现对图像源域到目标域的迁移。CycleGAN适用于非配对的图像到图像转换,解决了模型需要成对数据进行训练的困难。

第二步:CycleGan网络结构

第三步:模型代码展示

  1. import torch
  2. import torch.nn as nn
  3. import torch.nn.functional as F
  4. from torch.nn.parameter import Parameter
  5. class ResnetGenerator(nn.Module):
  6. def __init__(self, ngf=64, img_size=256, light=False):
  7. super(ResnetGenerator, self).__init__()
  8. self.light = light
  9. self.ConvBlock1 = nn.Sequential(nn.ReflectionPad2d(3),
  10. nn.Conv2d(3, ngf, kernel_size=7, stride=1, padding=0, bias=False),
  11. nn.InstanceNorm2d(ngf),
  12. nn.ReLU(True))
  13. self.HourGlass1 = HourGlass(ngf, ngf)
  14. self.HourGlass2 = HourGlass(ngf, ngf)
  15. # Down-Sampling
  16. self.DownBlock1 = nn.Sequential(nn.ReflectionPad2d(1),
  17. nn.Conv2d(ngf, ngf*2, kernel_size=3, stride=2, padding=0, bias=False),
  18. nn.InstanceNorm2d(ngf * 2),
  19. nn.ReLU(True))
  20. self.DownBlock2 = nn.Sequential(nn.ReflectionPad2d(1),
  21. nn.Conv2d(ngf*2, ngf*4, kernel_size=3, stride=2, padding=0, bias=False),
  22. nn.InstanceNorm2d(ngf*4),
  23. nn.ReLU(True))
  24. # Encoder Bottleneck
  25. self.EncodeBlock1 = ResnetBlock(ngf*4)
  26. self.EncodeBlock2 = ResnetBlock(ngf*4)
  27. self.EncodeBlock3 = ResnetBlock(ngf*4)
  28. self.EncodeBlock4 = ResnetBlock(ngf*4)
  29. # Class Activation Map
  30. self.gap_fc = nn.Linear(ngf*4, 1)
  31. self.gmp_fc = nn.Linear(ngf*4, 1)
  32. self.conv1x1 = nn.Conv2d(ngf*8, ngf*4, kernel_size=1, stride=1)
  33. self.relu = nn.ReLU(True)
  34. # Gamma, Beta block
  35. if self.light:
  36. self.FC = nn.Sequential(nn.Linear(ngf*4, ngf*4),
  37. nn.ReLU(True),
  38. nn.Linear(ngf*4, ngf*4),
  39. nn.ReLU(True))
  40. else:
  41. self.FC = nn.Sequential(nn.Linear(img_size//4*img_size//4*ngf*4, ngf*4),
  42. nn.ReLU(True),
  43. nn.Linear(ngf*4, ngf*4),
  44. nn.ReLU(True))
  45. # Decoder Bottleneck
  46. self.DecodeBlock1 = ResnetSoftAdaLINBlock(ngf*4)
  47. self.DecodeBlock2 = ResnetSoftAdaLINBlock(ngf*4)
  48. self.DecodeBlock3 = ResnetSoftAdaLINBlock(ngf*4)
  49. self.DecodeBlock4 = ResnetSoftAdaLINBlock(ngf*4)
  50. # Up-Sampling
  51. self.UpBlock1 = nn.Sequential(nn.Upsample(scale_factor=2),
  52. nn.ReflectionPad2d(1),
  53. nn.Conv2d(ngf*4, ngf*2, kernel_size=3, stride=1, padding=0, bias=False),
  54. LIN(ngf*2),
  55. nn.ReLU(True))
  56. self.UpBlock2 = nn.Sequential(nn.Upsample(scale_factor=2),
  57. nn.ReflectionPad2d(1),
  58. nn.Conv2d(ngf*2, ngf, kernel_size=3, stride=1, padding=0, bias=False),
  59. LIN(ngf),
  60. nn.ReLU(True))
  61. self.HourGlass3 = HourGlass(ngf, ngf)
  62. self.HourGlass4 = HourGlass(ngf, ngf, False)
  63. self.ConvBlock2 = nn.Sequential(nn.ReflectionPad2d(3),
  64. nn.Conv2d(3, 3, kernel_size=7, stride=1, padding=0, bias=False),
  65. nn.Tanh())
  66. def forward(self, x):
  67. x = self.ConvBlock1(x)
  68. x = self.HourGlass1(x)
  69. x = self.HourGlass2(x)
  70. x = self.DownBlock1(x)
  71. x = self.DownBlock2(x)
  72. x = self.EncodeBlock1(x)
  73. content_features1 = F.adaptive_avg_pool2d(x, 1).view(x.shape[0], -1)
  74. x = self.EncodeBlock2(x)
  75. content_features2 = F.adaptive_avg_pool2d(x, 1).view(x.shape[0], -1)
  76. x = self.EncodeBlock3(x)
  77. content_features3 = F.adaptive_avg_pool2d(x, 1).view(x.shape[0], -1)
  78. x = self.EncodeBlock4(x)
  79. content_features4 = F.adaptive_avg_pool2d(x, 1).view(x.shape[0], -1)
  80. gap = F.adaptive_avg_pool2d(x, 1)
  81. gap_logit = self.gap_fc(gap.view(x.shape[0], -1))
  82. gap_weight = list(self.gap_fc.parameters())[0]
  83. gap = x * gap_weight.unsqueeze(2).unsqueeze(3)
  84. gmp = F.adaptive_max_pool2d(x, 1)
  85. gmp_logit = self.gmp_fc(gmp.view(x.shape[0], -1))
  86. gmp_weight = list(self.gmp_fc.parameters())[0]
  87. gmp = x * gmp_weight.unsqueeze(2).unsqueeze(3)
  88. cam_logit = torch.cat([gap_logit, gmp_logit], 1)
  89. x = torch.cat([gap, gmp], 1)
  90. x = self.relu(self.conv1x1(x))
  91. heatmap = torch.sum(x, dim=1, keepdim=True)
  92. if self.light:
  93. x_ = F.adaptive_avg_pool2d(x, 1)
  94. style_features = self.FC(x_.view(x_.shape[0], -1))
  95. else:
  96. style_features = self.FC(x.view(x.shape[0], -1))
  97. x = self.DecodeBlock1(x, content_features4, style_features)
  98. x = self.DecodeBlock2(x, content_features3, style_features)
  99. x = self.DecodeBlock3(x, content_features2, style_features)
  100. x = self.DecodeBlock4(x, content_features1, style_features)
  101. x = self.UpBlock1(x)
  102. x = self.UpBlock2(x)
  103. x = self.HourGlass3(x)
  104. x = self.HourGlass4(x)
  105. out = self.ConvBlock2(x)
  106. return out, cam_logit, heatmap

第四步:运行

第五步:整个工程的内容

代码的下载路径(新窗口打开链接)基于深度学习CycleGan转卡通头像系统

有问题可以私信或者留言,有问必答

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/知新_RL/article/detail/873418
推荐阅读
相关标签
  

闽ICP备14008679号