赞
踩
论文:Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks
论文地址:https://arxiv.org/abs/1703.10593
什么是CycleGAN:CycleGAN主要用于图像之间的转换,假设有两个不成对的图像X和Y,算法训练去学习一个“自动相互转换”,训练时不需要成对的配对样本,只需要源域和目标域的图像。训练后网络就能实现对图像源域到目标域的迁移。CycleGAN适用于非配对的图像到图像转换,解决了模型需要成对数据进行训练的困难。
与pix2pixGAN的区别:二者都可以做图像变换,pix2pix模型必须要求成对数据(paired data),而CycleGAN利用非成对数据也能进行训练(unpaired data)。
相关工作:
CycleGAN其实就是一个 A→B 的单向 GAN 加上一个 B→A 的单向 GAN。两个 GAN 共享两个生成器,然后各自带一个判别器,所以加起来总共有两个判别器和两个生成器。一个单向 GAN 有两个 loss, 故 CycleGAN 加起来总共有四个 loss。
循环一致损失:因为网络需要保证生成的图像必须保留有原 始图像的特性,所以如果我们使用生成器GenratorA-B生 成一张假图像,那么要能够使用另外一个生成器 GenratorB-A来努力恢复成原始图像。此过程必须满足循环一致性。
identity loss:可以理解为,生成器是负责域x到域y的图像生成, 如果输入域y的图片还是应该生成域y的图片。
- # 用狗的图像生成猫的图像
- import itertools
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
- from torchvision import transforms, datasets, utils
- import matplotlib.pyplot as plt
- import numpy as np
- import torch.optim as optim
- from torch.utils.data.dataset import Dataset
- from PIL import Image
- import tqdm
- import glob
-
- dogs_path = glob.glob('D:\cnn\All Classfication\AlexNet\data/train\Dog/*.jpg') #获取数据集中的.jpg图片
- cats_path = glob.glob('D:\cnn\All Classfication\AlexNet\data/train\Cat/*.jpg') #获取数据集中的.jpg图片
- # print(cats_path[:3])
- # print(dogs_path[:3])
- cats_path_test = glob.glob('D:\cnn\All Classfication\AlexNet\data/val\Cat/*.jpg') #获取数据集中的.jpg图片
- dogs_path_test = glob.glob('D:\cnn\All Classfication\AlexNet\data/val\Dog/*.jpg') #获取数据集中的.jpg图片
-
- transform = transforms.Compose([transforms.ToTensor(),
- transforms.Resize((256, 256)),
- transforms.Normalize(mean=0.5, std=0.5)]) #Normalize为转化到-1~1之间
-
- # 定义数据读取
- class SGANDataset(Dataset):
- def __init__(self, imgs_path): #初始化
- super(SGANDataset, self).__init__()
- self.imgs_path = imgs_path #定义属性
-
- def __len__(self):
- return len(self.imgs_path)
-
- def __getitem__(self, index): #对数据切片
- img_path = self.imgs_path[index]
-
- # 从文件中读取图像
- pil_img = Image.open(img_path)
- pil_img = transform(pil_img)
- return pil_img
-
- # 初始化训练集
- dog_dataset = SGANDataset(dogs_path) #创建dataset
- cat_dataset = SGANDataset(cats_path) #创建dataset
-
- # 初始化测试集
- dog_dataset_test = SGANDataset(dogs_path_test) #创建dataset
- cat_dataset_test = SGANDataset(cats_path_test) #创建dataset
-
- dog_dataloader = torch.utils.data.DataLoader(dog_dataset, batch_size=4, shuffle=True)
- cat_dataloader = torch.utils.data.DataLoader(cat_dataset, batch_size=4, shuffle=True)
-
- dog_dataloader_test = torch.utils.data.DataLoader(dog_dataset_test, batch_size=4)
- cat_dataloader_test = torch.utils.data.DataLoader(cat_dataset_test, batch_size=4)
-
- # cat_bath = next(iter(cat_dataloader)) #查看
- # dog_bath = next(iter(dog_dataloader)) #查看
- # print(dog_bath.shape) #torch.Size([4, 3, 256, 256])
- # print(cat_bath.shape) #torch.Size([4, 3, 256, 256])
-
- # 查看数据集
- # plt.figure(figsize=(8, 12))
- # for i, (dog, cat) in enumerate(zip(dog_bath[:3], cat_bath[:3])): #zip代表元组
- # # 因为dataset返回的数据是tensor,需要转为numpy格式,因为Normalize为转化到-1~1之间,所以加1再除以2将其转化到0~1之间
- # dog = (dog.permute(1, 2, 0).numpy() + 1) / 2
- # cat = (cat.permute(1, 2, 0).numpy() + 1) / 2
- # plt.subplot(3, 2, 2*i+1)
- # plt.title('dog')
- # plt.imshow(dog)
- # plt.subplot(3, 2, 2*i+2)
- # plt.title('cat')
- # plt.imshow(cat)
- # plt.show()
-
-
- #定义下采样模块
- class Downsample(nn.Module):
- def __init__(self, in_channels, out_channels):
- super(Downsample, self).__init__()
- self.conv_relu = nn.Sequential(
- nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, padding=1),
- nn.LeakyReLU(inplace=True)
- )
- self.bn = nn.InstanceNorm2d(out_channels)
-
- def forward(self, x, is_bn=True): #is_bn用于确定是否使用bn层,默认为True
- x = self.conv_relu(x)
- if is_bn:
- x = self.bn(x)
- return x
-
- #定义上采样模块
- class Upsample(nn.Module):
- def __init__(self, in_channels, out_channels):
- super(Upsample, self).__init__()
- self.upconv_relu = nn.Sequential(
- nn.ConvTranspose2d(in_channels, out_channels, kernel_size=3, stride=2, padding=1, output_padding=1),
- nn.LeakyReLU(inplace=True)
- )
- self.bn = nn.InstanceNorm2d(out_channels)
-
- def forward(self, x, is_drop=False): #is_drop用于确定是否使用drop层,默认为False
- x = self.upconv_relu(x)
- x = self.bn(x)
- if is_drop:
- x = F.dropout2d(x)
- return x
-
- # 定义生成器,包含6个下采样层,6个上采样层
- class Generator(nn.Module):
- def __init__(self):
- super(Generator, self).__init__()
- self.down1 = Downsample(3, 64) #3,256,256 -- 64,128,128
- self.down2 = Downsample(64, 128) #64,128,128 -- 128,64,64
- self.down3 = Downsample(128, 256) #128,64,64 -- 256,32,32
- self.down4 = Downsample(256, 512) #256,32,32 -- 512,16,16
- self.down5 = Downsample(512, 512) #512,16,16 -- 512,8,8
- self.down6 = Downsample(512, 512) #512,8,8 -- 512,4,4
-
- self.up1 = Upsample(512, 512) #512,4,4 -- 512,8,8
- self.up2 = Upsample(1024, 512) #1024,8,8 -- 512,16,16
- self.up3 = Upsample(1024, 256) #1024,16,16 -- 256,32,32
- self.up4 = Upsample(512, 128) #512,32,32 -- 128,64,64
- self.up5 = Upsample(256, 64) #256,64,64 -- 64,128,128
- #128,128,128 -- 3,256,256
- self.last = nn.ConvTranspose2d(128, 3, kernel_size=3, stride=2, padding=1, output_padding=1)
-
- def forward(self, x):
- x1 = self.down1(x)
- x2 = self.down2(x1)
- x3 = self.down3(x2)
- x4 = self.down4(x3)
- x5 = self.down5(x4)
- x6 = self.down6(x5)
-
- x6 = self.up1(x6, is_drop=True)
- x6 = torch.cat([x6, x5], dim=1)
-
- x6 = self.up2(x6, is_drop=True)
- x6 = torch.cat([x6, x4], dim=1)
-
- x6 = self.up3(x6, is_drop=True)
- x6 = torch.cat([x6, x3], dim=1)
-
- x6 = self.up4(x6)
- x6 = torch.cat([x6, x2], dim=1)
-
- x6 = self.up5(x6)
- x6 = torch.cat([x6, x1], dim=1)
-
- x6 = torch.tanh(self.last(x6))
-
- return x6
-
- # 定义判别器
- class Discriminator(nn.Module):
- def __init__(self):
- super(Discriminator, self).__init__()
- self.down1 = Downsample(3, 64)
- self.down2 = Downsample(64, 128)
- self.last = nn.Conv2d(128, 1, 3)
-
- def forward(self, img):
- x = self.down1(img)
- x = self.down2(x)
- x =torch.sigmoid(self.last(x))
- return x
-
- device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
-
- # 初始化两个生成器
- gen_AB = Generator().to(device)
- gen_BA = Generator().to(device)
-
- # 初始化两个判别器
- dis_A = Discriminator().to(device)
- dis_B = Discriminator().to(device)
-
- # 损失函数 1.gan loss 2.cycle consistance loss 3.identity loss
- bce_loss = torch.nn.BCELoss()
- l1_loss = torch.nn.L1Loss()
-
- # 初始化优化器
- # 对两个生成器同时进行优化, 使用itertools.chain对二者同时进行迭代
- gen_optimizer = torch.optim.Adam(itertools.chain(gen_AB.parameters(), gen_BA.parameters()), lr=2e-4, betas=(0.5, 0.999))
-
- # 对两个判别器分别进行优化
- dis_A_optimizer = torch.optim.Adam(dis_A.parameters(), lr=2e-4, betas=(0.5, 0.999))
- dis_B_optimizer = torch.optim.Adam(dis_B.parameters(), lr=2e-4, betas=(0.5, 0.999))
-
- # 绘图函数,将每一个epoch中生成器生成的图片绘制
- def gen_img_plot(model, epoch, test_input): # model为gen_AB/gen_BA,test_input
- generate = model(test_input).permute(0, 2, 3, 1).cpu().numpy() #将通道维度放在最后
- test_input = test_input.permute(0, 2, 3, 1).cpu().numpy() #1,3,256,256 -- 1,256,256,3
- plt.figure(figsize=(10, 6))
- display_list = [test_input[0], generate[0]]
- title = ['Input image', 'Generate image']
- for i in range(2):
- plt.subplot(1, 2, i + 1)
- plt.title(title[i])
- plt.imshow((display_list[i]+1)/2) #从-1~1 --> 0~1
- plt.axis('off')
- plt.savefig('./image/image_at_{}.png'.format(epoch))
-
- test_batch = next(iter(dog_dataloader_test)) #batch_size,3,256,256
- # 测试输入:选取test_batch中的第一张图片,并添加一个batch_size维度 3,256,256--1,3,256,256
- test_input = torch.unsqueeze(test_batch[0], 0).to(device)
-
- # cycleGAN训练
- D_loss = []
- G_loss = []
- epochs = 50
- for epoch in range(epochs):
- d_epoch_loss = 0
- g_epoch_loss = 0
- for step, (real_A, real_B) in enumerate(zip(dog_dataloader, cat_dataloader)): #取出真实的狗,猫图片
- real_A = real_A.to(device)
- real_B = real_B.to(device)
- #--------------------begin--------------------#
- # 生成器训练
- gen_optimizer.zero_grad() #训练之前梯度清0
- # identity loss
- same_B = gen_AB(real_B) #真实的B经过生成器gen_AB还是要得到真实的B
- identity_B_loss = l1_loss(same_B, real_B)
- same_A = gen_AB(real_A) #真实的A经过生成器gen_BA还是要得到真实的A
- identity_A_loss = l1_loss(same_A, real_A)
- # 对抗损失 gan loss
- fake_B = gen_AB(real_A) #真实A通过生成器生成了B,此时生成器希望判别器将其判别为真
- D_pred_fake_B = dis_B(fake_B)
- gen_loss_AB = bce_loss(D_pred_fake_B, torch.ones_like(D_pred_fake_B, device=device))
- fake_A = gen_BA(real_B) #真实B通过生成器生成了A,此时生成器希望判别器将其判别为真
- D_pred_fake_A = dis_A(fake_A)
- gen_loss_BA = bce_loss(D_pred_fake_A, torch.ones_like(D_pred_fake_A, device=device))
- # 循环一致损失
- recovered_A = gen_BA(fake_B)
- cycle_loss_ABA = l1_loss(recovered_A, real_A)
-
- recovered_B = gen_AB(fake_A)
- cycle_loss_BAB = l1_loss(recovered_B, real_B)
-
- # 生成器总的损失
- g_loss = identity_A_loss + identity_B_loss + gen_loss_AB + gen_loss_BA +cycle_loss_ABA + cycle_loss_BAB
-
- g_loss.backward()
- gen_optimizer.step()
- # --------------------end--------------------#
-
- # --------------------begin--------------------#
- # 判别器训练
- # dis_A训练
- dis_A_optimizer.zero_grad()
- dis_A_real_output = dis_A(real_A) #输入为真,期望判定为真
- dis_A_real_loss = bce_loss(dis_A_real_output, torch.ones_like(dis_A_real_output, device=device))
-
- dis_A_fake_output = dis_A(fake_A.detach()) #输入为假,期望判定为假,梯度截断
- dis_A_fake_loss = bce_loss(dis_A_fake_output, torch.zeros_like(dis_A_fake_output, device=device))
-
- dis_A_loss = dis_A_real_loss + dis_A_fake_loss #生成器A的总损失
- dis_A_loss.backward()
- dis_A_optimizer.step()
-
- # dis_B训练
- dis_B_optimizer.zero_grad()
- dis_B_real_output = dis_B(real_B) #输入为真,期望判定为真
- dis_B_real_loss = bce_loss(dis_B_real_output, torch.ones_like(dis_B_real_output, device=device))
-
- dis_B_fake_output = dis_B(fake_B.detach()) #输入为假,期望判定为假,梯度截断
- dis_B_fake_loss = bce_loss(dis_B_fake_output, torch.zeros_like(dis_B_fake_output, device=device))
-
- dis_B_loss = dis_B_real_loss + dis_B_fake_loss #生成器B的总损失
- dis_B_loss.backward()
- dis_B_optimizer.step()
- # --------------------end--------------------#
-
- with torch.no_grad():
- g_epoch_loss += g_loss.item() #将每一个批次的loss累加
- d_epoch_loss += (dis_A_loss + dis_B_loss).item() # 将每一个批次的loss累加
-
- with torch.no_grad():
- g_epoch_loss /= (step + 1) #求得每一轮的平均loss
- d_epoch_loss /= (step + 1) #求得每一轮的平均loss
- D_loss.append(d_epoch_loss)
- G_loss.append(g_epoch_loss)
- print('epoch:', epoch, 'g_epoch_loss:', g_epoch_loss, 'd_epoch_loss:', d_epoch_loss)
- gen_img_plot(gen_AB, epoch, test_input)
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。