赞
踩
数据集:Index of /~taesung_park/CycleGAN/datasets
CycleGAN一共有4个模型,即2个生成器,2个判别器,训练完成后只使用一个生成模型。
if opt.epoch != 0:
# Load pretrained models
G_AB.load_state_dict(torch.load("saved_models/%s/G_AB_%d.pth" % (opt.dataset_name, opt.epoch)))
G_BA.load_state_dict(torch.load("saved_models/%s/G_BA_%d.pth" % (opt.dataset_name, opt.epoch)))
D_A.load_state_dict(torch.load("saved_models/%s/D_A_%d.pth" % (opt.dataset_name, opt.epoch)))
D_B.load_state_dict(torch.load("saved_models/%s/D_B_%d.pth" % (opt.dataset_name, opt.epoch)))
loss 包括3个部分:生成器loss、循环loss、一致性loss.
# Losses
criterion_GAN = torch.nn.MSELoss()
criterion_cycle = torch.nn.L1Loss()
criterion_identity = torch.nn.L1Loss()
其中apply是内置函数可以调用括号中自己实现的算子。
# Initialize weights
G_AB.apply(weights_init_normal)
G_BA.apply(weights_init_normal)
D_A.apply(weights_init_normal)
D_B.apply(weights_init_normal)
被调用的初始化函数
def weights_init_normal(m):
classname = m.__class__.__name__
if classname.find("Conv") != -1:
torch.nn.init.normal_(m.weight.data, 0.0, 0.02)
if hasattr(m, "bias") and m.bias is not None:
torch.nn.init.constant_(m.bias.data, 0.0)
elif classname.find("BatchNorm2d") != -1:
torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
torch.nn.init.constant_(m.bias.data, 0.0)
# Optimizers optimizer_G = torch.optim.Adam( itertools.chain(G_AB.parameters(), G_BA.parameters()), lr=opt.lr, betas=(opt.b1, opt.b2) ) optimizer_D_A = torch.optim.Adam(D_A.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2)) optimizer_D_B = torch.optim.Adam(D_B.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2)) # Learning rate update schedulers lr_scheduler_G = torch.optim.lr_scheduler.LambdaLR( optimizer_G, lr_lambda=LambdaLR(opt.n_epochs, opt.epoch, opt.decay_epoch).step ) lr_scheduler_D_A = torch.optim.lr_scheduler.LambdaLR( optimizer_D_A, lr_lambda=LambdaLR(opt.n_epochs, opt.epoch, opt.decay_epoch).step ) lr_scheduler_D_B = torch.optim.lr_scheduler.LambdaLR( optimizer_D_B, lr_lambda=LambdaLR(opt.n_epochs, opt.epoch, opt.decay_epoch).step )
# Image transformations
transforms_ = [
transforms.Resize(int(opt.img_height * 1.12), Image.BICUBIC),
transforms.RandomCrop((opt.img_height, opt.img_width)),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
]
# Training data loader
dataloader = DataLoader(
ImageDataset("../../data/%s" % opt.dataset_name, transforms_=transforms_, unaligned=True),
batch_size=opt.batch_size,
shuffle=True,
num_workers=opt.n_cpu,
)
# Test data loader
val_dataloader = DataLoader(
ImageDataset("../../data/%s" % opt.dataset_name, transforms_=transforms_, unaligned=True, mode="test"),
batch_size=5,
shuffle=True,
num_workers=1,
)
CycleGAN也不是没有问题。CycleGAN: a Master of Steganography (隐写术) [Casey Chu, et al., NIPS workshop, 2017 ]这篇论文就指出,CycleGAN存在一种情况,是它能学会把输入的某些部分藏起来,然后在输出的时候再还原回来。比如下面这张图:
可以看到,在经过第一个generator的时候,屋顶的黑色斑点不见了,但是在经过第二个generator之后,屋顶的黑色斑点又被还原回来了。这其实意味着,第一个generator并没有遗失掉屋顶有黑色斑点这一讯息,它只是用一种人眼看不出的方式将这一讯息隐藏在输出的图片中(例如黑点数值改得非常小),而第二个generator在训练过程中也学习到了提取这种隐藏讯息的方式。那generator隐藏讯息的目的是什么呢?其实很简单,隐藏掉一些破坏风格相似性的“坏点”会更容易获得discriminator的高分,而从discriminator那拿高分是generator实际上的唯一目的。
参考:编译原理语义分析代码_Cycle GAN原理分析与代码解读
测试了马到斑马的cyclegan
训练了一下cyclegan,并启动了visdom。
安装visdom
pip install visdom
启动visdom
python -m visdom.server
服务器终端给出的链接地址
http://localhost:8097
在本地网页输入地址
# http://服务器IP:8097
http://服务器IP:8097
界面窗口如下:
在trainB下有7张单通道图像,代码不兼容,将7张图移出trainB文件夹。
Epoch 001/200 [0165/1327] -- loss_G: 10.3852 | loss_G_identity: 3.0593 | loss_G_GAN: 0.8252 | loss_G_cycle: 6.5006 | loss_D: 0.5533 -- ETA: 10 days, 16:51:07.615357
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。