当前位置:   article > 正文

Pytorch生成对抗网络(GAN)官方入门教程_gan网络 pytorch

gan网络 pytorch

目录

引言(Introduction)

生成对抗网络(Generative Adversarial Networks)

什么是GAN?(What is a GAN?)

什么是DCGAN?(What is a DCGAN?)

输入(Inputs)

数据(Data)

实现(Implementation)

权重初始化(Weight Initialization)

生成器(Generator)

判别器(Discriminator)

损失函数和优化器(Loss Functions and Optimizers)

训练(Training)

下一步(Where to Go Next)


原文链接:https://pytorch.org/tutorials/beginner/dcgan_faces_tutorial.html

引言(Introduction)

本教程将通过一个示例介绍DCGAN(Deep Convolutional Generative Adversarial Networks)。我们将训练一个生成对抗网络(GAN),在展示许多名人的真实照片后产生新的名人。这里的代码实现来自 pytorch/examples,本文档对代码实现进行透彻的解释,并阐明此模型如何以及为什么有效果。但别担心,理解GANs不需要有先验知识,但它可能需要你花一些时间来研究幕后到底发生了什么。另外,因为时间的缘故,有一个或两个GPU也会有帮助。让我们开始吧。

生成对抗网络(Generative Adversarial Networks)

什么是GAN?(What is a GAN?)

GANs是一个教学DL(Deep Learning)模型的框架,使得DL模型可以捕获训练数据的分布,这样我们就可以在相同的数据分布中生成新的数据。GANs是Goodfellow 在2014年发明的,并在 Generative Adversarial Nets论文中首次提出。它由两个不同的模型组成,一个生成器和一个判别器。生成器的目标是生成类似于训练图片的图片,判别器的目标是,输入一张图片,判断输入的图片是真图片还是生成器产生的假图片。在训练过程中,生成器不断的生成更好的假图片试图骗过判别器,而判别器则在努力成为更好的鉴别者,正确的对真假图片进行分类。这个游戏的平衡点就是生成器产生的图片就好像是从训练图片中取出的一样,判别器总是有50%的置信度鉴别生成器的图片是真或是假。

现在,让我们定义一些在整个教程中使用的符号,从判别器(discriminator)开始。设x表示图像数据。D(x)表示判别器,它的输出是x来自训练数据而不是生成器的概率(标量)。这里,我们处理的是CHW(channel,height,width)为3*64*64大小的图像。直观的说,当x来自训练数据时D(x)的值应该是高的,当x来自生成器时D(x)的值应该是低的。你也可以把D(x)看作是传统的二元分类器。

对于生成器(generator )的符号,设z是从标准正态分布采样的隐向量(此处的隐没有什么特别高深晦涩难懂的意思,就像前馈神经网络的隐藏层一样,表示没有物理含义的变量或空间,一般不具备可解释性),G(z)表示将隐向量z映射到数据空间的生成函数。G的目标是估算训练数据的分布(pdata),以便从估计的分布(pg)中生成假样本。

所以,D(G(z))是生成器G的输出是真实图片的概率(标量)。正如 Goodfellow的论文中所描述的:D和G在玩一个极大极小博弈:D试图最大化它能正确分类真赝品的概率 (logD(x)),而G试图最小化D预测其输出是假的概率 (log(1−D(G(x))))。从论文中可以看出,GAN的损失函数为:

理论上,这个极大极小博弈的解决方案是pg=pdata,判别器随机猜测输入图片是真是假。然而,GANs的收敛理论仍在积极研究中,而现实中的模型通常不能做到收敛。

什么是DCGAN?(What is a DCGAN?)

DCGAN是上述DAN的直接扩展,不同之处在于它在判别器和生成器中分别使用了卷积和卷积转置层。它是由Radford 等人在 Unsupervised Representation Learning With Deep Convolutional Generative Adversarial Networks论文中首先提出的。其中的判别器由convolution层,batch norm层,和LeakyReLU激活函数组成。输入是一个3*64*64的图片数据,输出是一个概率(标量),即输入来自真实数据的分布。其中的生成器由convolutional-transpose层,batch norm层,和ReLU激活函数组成。输入是一个隐向量——z,来自标准正态分布,输出是一个3*64*64的GRB图片。卷积转置层可以将隐向量转换成图像的形状。在论文中,作者还提供了一些关于如何设置优化器、如何计算损失函数以及如何初始化模型权重的提示,这些将在接下来的部分中进行解释。

  1. from __future__ import print_function
  2. #%matplotlib inline
  3. import argparse
  4. import os
  5. import random
  6. import torch
  7. import torch.nn as nn
  8. import torch.nn.parallel
  9. import torch.backends.cudnn as cudnn
  10. import torch.optim as optim
  11. import torch.utils.data
  12. import torchvision.datasets as dset
  13. import torchvision.transforms as transforms
  14. import torchvision.utils as vutils
  15. import numpy as np
  16. import matplotlib.pyplot as plt
  17. import matplotlib.animation as animation
  18. from IPython.display import HTML
  19. # Set random seed for reproducibility
  20. manualSeed = 999
  21. #manualSeed = random.randint(1, 10000) # use if you want new results
  22. print("Random Seed: ", manualSeed)
  23. random.seed(manualSeed)
  24. torch.manual_seed(manualSeed)

输出:

Random Seed:  999

输入(Inputs)

让我们为接下来的运行定义一些输入:

  • dataroot - 数据集存放路径. 我们将在下一节中深入讨论
  • workers - 多进程加载数据所用的进程数
  • batch_size - 训练时batch的大小.  DCGAN 论文中使用的是 128
  • image_size -训练图片的尺寸. 这里默认是 64x64.如果需要另一种尺寸,则必须更改 D 和G 的结构. 参阅here 了解更多详细信息。
  • nc - 输入图片的通道数. 这里是3
  • nz - 隐向量的维度(即来自标准正态分布的隐向量的维度)(也即高斯噪声的维度)
  • ngf - 生成器的特征图数量(即进行最后一次卷积转置层时,out_channels为3时的in_channels)
  • ndf - 判别器的特征图数量(即进行第一次卷积时,in_channels为3时的out通道数)
  • num_epochs - 训练模型的迭代次数。长时间的训练可能会带来更好的结果,但也需要更长的时间
  • lr - 训练时的学习率. 在DCGAN 论文中, 这个数值是 0.0002
  • beta1 - Adam 优化器的beta1参数. 在论文中,此数值是0.5
  • ngpu - 可用GPU的数量. 如果为 0, 代码将使用CPU训练. 如果大于0,将使用此数值的GPU进行训练
  1. # Root directory for dataset
  2. dataroot = "data/celeba"
  3. # Number of workers for dataloader
  4. workers = 2
  5. # Batch size during training
  6. batch_size = 128
  7. # Spatial size of training images. All images will be resized to this
  8. # size using a transformer.
  9. image_size = 64
  10. # Number of channels in the training images. For color images this is 3
  11. nc = 3
  12. # Size of z latent vector (i.e. size of generator input)
  13. nz = 100
  14. # Size of feature maps in generator
  15. ngf = 64
  16. # Size of feature maps in discriminator
  17. ndf = 64
  18. # Number of training epochs
  19. num_epochs = 5
  20. # Learning rate for optimizers
  21. lr = 0.0002
  22. # Beta1 hyperparam for Adam optimizers
  23. beta1 = 0.5
  24. # Number of GPUs available. Use 0 for CPU mode.
  25. ngpu = 1

