赞
踩
第一步:CycleGan介绍
CycleGAN主要用于图像之间的转换,假设有两个不成对的图像X和Y,算法训练去学习一个“自动相互转换”,训练时不需要成对的配对样本,只需要源域和目标域的图像。训练后网络就能实现对图像源域到目标域的迁移。CycleGAN适用于非配对的图像到图像转换,解决了模型需要成对数据进行训练的困难。
第二步:CycleGan网络结构
第三步:模型代码展示
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
- from torch.nn.parameter import Parameter
-
-
- class ResnetGenerator(nn.Module):
- def __init__(self, ngf=64, img_size=256, light=False):
- super(ResnetGenerator, self).__init__()
- self.light = light
-
- self.ConvBlock1 = nn.Sequential(nn.ReflectionPad2d(3),
- nn.Conv2d(3, ngf, kernel_size=7, stride=1, padding=0, bias=False),
- nn.InstanceNorm2d(ngf),
- nn.ReLU(True))
-
- self.HourGlass1 = HourGlass(ngf, ngf)
- self.HourGlass2 = HourGlass(ngf, ngf)
-
- # Down-Sampling
- self.DownBlock1 = nn.Sequential(nn.ReflectionPad2d(1),
- nn.Conv2d(ngf, ngf*2, kernel_size=3, stride=2, padding=0, bias=False),
- nn.InstanceNorm2d(ngf * 2),
- nn.ReLU(True))
-
- self.DownBlock2 = nn.Sequential(nn.ReflectionPad2d(1),
- nn.Conv2d(ngf*2, ngf*4, kernel_size=3, stride=2, padding=0, bias=False),
- nn.InstanceNorm2d(ngf*4),
- nn.ReLU(True))
-
- # Encoder Bottleneck
- self.EncodeBlock1 = ResnetBlock(ngf*4)
- self.EncodeBlock2 = ResnetBlock(ngf*4)
- self.EncodeBlock3 = ResnetBlock(ngf*4)
- self.EncodeBlock4 = ResnetBlock(ngf*4)
-
- # Class Activation Map
- self.gap_fc = nn.Linear(ngf*4, 1)
- self.gmp_fc = nn.Linear(ngf*4, 1)
- self.conv1x1 = nn.Conv2d(ngf*8, ngf*4, kernel_size=1, stride=1)
- self.relu = nn.ReLU(True)
-
- # Gamma, Beta block
- if self.light:
- self.FC = nn.Sequential(nn.Linear(ngf*4, ngf*4),
- nn.ReLU(True),
- nn.Linear(ngf*4, ngf*4),
- nn.ReLU(True))
- else:
- self.FC = nn.Sequential(nn.Linear(img_size//4*img_size//4*ngf*4, ngf*4),
- nn.ReLU(True),
- nn.Linear(ngf*4, ngf*4),
- nn.ReLU(True))
-
- # Decoder Bottleneck
- self.DecodeBlock1 = ResnetSoftAdaLINBlock(ngf*4)
- self.DecodeBlock2 = ResnetSoftAdaLINBlock(ngf*4)
- self.DecodeBlock3 = ResnetSoftAdaLINBlock(ngf*4)
- self.DecodeBlock4 = ResnetSoftAdaLINBlock(ngf*4)
-
- # Up-Sampling
- self.UpBlock1 = nn.Sequential(nn.Upsample(scale_factor=2),
- nn.ReflectionPad2d(1),
- nn.Conv2d(ngf*4, ngf*2, kernel_size=3, stride=1, padding=0, bias=False),
- LIN(ngf*2),
- nn.ReLU(True))
-
- self.UpBlock2 = nn.Sequential(nn.Upsample(scale_factor=2),
- nn.ReflectionPad2d(1),
- nn.Conv2d(ngf*2, ngf, kernel_size=3, stride=1, padding=0, bias=False),
- LIN(ngf),
- nn.ReLU(True))
-
- self.HourGlass3 = HourGlass(ngf, ngf)
- self.HourGlass4 = HourGlass(ngf, ngf, False)
-
- self.ConvBlock2 = nn.Sequential(nn.ReflectionPad2d(3),
- nn.Conv2d(3, 3, kernel_size=7, stride=1, padding=0, bias=False),
- nn.Tanh())
-
- def forward(self, x):
- x = self.ConvBlock1(x)
- x = self.HourGlass1(x)
- x = self.HourGlass2(x)
-
- x = self.DownBlock1(x)
- x = self.DownBlock2(x)
-
- x = self.EncodeBlock1(x)
- content_features1 = F.adaptive_avg_pool2d(x, 1).view(x.shape[0], -1)
- x = self.EncodeBlock2(x)
- content_features2 = F.adaptive_avg_pool2d(x, 1).view(x.shape[0], -1)
- x = self.EncodeBlock3(x)
- content_features3 = F.adaptive_avg_pool2d(x, 1).view(x.shape[0], -1)
- x = self.EncodeBlock4(x)
- content_features4 = F.adaptive_avg_pool2d(x, 1).view(x.shape[0], -1)
-
- gap = F.adaptive_avg_pool2d(x, 1)
- gap_logit = self.gap_fc(gap.view(x.shape[0], -1))
- gap_weight = list(self.gap_fc.parameters())[0]
- gap = x * gap_weight.unsqueeze(2).unsqueeze(3)
-
- gmp = F.adaptive_max_pool2d(x, 1)
- gmp_logit = self.gmp_fc(gmp.view(x.shape[0], -1))
- gmp_weight = list(self.gmp_fc.parameters())[0]
- gmp = x * gmp_weight.unsqueeze(2).unsqueeze(3)
-
- cam_logit = torch.cat([gap_logit, gmp_logit], 1)
- x = torch.cat([gap, gmp], 1)
- x = self.relu(self.conv1x1(x))
-
- heatmap = torch.sum(x, dim=1, keepdim=True)
-
- if self.light:
- x_ = F.adaptive_avg_pool2d(x, 1)
- style_features = self.FC(x_.view(x_.shape[0], -1))
- else:
- style_features = self.FC(x.view(x.shape[0], -1))
-
- x = self.DecodeBlock1(x, content_features4, style_features)
- x = self.DecodeBlock2(x, content_features3, style_features)
- x = self.DecodeBlock3(x, content_features2, style_features)
- x = self.DecodeBlock4(x, content_features1, style_features)
-
- x = self.UpBlock1(x)
- x = self.UpBlock2(x)
-
- x = self.HourGlass3(x)
- x = self.HourGlass4(x)
- out = self.ConvBlock2(x)
-
- return out, cam_logit, heatmap
第四步:运行
第五步:整个工程的内容
代码的下载路径(新窗口打开链接):基于深度学习CycleGan转卡通头像系统
有问题可以私信或者留言,有问必答
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。