当前位置:   article > 正文

昇思25天学习打卡营第15天|DCGAN生成漫画头像

昇思25天学习打卡营第15天|DCGAN生成漫画头像

DCGAN原理

DCGAN(深度卷积对抗生成网络,Deep Convolutional Generative Adversarial Networks)是GAN的直接扩展。不同之处在于,DCGAN会分别在判别器和生成器中使用卷积和转置卷积层。

数据准备与处理

  1. from download import download
  2. url = "https://download.mindspore.cn/dataset/Faces/faces.zip"
  3. path = download(url, "./faces", kind="zip", replace=True)

数据处理

为执行过程定义一些输入:

  1. batch_size = 128 # 批量大小
  2. image_size = 64 # 训练图像空间大小
  3. nc = 3 # 图像彩色通道数
  4. nz = 100 # 隐向量的长度
  5. ngf = 64 # 特征图在生成器中的大小
  6. ndf = 64 # 特征图在判别器中的大小
  7. num_epochs = 3 # 训练周期数
  8. lr = 0.0002 # 学习率
  9. beta1 = 0.5 # Adam优化器的beta1超参数

数据处理和增强:

  1. import numpy as np
  2. import mindspore.dataset as ds
  3. import mindspore.dataset.vision as vision
  4. def create_dataset_imagenet(dataset_path):
  5. """数据加载"""
  6. dataset = ds.ImageFolderDataset(dataset_path,
  7. num_parallel_workers=4,
  8. shuffle=True,
  9. decode=True)
  10. # 数据增强操作
  11. transforms = [
  12. vision.Resize(image_size),
  13. vision.CenterCrop(image_size),
  14. vision.HWC2CHW(),
  15. lambda x: ((x / 255).astype("float32"))
  16. ]
  17. # 数据映射操作
  18. dataset = dataset.project('image')
  19. dataset = dataset.map(transforms, 'image')
  20. # 批量操作
  21. dataset = dataset.batch(batch_size)
  22. return dataset
  23. dataset = create_dataset_imagenet('./faces')

将数据转换成字典迭代器,并可视化部分训练数据:

  1. import matplotlib.pyplot as plt
  2. def plot_data(data):
  3. # 可视化部分训练数据
  4. plt.figure(figsize=(10, 3), dpi=140)
  5. for i, image in enumerate(data[0][:30], 1):
  6. plt.subplot(3, 10, i)
  7. plt.axis("off")
  8. plt.imshow(image.transpose(1, 2, 0))
  9. plt.show()
  10. sample_data = next(dataset.create_tuple_iterator(output_numpy=True))
  11. plot_data(sample_data)

构造网络

生成器

该功能是通过一系列 Conv2dTranspose 转置卷积层来完成的,每个层都与 BatchNorm2d 层和 ReLu 激活层配对,输出数据会经过 tanh 函数,使其返回 [-1,1] 的数据范围内。

DCGAN生成结构如下:

dcgangenerator

代码实现:

  1. import mindspore as ms
  2. from mindspore import nn, ops
  3. from mindspore.common.initializer import Normal
  4. weight_init = Normal(mean=0, sigma=0.02)
  5. gamma_init = Normal(mean=1, sigma=0.02)
  6. class Generator(nn.Cell):
  7. """DCGAN网络生成器"""
  8. def __init__(self):
  9. super(Generator, self).__init__()
  10. self.generator = nn.SequentialCell(
  11. nn.Conv2dTranspose(nz, ngf * 8, 4, 1, 'valid', weight_init=weight_init),
  12. nn.BatchNorm2d(ngf * 8, gamma_init=gamma_init),
  13. nn.ReLU(),
  14. nn.Conv2dTranspose(ngf * 8, ngf * 4, 4, 2, 'pad', 1, weight_init=weight_init),
  15. nn.BatchNorm2d(ngf * 4, gamma_init=gamma_init),
  16. nn.ReLU(),
  17. nn.Conv2dTranspose(ngf * 4, ngf * 2, 4, 2, 'pad', 1, weight_init=weight_init),
  18. nn.BatchNorm2d(ngf * 2, gamma_init=gamma_init),
  19. nn.ReLU(),
  20. nn.Conv2dTranspose(ngf * 2, ngf, 4, 2, 'pad', 1, weight_init=weight_init),
  21. nn.BatchNorm2d(ngf, gamma_init=gamma_init),
  22. nn.ReLU(),
  23. nn.Conv2dTranspose(ngf, nc, 4, 2, 'pad', 1, weight_init=weight_init),
  24. nn.Tanh()
  25. )
  26. def construct(self, x):
  27. return self.generator(x)
  28. generator = Generator()