数据(Data)

在本教程中,我们将使用 Celeb-A Faces dataset 数据集,该数据集可以在链接站点或 Google Drive中下载。数据集下载之后是一个名为img_align_celeba.zip的文件。当你下载完成之后,创建一个celeba目录并将zip文件解压到这个目录。然后,将上一节提到的dataroot 输入的值设置为我们刚刚创建的celeba目录。生成的目录结构应为:

  1. /path/to/celeba
  2. -> img_align_celeba
  3. -> 188242.jpg
  4. -> 173822.jpg
  5. -> 284702.jpg
  6. -> 537394.jpg
  7. ...

这是非常重要的一步,因为我们将使用ImageFolder这个数据集类,它要求在这个数据集的根目录下必须要有子目录。现在,我们可以创建数据集,创建dataloader,设置device,最后可视化一些训练数据。

  1. # We can use an image folder dataset the way we have it setup.
  2. # Create the dataset
  3. dataset = dset.ImageFolder(root=dataroot,
  4. transform=transforms.Compose([
  5. transforms.Resize(image_size),
  6. transforms.CenterCrop(image_size),
  7. transforms.ToTensor(),
  8. transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
  9. ]))
  10. # Create the dataloader
  11. dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size,
  12. shuffle=True)
  13. # Decide which device we want to run on
  14. device = torch.device("cuda:0" if (torch.cuda.is_available() and ngpu > 0) else "cpu")
  15. # Plot some training images
  16. real_batch = next(iter(dataloader))
  17. plt.figure(figsize=(8,8))
  18. plt.axis("off")
  19. plt.title("Training Images")
  20. plt.imshow(np.transpose(vutils.make_grid(real_batch[0].to(device)[:64], padding=2, normalize=True).cpu(),(1,2,0)))

实现(Implementation)

输入参数和数据集都准备好了,现在可以进入实现环节了。我们将会从权重的初始化策略开始,然后详细讨论生成器,判别器,损失函数和训练过程。

权重初始化(Weight Initialization)

在DCGAN的论文中,作者指定所有模型的初始化权重是一个均值为0,标准差为0.02的正态分布。weights_init函数的输入是一个初始化的模型,然后按此标准重新初始化模型的卷积层、卷积转置层和BN层的权重。模型初始化后应立即应用此函数。(这个文章中,我有的时候用的权重,有时候用参数,这两个名词是等价的)

  1. # custom weights initialization called on netG and netD
  2. def weights_init(m):
  3. classname = m.__class__.__name__
  4. if classname.find('Conv') != -1:
  5. nn.init.normal_(m.weight.data, 0.0, 0.02)
  6. elif classname.find('BatchNorm') != -1:
  7. nn.init.normal_(m.weight.data, 1.0, 0.02)
  8. nn.init.constant_(m.bias.data, 0)

生成器(Generator)

生成器G, 用于将隐向量 (z)映射到数据空间。 由于我们的数据是图片,也就是通过隐向量z生成一张与训练图片大小相同的RGB图片 (比如 3x64x64). 在实践中,这是通过一系列的ConvTranspose2d,BatchNorm2d,ReLU完成的。 生成器的输出,通过tanh激活函数把数据映射到[−1,1]。值得注意的是,在卷积转置层之后紧跟BN层,这是DCGAN论文的重要贡献。这些层(即BN层)有助于训练过程中梯度的流动。DCGAN论文中的生成器如下图所示。

dcgan_generator

注意,我们在输入(Inputs)小节设置的参数 (nzngf, and nc) 影响着生成器G的架构。 nz 是隐向量z的长度, ngf 为生成器的特征图大小,nc 是输出图片(若为RGB图像,则设置为3)的通道数。 生成器的代码如下:

  1. # Generator Code
  2. class Generator(nn.Module):
  3. def __init__(self, ngpu):
  4. super(Generator, self).__init__()
  5. self.ngpu = ngpu
  6. self.main = nn.Sequential(
  7. # input is Z, going into a convolution
  8. nn.ConvTranspose2d(in_channels=nz, out_channels=ngf * 8, kernel_size=4, stride=1, padding=0, bias=False),
  9. nn.BatchNorm2d(ngf * 8),
  10. nn.ReLU(True),
  11. # state size. (ngf*8) x 4 x 4
  12. nn.ConvTranspose2d(in_channels=ngf * 8, out_channels=ngf * 4, kernel_size=4, stride=2, padding=1, bias=False),
  13. nn.BatchNorm2d(ngf * 4),
  14. nn.ReLU(True),
  15. # state size. (ngf*4) x 8 x 8
  16. nn.ConvTranspose2d(in_channels=ngf * 4, out_channels=ngf * 2, kernel_size=4, stride=2, padding=1, bias=False),
  17. nn.BatchNorm2d(ngf * 2),
  18. nn.ReLU(True),
  19. # state size. (ngf*2) x 16 x 16
  20. nn.ConvTranspose2d(in_channels=ngf * 2, out_channels=ngf, kernel_size=4, stride=2, padding=1, bias=False),
  21. nn.BatchNorm2d(ngf),
  22. nn.ReLU(True),
  23. # state size. (ngf) x 32 x 32
  24. nn.ConvTranspose2d(in_channels=ngf, out_channels=nc, kernel_size=4, stride=2, padding=1, bias=False),
  25. nn.Tanh()
  26. # state size. (nc) x 64 x 64
  27. )
  28. """
  29. 上卷积层可理解为是卷积层的逆运算。
  30. 拿最后一个上卷积层举例。若卷积的输入是(nc) x 64 x 64时,
  31. 经过Hout=(Hin+2*Padding-kernel_size)/stride+1=(64+2*1-4)/2+1=32,输出为(out_channels) x 32 x 32
  32. 此处上卷积层为卷积层的输入输出的倒置:
  33. 即输入通道数为out_channels,输出通道数为3;输入图片大小为(out_channels) x 32 x 32,输出图片的大小为(nc) x 64 x 64
  34. """
  35. def forward(self, input):
  36. return self.main(input)

