当前位置:   article > 正文

从零使用Python 实现对抗神经网络GAN_生成对抗网络交替训练python代码怎么写

生成对抗网络交替训练python代码怎么写

你好,我是郭震

这篇从零使用Python,实现生成对抗网络(GAN)的基本版本。

GAN使用两套网络,分别是判别器(D)网络和生成器(G)网络,最重要的是弄清楚每套网络的输入和输出分别是什么,两套网络如何结合在一起,及优化的目标即cost function如何定义。

通俗来讲,两套网络结合的方法,就是G会从D的判分中不断提升生成能力,要知道G最开始的输入全部是噪点,这个思想也是文生图**,文生视频的基石**。

下面这段代码展示了使用PyTorch框架进行生成对抗网络(GAN)训练的基本流程。

下面这些解释非常重要:

对于判别器网络而言,它的目标是最大化表达式 log(D(x)) + log(1 - D(G(z))),其中:

  • D(x) 是判别器网络对真实图像 x 的输出,这个值代表判别器认为图像是真实的概率。

  • D(G(z)) 是判别器网络对生成图像 G(z) 的输出,这个值代表判别器认为通过生成器从噪声 z 生成的图像是真实的概率。

  • log(D(x)) 的目标是使得判别器能够尽可能地将真实图像分类为真实(即,使 D(x) 接近于1)。

  • log(1 - D(G(z))) 的目标是使得判别器能够将生成的图像分类为假(即,使 D(G(z)) 接近于0)。

# GAN 训练的基本代码   for epoch in range(num_epochs):       for i, data in enumerate(dataloader, 0):           # 更新判别器网络:maximize log(D(x)) + log(1 - D(G(z)))                      # 在真实图像上训练           netD.zero_grad()           real_cpu = data[0].to(device)           batch_size = real_cpu.size(0)           label = torch.full((batch_size,), 1, dtype=torch.float, device=device)           output = netD(real_cpu).view(-1)           errD_real = criterion(output, label)           errD_real.backward()           D_x = output.mean().item()              # 在假图像上训练           noise = torch.randn(batch_size, nz, 1, 1, device=device)           fake = netG(noise)           label.fill_(0)           output = netD(fake.detach()).view(-1)           errD_fake = criterion(output, label)           errD_fake.backward()           D_G_z1 = output.mean().item()           errD = errD_real + errD_fake           optimizerD.step()              # 更新生成器网络:maximize log(D(G(z)))           netG.zero_grad()           label.fill_(1)  # 假图像的标签对于生成器来说是真的           output = netD(fake).view(-1)           errG = criterion(output, label)           errG.backward()           D_G_z2 = output.mean().item()           optimizerG.step()   
  • 1

在训练过程中,这一目标通过以下步骤来实现:

  1. 对于真实图像
  • 判别器 D 接收一批真实图像 x

  • 计算 D(x),即这些真实图像被识别为真实的概率。

  • 使用 log(D(x)) 计算损失。这个损失会根据真实图像被正确识别的程度(即,D(x) 应该接近于1)来调整。

  1. 对于生成图像
  • 生成器 G 从随机噪声 z 生成一批假图像。

  • 判别器 D 接收这些生成的图像,并计算 D(G(z)),即这些假图像被识别为真实的概率。

  • 使用 log(1 - D(G(z))) 计算损失。这个损失会根据假图像被正确识别的程度(即,D(G(z)) 应该接近于0)来调整。

实现细节

  • 在PyTorch中,损失函数通常是要最小化的。因此,虽然理论目标是最大化 log(D(x)) + log(1 - D(G(z))),实际上我们通过最小化 -log(D(x)) - log(1 - D(G(z))) 来实现这一目标。

  • 使用二元交叉熵损失(Binary Cross-Entropy, BCE)来实现这一目标,因为它直接提供了所需的 -log(x)-log(1-x) 形式的损失。

这个基本训练循环是理解和实现GANs的关键,而且也是后续进行各种变体和改进的基础。