判别器

使用卷积而不是通过池化来进行下采样是一个好方法,因为它可以让网络学习自己的池化特征。

代码实现如下:

  1. class Discriminator(nn.Cell):
  2. """DCGAN网络判别器"""
  3. def __init__(self):
  4. super(Discriminator, self).__init__()
  5. self.discriminator = nn.SequentialCell(
  6. nn.Conv2d(nc, ndf, 4, 2, 'pad', 1, weight_init=weight_init),
  7. nn.LeakyReLU(0.2),
  8. nn.Conv2d(ndf, ndf * 2, 4, 2, 'pad', 1, weight_init=weight_init),
  9. nn.BatchNorm2d(ngf * 2, gamma_init=gamma_init),
  10. nn.LeakyReLU(0.2),
  11. nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 'pad', 1, weight_init=weight_init),
  12. nn.BatchNorm2d(ngf * 4, gamma_init=gamma_init),
  13. nn.LeakyReLU(0.2),
  14. nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 'pad', 1, weight_init=weight_init),
  15. nn.BatchNorm2d(ngf * 8, gamma_init=gamma_init),
  16. nn.LeakyReLU(0.2),
  17. nn.Conv2d(ndf * 8, 1, 4, 1, 'valid', weight_init=weight_init),
  18. )
  19. self.adv_layer = nn.Sigmoid()
  20. def construct(self, x):
  21. out = self.discriminator(x)
  22. out = out.reshape(out.shape[0], -1)
  23. return self.adv_layer(out)
  24. discriminator = Discriminator()

模型训练

损失函数

定义交叉熵损失函数BCELoss

  1. # 定义损失函数
  2. adversarial_loss = nn.BCELoss(reduction='mean')

优化器

  1. # 为生成器和判别器设置优化器
  2. optimizer_D = nn.Adam(discriminator.trainable_params(), learning_rate=lr, beta1=beta1)
  3. optimizer_G = nn.Adam(generator.trainable_params(), learning_rate=lr, beta1=beta1)
  4. optimizer_G.update_parameters_name('optim_g.')
  5. optimizer_D.update_parameters_name('optim_d.')

训练模型

训练判别器

训练判别器的目的是最大程度地提高判别图像真伪的概率。通过提高其随机梯度来更新判别器,所以我们要最大化logD(x)+log(1-D(G(z)))

训练生成器

需要最小化log(1-D(G(z)))

模型训练正向逻辑:

  1. def generator_forward(real_imgs, valid):
  2. # 将噪声采样为发生器的输入
  3. z = ops.standard_normal((real_imgs.shape[0], nz, 1, 1))
  4. # 生成一批图像
  5. gen_imgs = generator(z)
  6. # 损失衡量发生器绕过判别器的能力
  7. g_loss = adversarial_loss(discriminator(gen_imgs), valid)
  8. return g_loss, gen_imgs
  9. def discriminator_forward(real_imgs, gen_imgs, valid, fake):
  10. # 衡量鉴别器从生成的样本中对真实样本进行分类的能力
  11. real_loss = adversarial_loss(discriminator(real_imgs), valid)
  12. fake_loss = adversarial_loss(discriminator(gen_imgs), fake)
  13. d_loss = (real_loss + fake_loss) / 2
  14. return d_loss
  15. grad_generator_fn = ms.value_and_grad(generator_forward, None,
  16. optimizer_G.parameters,
  17. has_aux=True)
  18. grad_discriminator_fn = ms.value_and_grad(discriminator_forward, None,
  19. optimizer_D.parameters)
  20. @ms.jit
  21. def train_step(imgs):
  22. valid = ops.ones((imgs.shape[0], 1), mindspore.float32)
  23. fake = ops.zeros((imgs.shape[0], 1), mindspore.float32)
  24. (g_loss, gen_imgs), g_grads = grad_generator_fn(imgs, valid)
  25. optimizer_G(g_grads)
  26. d_loss, d_grads = grad_discriminator_fn(imgs, gen_imgs, valid, fake)
  27. optimizer_D(d_grads)
  28. return g_loss, d_loss, gen_imgs

