赞
踩
你好,我是郭震
这篇从零使用Python,实现生成对抗网络(GAN)的基本版本。
GAN使用两套网络,分别是判别器(D)网络和生成器(G)网络,最重要的是弄清楚每套网络的输入和输出分别是什么,两套网络如何结合在一起,及优化的目标即cost function如何定义。
通俗来讲,两套网络结合的方法,就是G会从D的判分中不断提升生成能力,要知道G最开始的输入全部是噪点,这个思想也是文生图**,文生视频的基石**。
下面这段代码展示了使用PyTorch框架进行生成对抗网络(GAN)训练的基本流程。
下面这些解释非常重要:
对于判别器网络而言,它的目标是最大化表达式 log(D(x)) + log(1 - D(G(z)))
,其中:
D(x)
是判别器网络对真实图像 x
的输出,这个值代表判别器认为图像是真实的概率。
D(G(z))
是判别器网络对生成图像 G(z)
的输出,这个值代表判别器认为通过生成器从噪声 z
生成的图像是真实的概率。
log(D(x))
的目标是使得判别器能够尽可能地将真实图像分类为真实(即,使 D(x)
接近于1)。
log(1 - D(G(z)))
的目标是使得判别器能够将生成的图像分类为假(即,使 D(G(z))
接近于0)。
# GAN 训练的基本代码 for epoch in range(num_epochs): for i, data in enumerate(dataloader, 0): # 更新判别器网络:maximize log(D(x)) + log(1 - D(G(z))) # 在真实图像上训练 netD.zero_grad() real_cpu = data[0].to(device) batch_size = real_cpu.size(0) label = torch.full((batch_size,), 1, dtype=torch.float, device=device) output = netD(real_cpu).view(-1) errD_real = criterion(output, label) errD_real.backward() D_x = output.mean().item() # 在假图像上训练 noise = torch.randn(batch_size, nz, 1, 1, device=device) fake = netG(noise) label.fill_(0) output = netD(fake.detach()).view(-1) errD_fake = criterion(output, label) errD_fake.backward() D_G_z1 = output.mean().item() errD = errD_real + errD_fake optimizerD.step() # 更新生成器网络:maximize log(D(G(z))) netG.zero_grad() label.fill_(1) # 假图像的标签对于生成器来说是真的 output = netD(fake).view(-1) errG = criterion(output, label) errG.backward() D_G_z2 = output.mean().item() optimizerG.step()
在训练过程中,这一目标通过以下步骤来实现:
判别器 D
接收一批真实图像 x
。
计算 D(x)
,即这些真实图像被识别为真实的概率。
使用 log(D(x))
计算损失。这个损失会根据真实图像被正确识别的程度(即,D(x)
应该接近于1)来调整。
生成器 G
从随机噪声 z
生成一批假图像。
判别器 D
接收这些生成的图像,并计算 D(G(z))
,即这些假图像被识别为真实的概率。
使用 log(1 - D(G(z)))
计算损失。这个损失会根据假图像被正确识别的程度(即,D(G(z))
应该接近于0)来调整。
实现细节:
在PyTorch中,损失函数通常是要最小化的。因此,虽然理论目标是最大化 log(D(x)) + log(1 - D(G(z)))
,实际上我们通过最小化 -log(D(x)) - log(1 - D(G(z)))
来实现这一目标。
使用二元交叉熵损失(Binary Cross-Entropy, BCE)来实现这一目标,因为它直接提供了所需的 -log(x)
和 -log(1-x)
形式的损失。
这个基本训练循环是理解和实现GANs的关键,而且也是后续进行各种变体和改进的基础。
弄清楚GAN的训练过程后,其他代码就比较容易理解。
导入所需的库
import torch import torch.nn as nn import torch.optim as optim from torchvision import datasets, transforms from torch.utils.data import DataLoader import torchvision.utils as vutils
定义生成器(Generator)
class Generator(nn.Module): def __init__(self, nz, ngf, nc): super(Generator, self).__init__() self.main = nn.Sequential( # 输入是 Z, 对此进行全连接 nn.ConvTranspose2d(nz, ngf * 8, 4, 1, 0, bias=False), nn.BatchNorm2d(ngf * 8), nn.ReLU(True), # 上一步的输出形状: (ngf*8) x 4 x 4 nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False), nn.BatchNorm2d(ngf * 4), nn.ReLU(True), # 上一步的输出形状: (ngf*4) x 8 x 8 nn.ConvTranspose2d( ngf * 4, ngf * 2, 4, 2, 1, bias=False), nn.BatchNorm2d(ngf * 2), nn.ReLU(True), # 上一步的输出形状: (ngf*2) x 16 x 16 nn.ConvTranspose2d( ngf * 2, ngf, 4, 2, 1, bias=False), nn.BatchNorm2d(ngf), nn.ReLU(True), # 上一步的输出形状: (ngf) x 32 x 32 nn.ConvTranspose2d( ngf, nc, 4, 2, 1, bias=False), nn.Tanh() # 输出形状: (nc) x 64 x 64 ) def forward(self, input): return self.main(input)
定义判别器(Discriminator)
class Discriminator(nn.Module): def __init__(self, nc, ndf): super(Discriminator, self).__init__() self.main = nn.Sequential( # 输入形状: (nc) x 64 x 64 nn.Conv2d(nc, ndf, 4, 2, 1, bias=False), nn.LeakyReLU(0.2, inplace=True), # 输出形状: (ndf) x 32 x 32 nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False), nn.BatchNorm2d(ndf * 2), nn.LeakyReLU(0.2, inplace=True), # 输出形状: (ndf*2) x 16 x 16 nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False), nn.BatchNorm2d(ndf * 4), nn.LeakyReLU(0.2, inplace=True), # 输出形状: (ndf*4) x 8 x 8 nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False), nn.BatchNorm2d(ndf * 8), nn.LeakyReLU(0.2, inplace=True), # 输出形状: (ndf*8) x 4 x 4 nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False), nn.Sigmoid() ) def forward(self, input): return self.main(input)
初始化模型、优化器和损失函数
# 初始化 nz = 100 # 隐藏向量的维度 ngf = 64 # 与生成器的特征图深度相关 ndf = 64 # 与判别器的特征图深度相关 nc = 1 # 输出图像的通道数 # 创建生成器和判别器 netG = Generator(nz, ngf, nc).to(device) netD = Discriminator(nc, ndf).to(device) # 初始化权重 def weights_init(m): classname = m.__class__.__name__ if classname.find('Conv') != -1: nn.init.normal_(m.weight.data, 0.0, 0.02) elif classname.find('BatchNorm') != -1: nn.init.normal_(m.weight.data, 1.0, 0.02) nn.init.constant_(m.bias.data, 0) netG.apply(weights_init) netD.apply(weights_init) # 设置优化器 optimizerD = optim.Adam(netD.parameters(), lr=0.0002, betas=(0.5, 0.999)) optimizerG = optim.Adam(netG.parameters(), lr=0.0002, betas=(0.5, 0.999)) # 设置损失函数 criterion = nn.BCELoss()
通过这种方式,判别器学习区分真实和生成的图像,同时生成器试图生成越来越难以被判别器区分的图像,从而实现了GAN的训练过程。
点击下方安全链接前往获取
CSDN大礼包:《Python入门&进阶学习资源包》免费分享
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。