弄清楚GAN的训练过程后,其他代码就比较容易理解。

导入所需的库

    import torch   import torch.nn as nn   import torch.optim as optim   from torchvision import datasets, transforms   from torch.utils.data import DataLoader   import torchvision.utils as vutils    
  • 1

定义生成器(Generator)

    class Generator(nn.Module):       def __init__(self, nz, ngf, nc):           super(Generator, self).__init__()           self.main = nn.Sequential(               # 输入是 Z, 对此进行全连接               nn.ConvTranspose2d(nz, ngf * 8, 4, 1, 0, bias=False),               nn.BatchNorm2d(ngf * 8),               nn.ReLU(True),               # 上一步的输出形状: (ngf*8) x 4 x 4               nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),               nn.BatchNorm2d(ngf * 4),               nn.ReLU(True),               # 上一步的输出形状: (ngf*4) x 8 x 8               nn.ConvTranspose2d( ngf * 4, ngf * 2, 4, 2, 1, bias=False),               nn.BatchNorm2d(ngf * 2),               nn.ReLU(True),               # 上一步的输出形状: (ngf*2) x 16 x 16               nn.ConvTranspose2d( ngf * 2, ngf, 4, 2, 1, bias=False),               nn.BatchNorm2d(ngf),               nn.ReLU(True),               # 上一步的输出形状: (ngf) x 32 x 32               nn.ConvTranspose2d( ngf, nc, 4, 2, 1, bias=False),               nn.Tanh()               # 输出形状: (nc) x 64 x 64           )          def forward(self, input):           return self.main(input)    
  • 1

定义判别器(Discriminator)

    class Discriminator(nn.Module):       def __init__(self, nc, ndf):           super(Discriminator, self).__init__()           self.main = nn.Sequential(               # 输入形状: (nc) x 64 x 64               nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),               nn.LeakyReLU(0.2, inplace=True),               # 输出形状: (ndf) x 32 x 32               nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),               nn.BatchNorm2d(ndf * 2),               nn.LeakyReLU(0.2, inplace=True),               # 输出形状: (ndf*2) x 16 x 16               nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),               nn.BatchNorm2d(ndf * 4),               nn.LeakyReLU(0.2, inplace=True),               # 输出形状: (ndf*4) x 8 x 8               nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),               nn.BatchNorm2d(ndf * 8),               nn.LeakyReLU(0.2, inplace=True),               # 输出形状: (ndf*8) x 4 x 4               nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),               nn.Sigmoid()           )          def forward(self, input):           return self.main(input)    
  • 1

初始化模型、优化器和损失函数

    # 初始化   nz = 100  # 隐藏向量的维度   ngf = 64  # 与生成器的特征图深度相关   ndf = 64  # 与判别器的特征图深度相关   nc = 1    # 输出图像的通道数      # 创建生成器和判别器   netG = Generator(nz, ngf, nc).to(device)   netD = Discriminator(nc, ndf).to(device)      # 初始化权重   def weights_init(m):       classname = m.__class__.__name__       if classname.find('Conv') != -1:           nn.init.normal_(m.weight.data, 0.0, 0.02)       elif classname.find('BatchNorm') != -1:           nn.init.normal_(m.weight.data, 1.0, 0.02)           nn.init.constant_(m.bias.data, 0)      netG.apply(weights_init)   netD.apply(weights_init)      # 设置优化器   optimizerD = optim.Adam(netD.parameters(), lr=0.0002, betas=(0.5, 0.999))   optimizerG = optim.Adam(netG.parameters(), lr=0.0002, betas=(0.5, 0.999))      # 设置损失函数   criterion = nn.BCELoss()    
  • 1

通过这种方式,判别器学习区分真实和生成的图像,同时生成器试图生成越来越难以被判别器区分的图像,从而实现了GAN的训练过程。

点击下方安全链接前往获取

CSDN大礼包:《Python入门&进阶学习资源包》免费分享

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/小惠珠哦/article/detail/876125

推荐阅读
相关标签