赞
踩
本系列来源于365天深度学习训练营
原作者K同学
CycleGAN是一种无监督图像到图像转换模型,它的一个重要应用领域是域迁移,比如可以把一张普通的风景照变化成梵高化作,或者将游戏画面变化成真实世界画面,将一匹正常肤色的马转为斑马等等。
CycleGAN 主要解决的问题是将一个域中的图像转换到另一个域中的图像,而无需成对的训练数据。这种转换是双向的,即可以从一个域转换到另一个域,也可以反过来转换。
生成器: CycleGAN 包含两个生成器,分别用于将两个不同域的图像进行转换。例如,在从马到斑马的转换中,一个生成器负责将马的图像转换为斑马的图像,另一个生成器负责将斑马的图像转换为马的图像。生成器学习将输入图像从一个域映射到另一个域的转换函数。
判别器: CycleGAN 包含两个判别器,用于区分生成的图像和真实的图像。一个判别器用于区分生成的源图像和真实的源图像,另一个判别器用于区分生成的生成图像和真实的生成图像。判别器帮助生成器生成更逼真的图像。
损失函数:CycleGAN的Loss由三部分组成,分别为LossGAN(保证生成器和判别器相互进化,进而保证生成器能产生更真实的图片)、LossCycle(保证生成器的输出图片与输入图片只是风格不同,而内容相同)和LossIdentity(是映射损失, 即用真实的 A 当做输入, 查看生成器是否会原封不动的输出)。
自定义的 PyTorch 数据集类 ,用于加载图像数据集并进行预处理。
import glob import random import os from torch.utils.data import Dataset from PIL import Image import torchvision.transforms as transforms def to_rgb(image): rgb_image = Image.new("RGB", image.size) rgb_image.paste(image) return rgb_image class ImageDataset(Dataset): def __init__(self, root, transforms_=None, unaligned=False, mode="train"): self.transform = transforms.Compose(transforms_) self.unaligned = unaligned self.files_A = sorted(glob.glob(os.path.join(root, "%sA" % mode) + "/*.*")) self.files_B = sorted(glob.glob(os.path.join(root, "%sB" % mode) + "/*.*")) def __getitem__(self, index): image_A = Image.open(self.files_A[index % len(self.files_A)]) if self.unaligned: image_B = Image.open(self.files_B[random.randint(0, len(self.files_B) - 1)]) else: image_B = Image.open(self.files_B[index % len(self.files_B)]) # Convert grayscale images to rgb if image_A.mode != "RGB": image_A = to_rgb(image_A) if image_B.mode != "RGB": image_B = to_rgb(image_B) item_A = self.transform(image_A) item_B = self.transform(image_B) return {"A": item_A, "B": item_B} def __len__(self): return max(len(self.files_A), len(self.files_B))
遍历模型中的每一层,初始化神经网络模型中的权重。
import torch.nn as nn
import torch.nn.functional as F
import torch
def weights_init_normal(m):
classname = m.__class__.__name__
if classname.find("Conv") != -1:
torch.nn.init.normal_(m.weight.data, 0.0, 0.02)
if hasattr(m, "bias") and m.bias is not None:
torch.nn.init.constant_(m.bias.data, 0.0)
elif classname.find("BatchNorm2d") != -1:
torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
torch.nn.init.constant_(m.bias.data, 0.0)
定义了一个残差块。每个残差块包含两个卷积层,使用反射填充)进行填充,然后进行卷积、实例归一化和 ReLU 激活操作。最后通过残差连接将输入和残差块的输出相加得到最终的输出。
class ResidualBlock(nn.Module): def __init__(self, in_features): super(ResidualBlock, self).__init__() self.block = nn.Sequential( nn.ReflectionPad2d(1), nn.Conv2d(in_features, in_features, 3), nn.InstanceNorm2d(in_features), nn.ReLU(inplace=True), nn.ReflectionPad2d(1), nn.Conv2d(in_features, in_features, 3), nn.InstanceNorm2d(in_features), ) def forward(self, x): return x + self.block(x)
定义了基于 ResNet 结构的生成器。它通过堆叠多个残差块、卷积层和上采样层来生成图像。首先是一个初始的卷积块,然后进行下采样、残差块、上采样,最后输出目标图像。
class GeneratorResNet(nn.Module): def __init__(self, input_shape, num_residual_blocks): super(GeneratorResNet, self).__init__() channels = input_shape[0] # Initial convolution block out_features = 64 model = [ nn.ReflectionPad2d(channels), nn.Conv2d(channels, out_features, 7), nn.InstanceNorm2d(out_features), nn.ReLU(inplace=True), ] in_features = out_features # Downsampling for _ in range(2): out_features *= 2 model += [ nn.Conv2d(in_features, out_features, 3, stride=2, padding=1), nn.InstanceNorm2d(out_features), nn.ReLU(inplace=True), ] in_features = out_features # Residual blocks for _ in range(num_residual_blocks): model += [ResidualBlock(out_features)] # Upsampling for _ in range(2): out_features //= 2 model += [ nn.Upsample(scale_factor=2), nn.Conv2d(in_features, out_features, 3, stride=1, padding=1), nn.InstanceNorm2d(out_features), nn.ReLU(inplace=True), ] in_features = out_features # Output layer model += [nn.ReflectionPad2d(channels), nn.Conv2d(out_features, channels, 7), nn.Tanh()] self.model = nn.Sequential(*model) def forward(self, x): return self.model(x)
定义了判别器,这个判别器由多个卷积层组成,逐渐减小特征图的大小,最后输出一个单通道的结果,表示输入图像是真实图像的概率。
class Discriminator(nn.Module): def __init__(self, input_shape): super(Discriminator, self).__init__() channels, height, width = input_shape # Calculate output shape of image discriminator (PatchGAN) self.output_shape = (1, height // 2 ** 4, width // 2 ** 4) def discriminator_block(in_filters, out_filters, normalize=True): """Returns downsampling layers of each discriminator block""" layers = [nn.Conv2d(in_filters, out_filters, 4, stride=2, padding=1)] if normalize: layers.append(nn.InstanceNorm2d(out_filters)) layers.append(nn.LeakyReLU(0.2, inplace=True)) return layers self.model = nn.Sequential( *discriminator_block(channels, 64, normalize=False), *discriminator_block(64, 128), *discriminator_block(128, 256), *discriminator_block(256, 512), nn.ZeroPad2d((1, 0, 1, 0)), nn.Conv2d(512, 1, 4, padding=1) ) def forward(self, img): return self.model(img)
Util工具类,ReplayBuffer 用于创建一个缓冲区,用于存储历史数据,并在训练过程中可能会用到。LambdaLR 则用于在训练过程中根据指定的规则调整学习率。
import random import time import datetime import sys from torch.autograd import Variable import torch import numpy as np from torchvision.utils import save_image class ReplayBuffer: def __init__(self, max_size=50): assert max_size > 0, "Empty buffer or trying to create a black hole. Be careful." self.max_size = max_size self.data = [] def push_and_pop(self, data): to_return = [] for element in data.data: element = torch.unsqueeze(element, 0) if len(self.data) < self.max_size: self.data.append(element) to_return.append(element) else: if random.uniform(0, 1) > 0.5: i = random.randint(0, self.max_size - 1) to_return.append(self.data[i].clone()) self.data[i] = element else: to_return.append(element) return Variable(torch.cat(to_return)) class LambdaLR: def __init__(self, n_epochs, offset, decay_start_epoch): assert (n_epochs - decay_start_epoch) > 0, "Decay must start before the training session ends!" self.n_epochs = n_epochs self.offset = offset self.decay_start_epoch = decay_start_epoch def step(self, epoch): return 1.0 - max(0, epoch + self.offset - self.decay_start_epoch) / (self.n_epochs - self.decay_start_epoch)
设置训练参数,包括 epoch 数、数据集名称、批量大小、学习率。接着定义模型和优化器,包括生成器、判别器、损失函数和优化器。加载数据集并进行数据预处理,设置训练和测试数据加载器。
import argparse import itertools from torchvision.utils import save_image, make_grid from torch.utils.data import DataLoader from models import * from datasets import * from utils import * import torch parser = argparse.ArgumentParser() parser.add_argument("--epoch", type=int, default=0, help="epoch to start training from") parser.add_argument("--n_epochs", type=int, default=200, help="number of epochs of training") parser.add_argument("--dataset_name", type=str, default="monet2photo", help="name of the dataset") parser.add_argument("--batch_size", type=int, default=1, help="size of the batches") parser.add_argument("--lr", type=float, default=0.0002, help="adam: learning rate") parser.add_argument("--b1", type=float, default=0.5, help="adam: decay of first order momentum of gradient") parser.add_argument("--b2", type=float, default=0.999, help="adam: decay of first order momentum of gradient") parser.add_argument("--decay_epoch", type=int, default=100, help="epoch from which to start lr decay") parser.add_argument("--n_cpu", type=int, default=8, help="number of cpu threads to use during batch generation") parser.add_argument("--img_height", type=int, default=256, help="size of image height") parser.add_argument("--img_width", type=int, default=256, help="size of image width") parser.add_argument("--channels", type=int, default=3, help="number of image channels") parser.add_argument("--sample_interval", type=int, default=100, help="interval between saving generator outputs") parser.add_argument("--checkpoint_interval", type=int, default=1, help="interval between saving model checkpoints") parser.add_argument("--n_residual_blocks", type=int, default=9, help="number of residual blocks in generator") parser.add_argument("--lambda_cyc", type=float, default=10.0, help="cycle loss weight") parser.add_argument("--lambda_id", type=float, default=5.0, help="identity loss weight") opt = parser.parse_args() print(opt) # Create sample and checkpoint directories os.makedirs("images/%s" % opt.dataset_name, exist_ok=True) os.makedirs("saved_models/%s" % opt.dataset_name, exist_ok=True) # Losses criterion_GAN = torch.nn.MSELoss() criterion_cycle = torch.nn.L1Loss() criterion_identity = torch.nn.L1Loss() cuda = torch.cuda.is_available() input_shape = (opt.channels, opt.img_height, opt.img_width) # 初始化生成器鉴别器 G_AB = GeneratorResNet(input_shape, opt.n_residual_blocks) G_BA = GeneratorResNet(input_shape, opt.n_residual_blocks) D_A = Discriminator(input_shape) D_B = Discriminator(input_shape) if cuda: G_AB = G_AB.cuda() G_BA = G_BA.cuda() D_A = D_A.cuda() D_B = D_B.cuda() criterion_GAN.cuda() criterion_cycle.cuda() criterion_identity.cuda() if opt.epoch != 0: # 加载预训练模型 G_AB.load_state_dict(torch.load("saved_models/%s/G_AB_%d.pth" % (opt.dataset_name, opt.epoch))) G_BA.load_state_dict(torch.load("saved_models/%s/G_BA_%d.pth" % (opt.dataset_name, opt.epoch))) D_A.load_state_dict(torch.load("saved_models/%s/D_A_%d.pth" % (opt.dataset_name, opt.epoch))) D_B.load_state_dict(torch.load("saved_models/%s/D_B_%d.pth" % (opt.dataset_name, opt.epoch))) else: # 初始化权重 G_AB.apply(weights_init_normal) G_BA.apply(weights_init_normal) D_A.apply(weights_init_normal) D_B.apply(weights_init_normal) # Optimizers optimizer_G = torch.optim.Adam( itertools.chain(G_AB.parameters(), G_BA.parameters()), lr=opt.lr, betas=(opt.b1, opt.b2) ) optimizer_D_A = torch.optim.Adam(D_A.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2)) optimizer_D_B = torch.optim.Adam(D_B.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2)) # Learning rate update schedulers lr_scheduler_G = torch.optim.lr_scheduler.LambdaLR( optimizer_G, lr_lambda=LambdaLR(opt.n_epochs, opt.epoch, opt.decay_epoch).step ) lr_scheduler_D_A = torch.optim.lr_scheduler.LambdaLR( optimizer_D_A, lr_lambda=LambdaLR(opt.n_epochs, opt.epoch, opt.decay_epoch).step ) lr_scheduler_D_B = torch.optim.lr_scheduler.LambdaLR( optimizer_D_B, lr_lambda=LambdaLR(opt.n_epochs, opt.epoch, opt.decay_epoch).step ) Tensor = torch.cuda.FloatTensor if cuda else torch.Tensor # Buffers of previously generated samples fake_A_buffer = ReplayBuffer() fake_B_buffer = ReplayBuffer() # Image transformations transforms_ = [ transforms.Resize(int(opt.img_height * 1.12), Image.BICUBIC), transforms.RandomCrop((opt.img_height, opt.img_width)), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), ] # Training data loader dataloader = DataLoader( ImageDataset("./data/%s/" % opt.dataset_name, transforms_=transforms_, unaligned=True), batch_size=opt.batch_size, shuffle=True, num_workers=opt.n_cpu, ) # Test data loader val_dataloader = DataLoader( ImageDataset("./data/%s/" % opt.dataset_name, transforms_=transforms_, unaligned=True, mode="test"), batch_size=5, shuffle=True, num_workers=1, ) def sample_images(batches_done): """Saves a generated sample from the test set""" imgs = next(iter(val_dataloader)) G_AB.eval() G_BA.eval() real_A = Variable(imgs["A"].type(Tensor)) fake_B = G_AB(real_A) real_B = Variable(imgs["B"].type(Tensor)) fake_A = G_BA(real_B) # Arange images along x-axis real_A = make_grid(real_A, nrow=5, normalize=True) real_B = make_grid(real_B, nrow=5, normalize=True) fake_A = make_grid(fake_A, nrow=5, normalize=True) fake_B = make_grid(fake_B, nrow=5, normalize=True) # Arange images along y-axis image_grid = torch.cat((real_A, fake_B, real_B, fake_A), 1) save_image(image_grid, "images/%s/%s.png" % (opt.dataset_name, batches_done), normalize=False) # ---------- # Training # ---------- if __name__ == '__main__': prev_time = time.time() for epoch in range(opt.epoch, opt.n_epochs): for i, batch in enumerate(dataloader): # Set model input real_A = Variable(batch["A"].type(Tensor)) real_B = Variable(batch["B"].type(Tensor)) # Adversarial ground truths valid = Variable(Tensor(np.ones((real_A.size(0), *D_A.output_shape))), requires_grad=False) fake = Variable(Tensor(np.zeros((real_A.size(0), *D_A.output_shape))), requires_grad=False) # ------------------ # Train Generators # ------------------ G_AB.train() G_BA.train() optimizer_G.zero_grad() # Identity loss loss_id_A = criterion_identity(G_BA(real_A), real_A) loss_id_B = criterion_identity(G_AB(real_B), real_B) loss_identity = (loss_id_A + loss_id_B) / 2 # GAN loss fake_B = G_AB(real_A) loss_GAN_AB = criterion_GAN(D_B(fake_B), valid) fake_A = G_BA(real_B) loss_GAN_BA = criterion_GAN(D_A(fake_A), valid) loss_GAN = (loss_GAN_AB + loss_GAN_BA) / 2 # Cycle loss recov_A = G_BA(fake_B) loss_cycle_A = criterion_cycle(recov_A, real_A) recov_B = G_AB(fake_A) loss_cycle_B = criterion_cycle(recov_B, real_B) loss_cycle = (loss_cycle_A + loss_cycle_B) / 2 # Total loss loss_G = loss_GAN + opt.lambda_cyc * loss_cycle + opt.lambda_id * loss_identity loss_G.backward() optimizer_G.step() # ----------------------- # Train Discriminator A # ----------------------- optimizer_D_A.zero_grad() # Real loss loss_real = criterion_GAN(D_A(real_A), valid) # Fake loss (on batch of previously generated samples) fake_A_ = fake_A_buffer.push_and_pop(fake_A) loss_fake = criterion_GAN(D_A(fake_A_.detach()), fake) # Total loss loss_D_A = (loss_real + loss_fake) / 2 loss_D_A.backward() optimizer_D_A.step() # ----------------------- # Train Discriminator B # ----------------------- optimizer_D_B.zero_grad() # Real loss loss_real = criterion_GAN(D_B(real_B), valid) # Fake loss (on batch of previously generated samples) fake_B_ = fake_B_buffer.push_and_pop(fake_B) loss_fake = criterion_GAN(D_B(fake_B_.detach()), fake) # Total loss loss_D_B = (loss_real + loss_fake) / 2 loss_D_B.backward() optimizer_D_B.step() loss_D = (loss_D_A + loss_D_B) / 2 # -------------- # Log Progress # -------------- # Determine approximate time left batches_done = epoch * len(dataloader) + i batches_left = opt.n_epochs * len(dataloader) - batches_done time_left = datetime.timedelta(seconds=batches_left * (time.time() - prev_time)) prev_time = time.time() # Print log sys.stdout.write( "\r[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f, adv: %f, cycle: %f, identity: %f] ETA: %s" % ( epoch, opt.n_epochs, i, len(dataloader), loss_D.item(), loss_G.item(), loss_GAN.item(), loss_cycle.item(), loss_identity.item(), time_left, ) ) # If at sample interval save image if batches_done % opt.sample_interval == 0: sample_images(batches_done) # Update learning rates lr_scheduler_G.step() lr_scheduler_D_A.step() lr_scheduler_D_B.step() if opt.checkpoint_interval != -1 and epoch % opt.checkpoint_interval == 0: # Save model checkpoints torch.save(G_AB.state_dict(), "saved_models2/%s/G_AB_%d.pth" % (opt.dataset_name, epoch)) torch.save(G_BA.state_dict(), "saved_models2/%s/G_BA_%d.pth" % (opt.dataset_name, epoch)) torch.save(D_A.state_dict(), "saved_models2/%s/D_A_%d.pth" % (opt.dataset_name, epoch)) torch.save(D_B.state_dict(), "saved_models2/%s/D_B_%d.pth" % (opt.dataset_name, epoch))
本次实验设备较差,算力不够。请读者在GPU机器上自行运行。
CycleGAN 可以用于学习两个不同图像域之间的映射关系,使得在两个域之间进行图像转换成为可能。通过训练,模型可以学习到如何将一个图像从一个域转换到另一个域,而无需配对的训练数据,降低了数据收集和标注的成本。其提出的不同角度的损失函数,也是值得我们去学习。
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。