现在,我们可以实例化生成器,并应用weights_init方法。打印并查看生成器的结构。

  1. # Create the generator
  2. netG = Generator(ngpu).to(device)
  3. # Handle multi-gpu if desired
  4. if (device.type == 'cuda') and (ngpu > 1):
  5. netG = nn.DataParallel(netG, list(range(ngpu)))
  6. # Apply the weights_init function to randomly initialize all weights
  7. # to mean=0, stdev=0.2.
  8. netG.apply(weights_init)
  9. # Print the model
  10. print(netG)

输出如下:

  1. Generator(
  2. (main): Sequential(
  3. (0): ConvTranspose2d(100, 512, kernel_size=(4, 4), stride=(1, 1), bias=False)
  4. (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  5. (2): ReLU(inplace=True)
  6. (3): ConvTranspose2d(512, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
  7. (4): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  8. (5): ReLU(inplace=True)
  9. (6): ConvTranspose2d(256, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
  10. (7): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  11. (8): ReLU(inplace=True)
  12. (9): ConvTranspose2d(128, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
  13. (10): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  14. (11): ReLU(inplace=True)
  15. (12): ConvTranspose2d(64, 3, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
  16. (13): Tanh()
  17. )
  18. )

判别器(Discriminator)

如前所述,判别器D是一个二分类网络,它将图片作为输入,输出其为真的标量概率。这里,D的输入是一个3*64*64的图片,通过一系列的 Conv2d, BatchNorm2d,和 LeakyReLU 层对其进行处理,最后通过Sigmoid 激活函数输出最终概率。如有必要,你可以使用更多层对其扩展。DCGAN 论文提到使用跨步卷积而不是池化进行降采样是一个很好的实践,因为它可以让网络自己学习池化方法。BatchNorm2d层和LeakyReLU层也促进了梯度的健康流动,这对生成器G和判别器D的学习过程都是至关重要的。

判别器代码

  1. class Discriminator(nn.Module):
  2. def __init__(self, ngpu):
  3. super(Discriminator, self).__init__()
  4. self.ngpu = ngpu
  5. self.main = nn.Sequential(
  6. # input is (nc) x 64 x 64
  7. nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),
  8. nn.LeakyReLU(0.2, inplace=True),
  9. # state size. (ndf) x 32 x 32
  10. nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
  11. nn.BatchNorm2d(ndf * 2),
  12. nn.LeakyReLU(0.2, inplace=True),
  13. # state size. (ndf*2) x 16 x 16
  14. nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
  15. nn.BatchNorm2d(ndf * 4),
  16. nn.LeakyReLU(0.2, inplace=True),
  17. # state size. (ndf*4) x 8 x 8
  18. nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
  19. nn.BatchNorm2d(ndf * 8),
  20. nn.LeakyReLU(0.2, inplace=True),
  21. # state size. (ndf*8) x 4 x 4
  22. nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),
  23. nn.Sigmoid()
  24. )
  25. def forward(self, input):
  26. return self.main(input)

现在,我们可以实例化判别器,并应用weights_init方法。打印并查看判别器的结构。

  1. # Create the Discriminator
  2. netD = Discriminator(ngpu).to(device)
  3. # Handle multi-gpu if desired
  4. if (device.type == 'cuda') and (ngpu > 1):
  5. netD = nn.DataParallel(netD, list(range(ngpu)))
  6. # Apply the weights_init function to randomly initialize all weights
  7. # to mean=0, stdev=0.2.
  8. netD.apply(weights_init)
  9. # Print the model
  10. print(netD)

输出如下:

  1. Discriminator(
  2. (main): Sequential(
  3. (0): Conv2d(3, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
  4. (1): LeakyReLU(negative_slope=0.2, inplace=True)
  5. (2): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
  6. (3): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  7. (4): LeakyReLU(negative_slope=0.2, inplace=True)
  8. (5): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
  9. (6): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  10. (7): LeakyReLU(negative_slope=0.2, inplace=True)
  11. (8): Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
  12. (9): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  13. (10): LeakyReLU(negative_slope=0.2, inplace=True)
  14. (11): Conv2d(512, 1, kernel_size=(4, 4), stride=(1, 1), bias=False)
  15. (12): Sigmoid()
  16. )
  17. )

损失函数和优化器(Loss Functions and Optimizers)

有了生成器D和判别器G,我们可以为其指定损失函数和优化器来进行学习。这里将使用Binary Cross Entropy损失函数 (BCELoss)。其在PyTorch中的定义为:

注意这个损失函数需要你提供两个log组件 (比如 log(D(x))和log(1−D(G(z))))。我们可以指定BCE的哪个部分使用输入y标签这将会在接下来的训练小节中讲到,但是明白我们可以仅仅通过改变y标签来指定使用哪个log部分是非常重要的(比如GT标签)。

接下来,我们定义真实标签为1,假标签为0。这些标签用来计算生成器D和判别器G的损失,这也是原始GAN论文的惯例。最后,我们将设置两个独立的优化器,一个用于生成器G,另一个判别器D。如DCGAN 论文所述,两个Adam优化器学习率都为0.0002,Beta1都为0.5。为了记录生成器的学习过程,我们将会生成一批符合高斯分布的固定的隐向量(即fixed_noise)。在训练过程中,我们将周期性地把固定噪声作为生成器G的输入,通过输出看到由噪声生成的图像。

  1. # Initialize BCELoss function
  2. criterion = nn.BCELoss()
  3. # Create batch of latent vectors that we will use to visualize
  4. # the progression of the generator
  5. fixed_noise = torch.randn(64, nz, 1, 1, device=device)
  6. # Establish convention for real and fake labels during training
  7. real_label = 1.
  8. fake_label = 0.
  9. # Setup Adam optimizers for both G and D
  10. optimizerD = optim.Adam(netD.parameters(), lr=lr, betas=(beta1, 0.999))
  11. optimizerG = optim.Adam(netG.parameters(), lr=lr, betas=(beta1, 0.999))

训练(Training)

最后,我们已经定义了GAN网络的所有结构,可以开始训练它了。请注意,训练GAN有点像一种艺术形式,因为不正确的超参数会导致模式崩溃,却不会提示超参数错误的信息。这里,我们将遵循Goodfellow’s论文的算法1,同时遵循 ganhacks中的一些最佳实践。也就是说,我们将会“为真假数据构造不同的mini-batches数据”,同时调整判别器G的目标函数以最大化logD(G(z))。训练分为两个部分。第一部分更新判别器,第二部分更新生成器。

第一部分——训练判别器(Part 1 - Train the Discriminator)

回想一下,判别器的训练目的是最大化输入正确分类的概率。从Goodfellow的角度来看,我们希望“通过随机梯度的变化来更新鉴别器”。实际上,我们想要最大化log(D(x))+log(1−D(G(z)))。为了区别mini-batch,ganhacks建议分两步计算。第一步,我们将会构造一个来自训练数据的真图片batch,作为判别器D的输入,计算其损失loss(log(D(x)),调用backward方法计算梯度。第二步,我们将会构造一个来自生成器G的假图片batch,作为判别器D的输入,计算其损失loss(log(1−D(G(z))),调用backward方法累计梯度。最后,调用判别器D优化器的step方法更新一次模型(即判别器D)的参数。

第二部分——训练生成器(Part 2 - Train the Generator)

如原论文所述,我们希望通过最小化log(1−D(G(z)))训练生成器G来创造更好的假图片。作为解决方案,我们希望最大化log(D(G(z)))。通过以下方法来实现这一点:使用判别器D来分类在第一部分G的输出图片,计算损失函数的时候用真实标签(记做GT),调用backward方法更新生成器G的梯度,最后调用生成器G优化器的step方法更新一次模型(即生成器G)的参数。使用真实标签作为GT来计算损失函数看起来有悖常理,但是这允许我们可以使用BCELoss的log(x)部分而不是log(1−x)部分,这正是我们想要的。

最后,我们将做一些统计报告。以展示每个迭代完成之后我们的固定噪声通过生成器G产生的图片信息。训练过程中统计数据报告如下:

  • Loss_D - 真假batch图片输入判别器后,所产生的损失总和((log(D(x)) + log(D(G(z))))).
  • Loss_G - 生成器损失总和(log(D(G(z))))
  • D(x) - 真batch图片输入判别器后,所产生的的平均值(即平均概率)。这个值理论上应该接近1,然后随着生成器的改善,它会收敛到0.5左右。
  • D(G(z)) - 假batch图片输入判别器后,所产生的平均值(即平均概率)。第一个值在判别器D更新之前,第二个值在判别器D更新之后。这两个值应该从接近0开始,随着G的改善收敛到0.5。

注意: 这一步可能会运行时间久一些。这取决于你跑了多少Epochs和你的数据集中有多少数据。

  1. # Training Loop
  2. # Lists to keep track of progress
  3. img_list = []
  4. G_losses = []
  5. D_losses = []
  6. iters = 0
  7. print("Starting Training Loop...")
  8. # For each epoch
  9. for epoch in range(num_epochs):
  10. # For each batch in the dataloader
  11. for i, data in enumerate(dataloader, 0):
  12. ############################
  13. # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
  14. ###########################
  15. ## Train with all-real batch
  16. netD.zero_grad()
  17. # Format batch
  18. real_cpu = data[0].to(device)
  19. b_size = real_cpu.size(0)
  20. label = torch.full((b_size,), real_label, dtype=torch.float, device=device)
  21. # Forward pass real batch through D
  22. output = netD(real_cpu).view(-1)
  23. # Calculate loss on all-real batch
  24. errD_real = criterion(output, label)
  25. # Calculate gradients for D in backward pass
  26. errD_real.backward()
  27. D_x = output.mean().item()
  28. ## Train with all-fake batch
  29. # Generate batch of latent vectors
  30. noise = torch.randn(b_size, nz, 1, 1, device=device)
  31. # Generate fake image batch with G
  32. fake = netG(noise)
  33. label.fill_(fake_label)
  34. # Classify all fake batch with D
  35. output = netD(fake.detach()).view(-1)
  36. # Calculate D's loss on the all-fake batch
  37. errD_fake = criterion(output, label)
  38. # Calculate the gradients for this batch
  39. errD_fake.backward()
  40. D_G_z1 = output.mean().item()
  41. # Add the gradients from the all-real and all-fake batches
  42. errD = errD_real + errD_fake
  43. # Update D
  44. optimizerD.step()
  45. ############################
  46. # (2) Update G network: maximize log(D(G(z)))
  47. ###########################
  48. netG.zero_grad()
  49. label.fill_(real_label) # fake labels are real for generator cost
  50. # Since we just updated D, perform another forward pass of all-fake batch through D
  51. output = netD(fake).view(-1)
  52. # Calculate G's loss based on this output
  53. errG = criterion(output, label)
  54. # Calculate gradients for G
  55. errG.backward()
  56. D_G_z2 = output.mean().item()
  57. # Update G
  58. optimizerG.step()
  59. # Output training stats
  60. if i % 50 == 0:
  61. print('[%d/%d][%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f\tD(x): %.4f\tD(G(z)): %.4f / %.4f'
  62. % (epoch, num_epochs, i, len(dataloader),
  63. errD.item(), errG.item(), D_x, D_G_z1, D_G_z2))
  64. # Save Losses for plotting later
  65. G_losses.append(errG.item())
  66. D_losses.append(errD.item())
  67. # Check how the generator is doing by saving G's output on fixed_noise
  68. if (iters % 500 == 0) or ((epoch == num_epochs-1) and (i == len(dataloader)-1)):
  69. with torch.no_grad():
  70. fake = netG(fixed_noise).detach().cpu()
  71. img_list.append(vutils.make_grid(fake, padding=2, normalize=True))
  72. iters += 1

输出:

  1. Starting Training Loop...
  2. [0/5][0/1583] Loss_D: 1.7834 Loss_G: 5.0952 D(x): 0.5564 D(G(z)): 0.5963 / 0.0094
  3. [0/5][50/1583] Loss_D: 0.2582 Loss_G: 28.5604 D(x): 0.8865 D(G(z)): 0.0000 / 0.0000
  4. [0/5][100/1583] Loss_D: 0.9311 Loss_G: 13.3240 D(x): 0.9443 D(G(z)): 0.4966 / 0.0000
  5. [0/5][150/1583] Loss_D: 0.7385 Loss_G: 8.8132 D(x): 0.9581 D(G(z)): 0.4625 / 0.0004
  6. [0/5][200/1583] Loss_D: 0.4796 Loss_G: 6.5862 D(x): 0.9888 D(G(z)): 0.3271 / 0.0047
  7. [0/5][250/1583] Loss_D: 0.7410 Loss_G: 5.4159 D(x): 0.8282 D(G(z)): 0.3274 / 0.0082
  8. [0/5][300/1583] Loss_D: 0.4622 Loss_G: 3.7107 D(x): 0.7776 D(G(z)): 0.1251 / 0.0375
  9. [0/5][350/1583] Loss_D: 1.0642 Loss_G: 6.3149 D(x): 0.9374 D(G(z)): 0.5391 / 0.0061
  10. [0/5][400/1583] Loss_D: 0.3848 Loss_G: 6.3376 D(x): 0.9153 D(G(z)): 0.2209 / 0.0036
  11. [0/5][450/1583] Loss_D: 0.2790 Loss_G: 4.3376 D(x): 0.8896 D(G(z)): 0.1256 / 0.0217
  12. [0/5][500/1583] Loss_D: 1.2478 Loss_G: 8.1121 D(x): 0.9361 D(G(z)): 0.5578 / 0.0016
  13. [0/5][550/1583] Loss_D: 0.3393 Loss_G: 4.0673 D(x): 0.8257 D(G(z)): 0.0496 / 0.0323
  14. [0/5][600/1583] Loss_D: 0.8083 Loss_G: 2.5396 D(x): 0.6232 D(G(z)): 0.0484 / 0.1265
  15. [0/5][650/1583] Loss_D: 0.3682 Loss_G: 4.3142 D(x): 0.8227 D(G(z)): 0.1114 / 0.0217
  16. [0/5][700/1583] Loss_D: 0.4788 Loss_G: 6.2379 D(x): 0.8594 D(G(z)): 0.2307 / 0.0037
  17. [0/5][750/1583] Loss_D: 0.4767 Loss_G: 5.3962 D(x): 0.8935 D(G(z)): 0.2463 / 0.0092
  18. [0/5][800/1583] Loss_D: 0.8085 Loss_G: 2.3573 D(x): 0.5934 D(G(z)): 0.0769 / 0.1357
  19. [0/5][850/1583] Loss_D: 0.3595 Loss_G: 3.9025 D(x): 0.7769 D(G(z)): 0.0563 / 0.0381
  20. [0/5][900/1583] Loss_D: 0.3235 Loss_G: 4.7795 D(x): 0.9224 D(G(z)): 0.1785 / 0.0163
  21. [0/5][950/1583] Loss_D: 0.3426 Loss_G: 3.1228 D(x): 0.8257 D(G(z)): 0.0847 / 0.0795
  22. [0/5][1000/1583] Loss_D: 0.6667 Loss_G: 7.3167 D(x): 0.9556 D(G(z)): 0.3751 / 0.0019
  23. [0/5][1050/1583] Loss_D: 0.2840 Loss_G: 5.0387 D(x): 0.9268 D(G(z)): 0.1642 / 0.0143
  24. [0/5][1100/1583] Loss_D: 0.4534 Loss_G: 3.8780 D(x): 0.7535 D(G(z)): 0.0697 / 0.0391
  25. [0/5][1150/1583] Loss_D: 0.5040 Loss_G: 2.9283 D(x): 0.7452 D(G(z)): 0.1167 / 0.0839
  26. [0/5][1200/1583] Loss_D: 0.6478 Loss_G: 4.0913 D(x): 0.6595 D(G(z)): 0.0263 / 0.0358
  27. [0/5][1250/1583] Loss_D: 1.2299 Loss_G: 7.8236 D(x): 0.9850 D(G(z)): 0.5941 / 0.0013
  28. [0/5][1300/1583] Loss_D: 0.3228 Loss_G: 4.9211 D(x): 0.8882 D(G(z)): 0.1488 / 0.0140
  29. [0/5][1350/1583] Loss_D: 0.4208 Loss_G: 4.1520 D(x): 0.8254 D(G(z)): 0.1638 / 0.0260
  30. [0/5][1400/1583] Loss_D: 0.5751 Loss_G: 3.9585 D(x): 0.7692 D(G(z)): 0.1902 / 0.0329
  31. [0/5][1450/1583] Loss_D: 1.6244 Loss_G: 0.5350 D(x): 0.3037 D(G(z)): 0.0159 / 0.6617
  32. [0/5][1500/1583] Loss_D: 0.3676 Loss_G: 3.2653 D(x): 0.8076 D(G(z)): 0.0825 / 0.0710
  33. [0/5][1550/1583] Loss_D: 0.2759 Loss_G: 4.4156 D(x): 0.9010 D(G(z)): 0.1370 / 0.0178
  34. [1/5][0/1583] Loss_D: 1.0879 Loss_G: 7.8641 D(x): 0.8737 D(G(z)): 0.5376 / 0.0008
  35. [1/5][50/1583] Loss_D: 0.2761 Loss_G: 4.4716 D(x): 0.9008 D(G(z)): 0.1267 / 0.0231
  36. [1/5][100/1583] Loss_D: 0.3438 Loss_G: 4.0343 D(x): 0.8389 D(G(z)): 0.1162 / 0.0308
  37. [1/5][150/1583] Loss_D: 0.4937 Loss_G: 4.8593 D(x): 0.7951 D(G(z)): 0.1819 / 0.0162
  38. [1/5][200/1583] Loss_D: 0.3973 Loss_G: 3.2078 D(x): 0.8671 D(G(z)): 0.1916 / 0.0587
  39. [1/5][250/1583] Loss_D: 0.4521 Loss_G: 4.5155 D(x): 0.9006 D(G(z)): 0.2441 / 0.0222
  40. [1/5][300/1583] Loss_D: 0.4423 Loss_G: 5.3907 D(x): 0.8635 D(G(z)): 0.2039 / 0.0125
  41. [1/5][350/1583] Loss_D: 0.6447 Loss_G: 2.5607 D(x): 0.6177 D(G(z)): 0.0195 / 0.1284
  42. [1/5][400/1583] Loss_D: 0.4079 Loss_G: 4.2563 D(x): 0.8621 D(G(z)): 0.1949 / 0.0268
  43. [1/5][450/1583] Loss_D: 0.9649 Loss_G: 8.0302 D(x): 0.9727 D(G(z)): 0.5302 / 0.0010
  44. [1/5][500/1583] Loss_D: 0.7693 Loss_G: 5.9895 D(x): 0.9070 D(G(z)): 0.4331 / 0.0053
  45. [1/5][550/1583] Loss_D: 0.4522 Loss_G: 2.6169 D(x): 0.7328 D(G(z)): 0.0634 / 0.1113
  46. [1/5][600/1583] Loss_D: 0.4039 Loss_G: 3.4861 D(x): 0.8436 D(G(z)): 0.1738 / 0.0494
  47. [1/5][650/1583] Loss_D: 0.4434 Loss_G: 3.0261 D(x): 0.7756 D(G(z)): 0.1299 / 0.0777
  48. [1/5][700/1583] Loss_D: 1.5401 Loss_G: 8.3636 D(x): 0.9705 D(G(z)): 0.7050 / 0.0011
  49. [1/5][750/1583] Loss_D: 0.3899 Loss_G: 4.3379 D(x): 0.7379 D(G(z)): 0.0231 / 0.0248
  50. [1/5][800/1583] Loss_D: 0.9547 Loss_G: 5.6122 D(x): 0.9520 D(G(z)): 0.5318 / 0.0074
  51. [1/5][850/1583] Loss_D: 0.3714 Loss_G: 3.2116 D(x): 0.7770 D(G(z)): 0.0752 / 0.0700
  52. [1/5][900/1583] Loss_D: 0.2717 Loss_G: 4.0063 D(x): 0.8673 D(G(z)): 0.1058 / 0.0272
  53. [1/5][950/1583] Loss_D: 0.2652 Loss_G: 3.7649 D(x): 0.8381 D(G(z)): 0.0540 / 0.0361
  54. [1/5][1000/1583] Loss_D: 0.9463 Loss_G: 1.6266 D(x): 0.5189 D(G(z)): 0.0913 / 0.2722
  55. [1/5][1050/1583] Loss_D: 0.7117 Loss_G: 3.7363 D(x): 0.8544 D(G(z)): 0.3578 / 0.0397
  56. [1/5][1100/1583] Loss_D: 0.5164 Loss_G: 4.0939 D(x): 0.8865 D(G(z)): 0.2904 / 0.0252
  57. [1/5][1150/1583] Loss_D: 0.3745 Loss_G: 3.1891 D(x): 0.8262 D(G(z)): 0.1358 / 0.0645
  58. [1/5][1200/1583] Loss_D: 0.4583 Loss_G: 2.9545 D(x): 0.7866 D(G(z)): 0.1453 / 0.0778
  59. [1/5][1250/1583] Loss_D: 0.5870 Loss_G: 4.4096 D(x): 0.9473 D(G(z)): 0.3706 / 0.0208
  60. [1/5][1300/1583] Loss_D: 0.5159 Loss_G: 4.1076 D(x): 0.8640 D(G(z)): 0.2738 / 0.0240
  61. [1/5][1350/1583] Loss_D: 0.6005 Loss_G: 1.8590 D(x): 0.6283 D(G(z)): 0.0418 / 0.2032
  62. [1/5][1400/1583] Loss_D: 0.3646 Loss_G: 3.4323 D(x): 0.7712 D(G(z)): 0.0653 / 0.0534
  63. [1/5][1450/1583] Loss_D: 0.6245 Loss_G: 2.2462 D(x): 0.6515 D(G(z)): 0.0905 / 0.1514
  64. [1/5][1500/1583] Loss_D: 0.6055 Loss_G: 1.7674 D(x): 0.7026 D(G(z)): 0.1682 / 0.2169
  65. [1/5][1550/1583] Loss_D: 0.5181 Loss_G: 3.2728 D(x): 0.7926 D(G(z)): 0.2048 / 0.0549
  66. [2/5][0/1583] Loss_D: 0.9580 Loss_G: 5.1154 D(x): 0.9605 D(G(z)): 0.5535 / 0.0105
  67. [2/5][50/1583] Loss_D: 0.9947 Loss_G: 1.7223 D(x): 0.4860 D(G(z)): 0.0563 / 0.2477
  68. [2/5][100/1583] Loss_D: 0.7023 Loss_G: 4.1781 D(x): 0.9083 D(G(z)): 0.4116 / 0.0239
  69. [2/5][150/1583] Loss_D: 0.3496 Loss_G: 2.7264 D(x): 0.8871 D(G(z)): 0.1795 / 0.0982
  70. [2/5][200/1583] Loss_D: 0.6805 Loss_G: 3.8157 D(x): 0.8900 D(G(z)): 0.3851 / 0.0312
  71. [2/5][250/1583] Loss_D: 0.6193 Loss_G: 3.8180 D(x): 0.8557 D(G(z)): 0.3286 / 0.0303
  72. [2/5][300/1583] Loss_D: 0.6480 Loss_G: 1.4683 D(x): 0.6157 D(G(z)): 0.0640 / 0.2844
  73. [2/5][350/1583] Loss_D: 0.7498 Loss_G: 4.1299 D(x): 0.8922 D(G(z)): 0.4244 / 0.0256
  74. [2/5][400/1583] Loss_D: 0.7603 Loss_G: 4.2291 D(x): 0.9512 D(G(z)): 0.4604 / 0.0213
  75. [2/5][450/1583] Loss_D: 0.4833 Loss_G: 4.0068 D(x): 0.9348 D(G(z)): 0.3095 / 0.0257
  76. [2/5][500/1583] Loss_D: 1.2311 Loss_G: 0.7107 D(x): 0.3949 D(G(z)): 0.0496 / 0.5440
  77. [2/5][550/1583] Loss_D: 0.9657 Loss_G: 1.5119 D(x): 0.4513 D(G(z)): 0.0338 / 0.2821
  78. [2/5][600/1583] Loss_D: 0.5351 Loss_G: 3.4546 D(x): 0.8889 D(G(z)): 0.3018 / 0.0449
  79. [2/5][650/1583] Loss_D: 0.8761 Loss_G: 1.2051 D(x): 0.5292 D(G(z)): 0.1193 / 0.3583
  80. [2/5][700/1583] Loss_D: 1.0206 Loss_G: 4.5741 D(x): 0.8599 D(G(z)): 0.5140 / 0.0159
  81. [2/5][750/1583] Loss_D: 1.0886 Loss_G: 5.4749 D(x): 0.9770 D(G(z)): 0.6093 / 0.0067
  82. [2/5][800/1583] Loss_D: 0.6539 Loss_G: 3.5203 D(x): 0.9074 D(G(z)): 0.3962 / 0.0390
  83. [2/5][850/1583] Loss_D: 0.8633 Loss_G: 1.0995 D(x): 0.5701 D(G(z)): 0.1401 / 0.3842
  84. [2/5][900/1583] Loss_D: 0.3703 Loss_G: 2.2482 D(x): 0.8183 D(G(z)): 0.1302 / 0.1329
  85. [2/5][950/1583] Loss_D: 0.6592 Loss_G: 1.6081 D(x): 0.6040 D(G(z)): 0.0818 / 0.2523
  86. [2/5][1000/1583] Loss_D: 0.7449 Loss_G: 1.0548 D(x): 0.5975 D(G(z)): 0.1375 / 0.4085
  87. [2/5][1050/1583] Loss_D: 0.5783 Loss_G: 2.3644 D(x): 0.6435 D(G(z)): 0.0531 / 0.1357
  88. [2/5][1100/1583] Loss_D: 0.6123 Loss_G: 2.2695 D(x): 0.7269 D(G(z)): 0.2083 / 0.1343
  89. [2/5][1150/1583] Loss_D: 0.6263 Loss_G: 1.8714 D(x): 0.6661 D(G(z)): 0.1407 / 0.1914
  90. [2/5][1200/1583] Loss_D: 0.4233 Loss_G: 3.0119 D(x): 0.8533 D(G(z)): 0.2039 / 0.0692
  91. [2/5][1250/1583] Loss_D: 0.8826 Loss_G: 3.3618 D(x): 0.7851 D(G(z)): 0.3971 / 0.0502
  92. [2/5][1300/1583] Loss_D: 0.6201 Loss_G: 2.1584 D(x): 0.6418 D(G(z)): 0.0977 / 0.1536
  93. [2/5][1350/1583] Loss_D: 0.9558 Loss_G: 3.8876 D(x): 0.8561 D(G(z)): 0.5001 / 0.0302
  94. [2/5][1400/1583] Loss_D: 0.4369 Loss_G: 2.3479 D(x): 0.7959 D(G(z)): 0.1588 / 0.1214
  95. [2/5][1450/1583] Loss_D: 0.5086 Loss_G: 2.1034 D(x): 0.6758 D(G(z)): 0.0586 / 0.1575
  96. [2/5][1500/1583] Loss_D: 0.6513 Loss_G: 3.5801 D(x): 0.8535 D(G(z)): 0.3429 / 0.0455
  97. [2/5][1550/1583] Loss_D: 0.6975 Loss_G: 2.5560 D(x): 0.7379 D(G(z)): 0.2784 / 0.1031
  98. [3/5][0/1583] Loss_D: 2.2846 Loss_G: 1.7977 D(x): 0.1771 D(G(z)): 0.0111 / 0.2394
  99. [3/5][50/1583] Loss_D: 1.6111 Loss_G: 5.7904 D(x): 0.9581 D(G(z)): 0.7350 / 0.0063
  100. [3/5][100/1583] Loss_D: 0.8553 Loss_G: 1.0540 D(x): 0.5229 D(G(z)): 0.1020 / 0.3945
  101. [3/5][150/1583] Loss_D: 0.7402 Loss_G: 2.6338 D(x): 0.7668 D(G(z)): 0.3277 / 0.0959
  102. [3/5][200/1583] Loss_D: 0.9278 Loss_G: 2.9689 D(x): 0.8913 D(G(z)): 0.4787 / 0.0769
  103. [3/5][250/1583] Loss_D: 2.6573 Loss_G: 6.4810 D(x): 0.9684 D(G(z)): 0.8799 / 0.0035
  104. [3/5][300/1583] Loss_D: 0.5435 Loss_G: 1.9416 D(x): 0.7118 D(G(z)): 0.1454 / 0.1801
  105. [3/5][350/1583] Loss_D: 1.2350 Loss_G: 4.6877 D(x): 0.9595 D(G(z)): 0.6444 / 0.0147
  106. [3/5][400/1583] Loss_D: 0.9264 Loss_G: 0.9139 D(x): 0.4825 D(G(z)): 0.0715 / 0.4526
  107. [3/5][450/1583] Loss_D: 0.8967 Loss_G: 4.4258 D(x): 0.9155 D(G(z)): 0.5074 / 0.0174
  108. [3/5][500/1583] Loss_D: 0.6874 Loss_G: 2.4529 D(x): 0.7775 D(G(z)): 0.3171 / 0.1097
  109. [3/5][550/1583] Loss_D: 0.5821 Loss_G: 3.0756 D(x): 0.8681 D(G(z)): 0.3161 / 0.0609
  110. [3/5][600/1583] Loss_D: 0.7164 Loss_G: 1.5045 D(x): 0.5652 D(G(z)): 0.0428 / 0.2868
  111. [3/5][650/1583] Loss_D: 0.6290 Loss_G: 2.1863 D(x): 0.7952 D(G(z)): 0.2829 / 0.1442
  112. [3/5][700/1583] Loss_D: 0.6270 Loss_G: 1.2824 D(x): 0.6481 D(G(z)): 0.1184 / 0.3234
  113. [3/5][750/1583] Loss_D: 0.7011 Loss_G: 1.3549 D(x): 0.5861 D(G(z)): 0.0926 / 0.3017
  114. [3/5][800/1583] Loss_D: 0.6912 Loss_G: 1.4927 D(x): 0.5919 D(G(z)): 0.0741 / 0.2728
  115. [3/5][850/1583] Loss_D: 0.6385 Loss_G: 2.9333 D(x): 0.8418 D(G(z)): 0.3338 / 0.0723
  116. [3/5][900/1583] Loss_D: 0.7835 Loss_G: 4.4475 D(x): 0.9290 D(G(z)): 0.4703 / 0.0151
  117. [3/5][950/1583] Loss_D: 0.6294 Loss_G: 2.3463 D(x): 0.7388 D(G(z)): 0.2414 / 0.1202
  118. [3/5][1000/1583] Loss_D: 0.6288 Loss_G: 1.5448 D(x): 0.6575 D(G(z)): 0.1389 / 0.2581
  119. [3/5][1050/1583] Loss_D: 0.6292 Loss_G: 3.4867 D(x): 0.8741 D(G(z)): 0.3549 / 0.0433
  120. [3/5][1100/1583] Loss_D: 0.7644 Loss_G: 1.7661 D(x): 0.5457 D(G(z)): 0.0408 / 0.2076
  121. [3/5][1150/1583] Loss_D: 0.4918 Loss_G: 3.1858 D(x): 0.8576 D(G(z)): 0.2563 / 0.0527
  122. [3/5][1200/1583] Loss_D: 1.1773 Loss_G: 4.5200 D(x): 0.8192 D(G(z)): 0.5536 / 0.0183
  123. [3/5][1250/1583] Loss_D: 0.6889 Loss_G: 1.8073 D(x): 0.6909 D(G(z)): 0.2230 / 0.1969
  124. [3/5][1300/1583] Loss_D: 0.9721 Loss_G: 1.0578 D(x): 0.4541 D(G(z)): 0.0570 / 0.4080
  125. [3/5][1350/1583] Loss_D: 0.5301 Loss_G: 2.3562 D(x): 0.7453 D(G(z)): 0.1670 / 0.1222
  126. [3/5][1400/1583] Loss_D: 0.5464 Loss_G: 2.5304 D(x): 0.8018 D(G(z)): 0.2438 / 0.1020
  127. [3/5][1450/1583] Loss_D: 0.5987 Loss_G: 2.2034 D(x): 0.6195 D(G(z)): 0.0601 / 0.1477
  128. [3/5][1500/1583] Loss_D: 1.4470 Loss_G: 4.2791 D(x): 0.9006 D(G(z)): 0.6537 / 0.0221
  129. [3/5][1550/1583] Loss_D: 0.7917 Loss_G: 3.3235 D(x): 0.8287 D(G(z)): 0.4002 / 0.0489
  130. [4/5][0/1583] Loss_D: 0.7682 Loss_G: 1.2445 D(x): 0.5371 D(G(z)): 0.0538 / 0.3386
  131. [4/5][50/1583] Loss_D: 0.9274 Loss_G: 0.9439 D(x): 0.4905 D(G(z)): 0.1004 / 0.4476
  132. [4/5][100/1583] Loss_D: 0.9571 Loss_G: 0.7391 D(x): 0.4619 D(G(z)): 0.0511 / 0.5431
  133. [4/5][150/1583] Loss_D: 1.4795 Loss_G: 0.7522 D(x): 0.3092 D(G(z)): 0.0387 / 0.5307
  134. [4/5][200/1583] Loss_D: 0.5203 Loss_G: 1.8662 D(x): 0.7279 D(G(z)): 0.1425 / 0.1895
  135. [4/5][250/1583] Loss_D: 0.8140 Loss_G: 1.9120 D(x): 0.5155 D(G(z)): 0.0606 / 0.1939
  136. [4/5][300/1583] Loss_D: 0.5813 Loss_G: 2.5807 D(x): 0.7674 D(G(z)): 0.2255 / 0.1008
  137. [4/5][350/1583] Loss_D: 0.5209 Loss_G: 2.8571 D(x): 0.8125 D(G(z)): 0.2389 / 0.0743
  138. [4/5][400/1583] Loss_D: 0.4505 Loss_G: 2.7965 D(x): 0.8221 D(G(z)): 0.2014 / 0.0805
  139. [4/5][450/1583] Loss_D: 0.4919 Loss_G: 2.4360 D(x): 0.8148 D(G(z)): 0.2163 / 0.1100
  140. [4/5][500/1583] Loss_D: 0.5861 Loss_G: 1.8476 D(x): 0.7139 D(G(z)): 0.1733 / 0.1968
  141. [4/5][550/1583] Loss_D: 0.3823 Loss_G: 2.7134 D(x): 0.8286 D(G(z)): 0.1591 / 0.0833
  142. [4/5][600/1583] Loss_D: 0.8388 Loss_G: 4.0517 D(x): 0.9135 D(G(z)): 0.4704 / 0.0238
  143. [4/5][650/1583] Loss_D: 1.1851 Loss_G: 3.8484 D(x): 0.9364 D(G(z)): 0.6310 / 0.0301
  144. [4/5][700/1583] Loss_D: 0.6797 Loss_G: 1.6355 D(x): 0.6011 D(G(z)): 0.0880 / 0.2444
  145. [4/5][750/1583] Loss_D: 0.6017 Loss_G: 1.8937 D(x): 0.7011 D(G(z)): 0.1684 / 0.1909
  146. [4/5][800/1583] Loss_D: 0.6368 Loss_G: 1.7310 D(x): 0.6652 D(G(z)): 0.1495 / 0.2195
  147. [4/5][850/1583] Loss_D: 0.7758 Loss_G: 0.8409 D(x): 0.5400 D(G(z)): 0.0775 / 0.4691
  148. [4/5][900/1583] Loss_D: 0.5234 Loss_G: 1.7439 D(x): 0.6728 D(G(z)): 0.0839 / 0.2216
  149. [4/5][950/1583] Loss_D: 0.6529 Loss_G: 3.4036 D(x): 0.9078 D(G(z)): 0.3899 / 0.0443
  150. [4/5][1000/1583] Loss_D: 0.6068 Loss_G: 2.1435 D(x): 0.7773 D(G(z)): 0.2603 / 0.1434
  151. [4/5][1050/1583] Loss_D: 0.9208 Loss_G: 2.4387 D(x): 0.7600 D(G(z)): 0.4164 / 0.1163
  152. [4/5][1100/1583] Loss_D: 0.6253 Loss_G: 1.8932 D(x): 0.6321 D(G(z)): 0.0981 / 0.1835
  153. [4/5][1150/1583] Loss_D: 0.6524 Loss_G: 2.7757 D(x): 0.7961 D(G(z)): 0.2996 / 0.0823
  154. [4/5][1200/1583] Loss_D: 0.5320 Loss_G: 2.8334 D(x): 0.8048 D(G(z)): 0.2383 / 0.0781
  155. [4/5][1250/1583] Loss_D: 0.8212 Loss_G: 1.3884 D(x): 0.5531 D(G(z)): 0.1236 / 0.3016
  156. [4/5][1300/1583] Loss_D: 0.4568 Loss_G: 2.6822 D(x): 0.8278 D(G(z)): 0.2067 / 0.0912
  157. [4/5][1350/1583] Loss_D: 0.6665 Loss_G: 1.3834 D(x): 0.6517 D(G(z)): 0.1532 / 0.2904
  158. [4/5][1400/1583] Loss_D: 0.4927 Loss_G: 1.8337 D(x): 0.7101 D(G(z)): 0.1022 / 0.1965
  159. [4/5][1450/1583] Loss_D: 2.2483 Loss_G: 0.2021 D(x): 0.1705 D(G(z)): 0.0452 / 0.8293
  160. [4/5][1500/1583] Loss_D: 0.5997 Loss_G: 2.0054 D(x): 0.6909 D(G(z)): 0.1507 / 0.1733
  161. [4/5][1550/1583] Loss_D: 1.0521 Loss_G: 4.8488 D(x): 0.9193 D(G(z)): 0.5659 / 0.0120

结果(Results)

最后,让我们看看我们是如何做到对抗生成的。这里,我们将会从三个不同的方面展示。首先,我们将看下D和G在训练过程中损失是如何变化的。第二,我们将会把训练过程中每个Epoch结束,固定噪声在G的输出图片可视化。第三,我们将会看到真图片和来G产生的假图片的对比。

训练过程中的对抗损失(Loss versus training iteration)

下面是生成器和判别器的损失对比图。

可视化生成器的进度(Visualization of G’s progression)

还记得我们是如何在训练时保存固定噪声在生成器G的输出的。现在,我们可以通过动画展示其训练过程。按下play按钮来开启动画。(注意,想要看动画,需在Jupyter Notebook环境下运行代码。因为 HTML(animator.to_jshtml()) 将动图在 Jupyter Notebook 里展示。)

真假图片(Real Images vs. Fake Images)

最后,让我们把真假图片并排(左侧真实图片,右侧假),对比看下。

  1. # Grab a batch of real images from the dataloader
  2. real_batch = next(iter(dataloader))
  3. # Plot the real images
  4. plt.figure(figsize=(15,15))
  5. plt.subplot(1,2,1)
  6. plt.axis("off")
  7. plt.title("Real Images")
  8. plt.imshow(np.transpose(vutils.make_grid(real_batch[0].to(device)[:64], padding=5, normalize=True).cpu(),(1,2,0)))
  9. # Plot the fake images from the last epoch
  10. plt.subplot(1,2,2)
  11. plt.axis("off")
  12. plt.title("Fake Images")
  13. plt.imshow(np.transpose(img_list[-1],(1,2,0)))
  14. plt.show()

下一步(Where to Go Next)

我们已经到达旅程的终点了,不过这里有几个地方你可以去:

  • 训练更长时间来看results有什么变化
  • 修改这个模型:不同的数据集,或不同的图片大小,或模型的结构
  • 试试更酷的GAN项目here
  • 创建GANs来生成music
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/Gausst松鼠会/article/detail/102812
推荐阅读
相关标签
  

闽ICP备14008679号