每50此迭代,手机生成器和判别器的损失

  1. import mindspore
  2. G_losses = []
  3. D_losses = []
  4. image_list = []
  5. total = dataset.get_dataset_size()
  6. for epoch in range(num_epochs):
  7. generator.set_train()
  8. discriminator.set_train()
  9. # 为每轮训练读入数据
  10. for i, (imgs, ) in enumerate(dataset.create_tuple_iterator()):
  11. g_loss, d_loss, gen_imgs = train_step(imgs)
  12. if i % 100 == 0 or i == total - 1:
  13. # 输出训练记录
  14. print('[%2d/%d][%3d/%d] Loss_D:%7.4f Loss_G:%7.4f' % (
  15. epoch + 1, num_epochs, i + 1, total, d_loss.asnumpy(), g_loss.asnumpy()))
  16. D_losses.append(d_loss.asnumpy())
  17. G_losses.append(g_loss.asnumpy())
  18. # 每个epoch结束后,使用生成器生成一组图片
  19. generator.set_train(False)
  20. fixed_noise = ops.standard_normal((batch_size, nz, 1, 1))
  21. img = generator(fixed_noise)
  22. image_list.append(img.transpose(0, 2, 3, 1).asnumpy())
  23. # 保存网络模型参数为ckpt文件
  24. mindspore.save_checkpoint(generator, "./generator.ckpt")
  25. mindspore.save_checkpoint(discriminator, "./discriminator.ckpt")

结果显示

得到生成器和判别器训练迭代的损失函数:

  1. plt.figure(figsize=(10, 5))
  2. plt.title("Generator and Discriminator Loss During Training")
  3. plt.plot(G_losses, label="G", color='blue')
  4. plt.plot(D_losses, label="D", color='orange')
  5. plt.xlabel("iterations")
  6. plt.ylabel("Loss")
  7. plt.legend()
  8. plt.show()

显示生成的图像

  1. import matplotlib.pyplot as plt
  2. import matplotlib.animation as animation
  3. def showGif(image_list):
  4. show_list = []
  5. fig = plt.figure(figsize=(8, 3), dpi=120)
  6. for epoch in range(len(image_list)):
  7. images = []
  8. for i in range(3):
  9. row = np.concatenate((image_list[epoch][i * 8:(i + 1) * 8]), axis=1)
  10. images.append(row)
  11. img = np.clip(np.concatenate((images[:]), axis=0), 0, 1)
  12. plt.axis("off")
  13. show_list.append([plt.imshow(img)])
  14. ani = animation.ArtistAnimation(fig, show_list, interval=1000, repeat_delay=1000, blit=True)
  15. ani.save('./dcgan.gif', writer='pillow', fps=1)
  16. showGif(image_list)

dcgan

可见随着训练次数的增加,图像质量也越来越好。

加载生成器网络模型参数文件来生成图像:

  1. # 从文件中获取模型参数并加载到网络中
  2. mindspore.load_checkpoint("./generator.ckpt", generator)
  3. fixed_noise = ops.standard_normal((batch_size, nz, 1, 1))
  4. img64 = generator(fixed_noise).transpose(0, 2, 3, 1).asnumpy()
  5. fig = plt.figure(figsize=(8, 3), dpi=120)
  6. images = []
  7. for i in range(3):
  8. images.append(np.concatenate((img64[i * 8:(i + 1) * 8]), axis=1))
  9. img = np.clip(np.concatenate((images[:]), axis=0), 0, 1)
  10. plt.axis("off")
  11. plt.imshow(img)
  12. plt.show()

总结

DCGAN在判别器和生成器中加入了卷积和转置卷积层,随着迭代次数的增加,生成图像的质量也随之提高。

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/一键难忘520/article/detail/802665
推荐阅读
相关标签
  

闽ICP备14008679号