当前位置:   article > 正文

从零使用Python 实现对抗神经网络GAN

python 对抗训练

你好,我是郭震

这篇从零使用Python,实现生成对抗网络(GAN)的基本版本。

GAN使用两套网络,分别是判别器(D)网络和生成器(G)网络,最重要的是弄清楚每套网络的输入和输出分别是什么,两套网络如何结合在一起,及优化的目标即cost function如何定义。

通俗来讲,两套网络结合的方法,就是G会从D的判分中不断提升生成能力,要知道G最开始的输入全部是噪点,这个思想也是文生图文生视频的基石

下面这段代码展示了使用PyTorch框架进行生成对抗网络(GAN)训练的基本流程。

下面这些解释非常重要:

对于判别器网络而言,它的目标是最大化表达式 log(D(x)) + log(1 - D(G(z))),其中:

  • D(x) 是判别器网络对真实图像 x 的输出,这个值代表判别器认为图像是真实的概率。

  • D(G(z)) 是判别器网络对生成图像 G(z) 的输出,这个值代表判别器认为通过生成器从噪声 z 生成的图像是真实的概率。

  • log(D(x)) 的目标是使得判别器能够尽可能地将真实图像分类为真实(即,使 D(x) 接近于1)。

  • log(1 - D(G(z))) 的目标是使得判别器能够将生成的图像分类为假(即,使 D(G(z)) 接近于0)。

  1. # GAN 训练的基本代码
  2. for epoch in range(num_epochs):
  3.     for i, data in enumerate(dataloader, 0):
  4.         # 更新判别器网络:maximize log(D(x)) + log(1 - D(G(z)))
  5.         
  6.         # 在真实图像上训练
  7.         netD.zero_grad()
  8.         real_cpu = data[0].to(device)
  9.         batch_size = real_cpu.size(0)
  10.         label = torch.full((batch_size,), 1, dtype=torch.float, device=device)
  11.         output = netD(real_cpu).view(-1)
  12.         errD_real = criterion(output, label)
  13.         errD_real.backward()
  14.         D_x = output.mean().item()
  15.         # 在假图像上训练
  16.         noise = torch.randn(batch_size, nz, 11, device=device)
  17.         fake = netG(noise)
  18.         label.fill_(0)
  19.         output = netD(fake.detach()).view(-1)
  20.         errD_fake = criterion(output, label)
  21.         errD_fake.backward()
  22.         D_G_z1 = output.mean().item()
  23.         errD = errD_real + errD_fake
  24.         optimizerD.step()
  25.         # 更新生成器网络:maximize log(D(G(z)))
  26.         netG.zero_grad()
  27.         label.fill_(1)  # 假图像的标签对于生成器来说是真的
  28.         output = netD(fake).view(-1)
  29.         errG = criterion(output, label)
  30.         errG.backward()
  31.         D_G_z2 = output.mean().item()
  32.         optimizerG.step()

在训练过程中,这一目标通过以下步骤来实现:

  1. 对于真实图像

  • 判别器 D 接收一批真实图像 x

  • 计算 D(x),即这些真实图像被识别为真实的概率。

  • 使用 log(D(x)) 计算损失。这个损失会根据真实图像被正确识别的程度(即,D(x) 应该接近于1)来调整。

对于生成图像

  • 生成器 G 从随机噪声 z 生成一批假图像。

  • 判别器 D 接收这些生成的图像,并计算 D(G(z)),即这些假图像被识别为真实的概率。

  • 使用 log(1 - D(G(z))) 计算损失。这个损失会根据假图像被正确识别的程度(即,D(G(z)) 应该接近于0)来调整。

实现细节

  • 在PyTorch中,损失函数通常是要最小化的。因此,虽然理论目标是最大化 log(D(x)) + log(1 - D(G(z))),实际上我们通过最小化 -log(D(x)) - log(1 - D(G(z))) 来实现这一目标。

  • 使用二元交叉熵损失(Binary Cross-Entropy, BCE)来实现这一目标,因为它直接提供了所需的 -log(x) 和 -log(1-x) 形式的损失。

这个基本训练循环是理解和实现GANs的关键,而且也是后续进行各种变体和改进的基础。

弄清楚GAN的训练过程后,其他代码就比较容易理解。

导入所需的库

  1. import torch
  2. import torch.nn as nn
  3. import torch.optim as optim
  4. from torchvision import datasets, transforms
  5. from torch.utils.data import DataLoader
  6. import torchvision.utils as vutils

