当前位置:   article > 正文

CycleGAN_cyclegan数据集

cyclegan数据集

1、数据集

在这里插入图片描述

数据集:Index of /~taesung_park/CycleGAN/datasets

2、模型

CycleGAN一共有4个模型,即2个生成器,2个判别器,训练完成后只使用一个生成模型。

if opt.epoch != 0:
    # Load pretrained models
    G_AB.load_state_dict(torch.load("saved_models/%s/G_AB_%d.pth" % (opt.dataset_name, opt.epoch)))
    G_BA.load_state_dict(torch.load("saved_models/%s/G_BA_%d.pth" % (opt.dataset_name, opt.epoch)))
    D_A.load_state_dict(torch.load("saved_models/%s/D_A_%d.pth" % (opt.dataset_name, opt.epoch)))
    D_B.load_state_dict(torch.load("saved_models/%s/D_B_%d.pth" % (opt.dataset_name, opt.epoch)))
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6

3、loss

loss 包括3个部分:生成器loss、循环loss、一致性loss.

# Losses
criterion_GAN = torch.nn.MSELoss()
criterion_cycle = torch.nn.L1Loss()
criterion_identity = torch.nn.L1Loss()
  • 1
  • 2
  • 3
  • 4

4、初始化权重

其中apply是内置函数可以调用括号中自己实现的算子。

    # Initialize weights
    G_AB.apply(weights_init_normal)
    G_BA.apply(weights_init_normal)
    D_A.apply(weights_init_normal)
    D_B.apply(weights_init_normal)
  • 1
  • 2
  • 3
  • 4
  • 5

被调用的初始化函数

def weights_init_normal(m):
    classname = m.__class__.__name__
    if classname.find("Conv") != -1:
        torch.nn.init.normal_(m.weight.data, 0.0, 0.02)
        if hasattr(m, "bias") and m.bias is not None:
            torch.nn.init.constant_(m.bias.data, 0.0)
    elif classname.find("BatchNorm2d") != -1:
        torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
        torch.nn.init.constant_(m.bias.data, 0.0)

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10

5、优化器与学习率更新策略

# Optimizers
optimizer_G = torch.optim.Adam(
    itertools.chain(G_AB.parameters(), G_BA.parameters()), lr=opt.lr, betas=(opt.b1, opt.b2)
)
optimizer_D_A = torch.optim.Adam(D_A.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
optimizer_D_B = torch.optim.Adam(D_B.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))

# Learning rate update schedulers
lr_scheduler_G = torch.optim.lr_scheduler.LambdaLR(
    optimizer_G, lr_lambda=LambdaLR(opt.n_epochs, opt.epoch, opt.decay_epoch).step
)
lr_scheduler_D_A = torch.optim.lr_scheduler.LambdaLR(
    optimizer_D_A, lr_lambda=LambdaLR(opt.n_epochs, opt.epoch, opt.decay_epoch).step
)
lr_scheduler_D_B = torch.optim.lr_scheduler.LambdaLR(
    optimizer_D_B, lr_lambda=LambdaLR(opt.n_epochs, opt.epoch, opt.decay_epoch).step
)

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18

6、数据预处理

# Image transformations
transforms_ = [
    transforms.Resize(int(opt.img_height * 1.12), Image.BICUBIC),
    transforms.RandomCrop((opt.img_height, opt.img_width)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
]
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8

7、数据加载 dataloader

# Training data loader
dataloader = DataLoader(
    ImageDataset("../../data/%s" % opt.dataset_name, transforms_=transforms_, unaligned=True),
    batch_size=opt.batch_size,
    shuffle=True,
    num_workers=opt.n_cpu,
)
# Test data loader
val_dataloader = DataLoader(
    ImageDataset("../../data/%s" % opt.dataset_name, transforms_=transforms_, unaligned=True, mode="test"),
    batch_size=5,
    shuffle=True,
    num_workers=1,
)

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15

总结

CycleGAN也不是没有问题。CycleGAN: a Master of Steganography (隐写术) [Casey Chu, et al., NIPS workshop, 2017 ]这篇论文就指出,CycleGAN存在一种情况,是它能学会把输入的某些部分藏起来,然后在输出的时候再还原回来。比如下面这张图:

在这里插入图片描述

可以看到,在经过第一个generator的时候,屋顶的黑色斑点不见了,但是在经过第二个generator之后,屋顶的黑色斑点又被还原回来了。这其实意味着,第一个generator并没有遗失掉屋顶有黑色斑点这一讯息,它只是用一种人眼看不出的方式将这一讯息隐藏在输出的图片中(例如黑点数值改得非常小),而第二个generator在训练过程中也学习到了提取这种隐藏讯息的方式。那generator隐藏讯息的目的是什么呢?其实很简单,隐藏掉一些破坏风格相似性的“坏点”会更容易获得discriminator的高分,而从discriminator那拿高分是generator实际上的唯一目的。

参考:编译原理语义分析代码_Cycle GAN原理分析与代码解读

8、数据显示

测试了马到斑马的cyclegan
训练了一下cyclegan,并启动了visdom。

安装visdom

pip install visdom
  • 1

启动visdom

python -m visdom.server
  • 1

服务器终端给出的链接地址

http://localhost:8097
  • 1

在本地网页输入地址

# http://服务器IP:8097
http://服务器IP:8097
  • 1
  • 2

界面窗口如下:
在这里插入图片描述
在trainB下有7张单通道图像,代码不兼容,将7张图移出trainB文件夹。

Epoch 001/200 [0165/1327] -- loss_G: 10.3852 | loss_G_identity: 3.0593 | loss_G_GAN: 0.8252 | loss_G_cycle: 6.5006 | loss_D: 0.5533 -- ETA: 10 days, 16:51:07.615357
  • 1
本文内容由网友自发贡献,转载请注明出处:【wpsshop博客】
推荐阅读
相关标签
  

闽ICP备14008679号