定义生成器(Generator)

  1. class Generator(nn.Module):
  2.     def __init__(self, nz, ngf, nc):
  3.         super(Generator, self).__init__()
  4.         self.main = nn.Sequential(
  5.             # 输入是 Z, 对此进行全连接
  6.             nn.ConvTranspose2d(nz, ngf * 8410, bias=False),
  7.             nn.BatchNorm2d(ngf * 8),
  8.             nn.ReLU(True),
  9.             # 上一步的输出形状: (ngf*8) x 4 x 4
  10.             nn.ConvTranspose2d(ngf * 8, ngf * 4421, bias=False),
  11.             nn.BatchNorm2d(ngf * 4),
  12.             nn.ReLU(True),
  13.             # 上一步的输出形状: (ngf*4) x 8 x 8
  14.             nn.ConvTranspose2d( ngf * 4, ngf * 2421, bias=False),
  15.             nn.BatchNorm2d(ngf * 2),
  16.             nn.ReLU(True),
  17.             # 上一步的输出形状: (ngf*2) x 16 x 16
  18.             nn.ConvTranspose2d( ngf * 2, ngf, 421, bias=False),
  19.             nn.BatchNorm2d(ngf),
  20.             nn.ReLU(True),
  21.             # 上一步的输出形状: (ngf) x 32 x 32
  22.             nn.ConvTranspose2d( ngf, nc, 421, bias=False),
  23.             nn.Tanh()
  24.             # 输出形状: (nc) x 64 x 64
  25.         )
  26.     def forward(self, input):
  27.         return self.main(input)

定义判别器(Discriminator)

  1. class Discriminator(nn.Module):
  2.     def __init__(self, nc, ndf):
  3.         super(Discriminator, self).__init__()
  4.         self.main = nn.Sequential(
  5.             # 输入形状: (nc) x 64 x 64
  6.             nn.Conv2d(nc, ndf, 421, bias=False),
  7.             nn.LeakyReLU(0.2, inplace=True),
  8.             # 输出形状: (ndf) x 32 x 32
  9.             nn.Conv2d(ndf, ndf * 2421, bias=False),
  10.             nn.BatchNorm2d(ndf * 2),
  11.             nn.LeakyReLU(0.2, inplace=True),
  12.             # 输出形状: (ndf*2) x 16 x 16
  13.             nn.Conv2d(ndf * 2, ndf * 4421, bias=False),
  14.             nn.BatchNorm2d(ndf * 4),
  15.             nn.LeakyReLU(0.2, inplace=True),
  16.             # 输出形状: (ndf*4) x 8 x 8
  17.             nn.Conv2d(ndf * 4, ndf * 8421, bias=False),
  18.             nn.BatchNorm2d(ndf * 8),
  19.             nn.LeakyReLU(0.2, inplace=True),
  20.             # 输出形状: (ndf*8) x 4 x 4
  21.             nn.Conv2d(ndf * 81410, bias=False),
  22.             nn.Sigmoid()
  23.         )
  24.     def forward(self, input):
  25.         return self.main(input)

初始化模型、优化器和损失函数

  1. # 初始化
  2. nz = 100  # 隐藏向量的维度
  3. ngf = 64  # 与生成器的特征图深度相关
  4. ndf = 64  # 与判别器的特征图深度相关
  5. nc = 1    # 输出图像的通道数
  6. # 创建生成器和判别器
  7. netG = Generator(nz, ngf, nc).to(device)
  8. netD = Discriminator(nc, ndf).to(device)
  9. # 初始化权重
  10. def weights_init(m):
  11.     classname = m.__class__.__name__
  12.     if classname.find('Conv') != -1:
  13.         nn.init.normal_(m.weight.data, 0.00.02)
  14.     elif classname.find('BatchNorm') != -1:
  15.         nn.init.normal_(m.weight.data, 1.00.02)
  16.         nn.init.constant_(m.bias.data, 0)
  17. netG.apply(weights_init)
  18. netD.apply(weights_init)
  19. # 设置优化器
  20. optimizerD = optim.Adam(netD.parameters(), lr=0.0002, betas=(0.50.999))
  21. optimizerG = optim.Adam(netG.parameters(), lr=0.0002, betas=(0.50.999))
  22. # 设置损失函数
  23. criterion = nn.BCELoss()

通过这种方式,判别器学习区分真实和生成的图像,同时生成器试图生成越来越难以被判别器区分的图像,从而实现了GAN的训练过程。

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/Gausst松鼠会/article/detail/543987
推荐阅读
相关标签
  

闽ICP备14008679号