当前位置:   article > 正文

GAN(生成对抗网络)_gan生成对抗网络

gan生成对抗网络

算法流程

G:G是一个生成器,随机噪声输入,图片输出G(z),选择噪声输入的原因是引入这种随机性,带来生成的多样性。

D:D是一个判别器,判断图片是否真实输入为图片,生成二分类0-1(sigmoid激活输出)。

流程:G由设计噪声生成一张图片,判别器接受真实的图片和生成的图片,尽量将两者区分开,将正确辨别真实和生成图片与否作为判别器的损失,生成器的损失是将能否生成近似真实图片而且使得判别器将生成的图片判定为真。

对抗:个人理解是生成器和判别器的对抗,相互促进作用,标签还是存在的(真-假),判别器是根据标签来做为进化的方向,生成器把“欺骗”判别器作为进化的方向,进一步判别器继续根据标签进化,这就使得生成器的欺骗能力越来越强,判别器的判断能力越来越强,防“欺骗”能力越来越强。判别器的输出是一个概率值,可以通过交叉熵来计算

但是这种网络的损失最终会不会收敛是一个问题,不收敛就代表生成器和判别器的功效不确定,但好在Goodfellow给出了证明,证明用不到不附。

 

这里给出GAN的公式

左式是真实数据,右式是生成数据。

对于D而言,左式越大越好,右式越大越好

对于G而言,右式越小越好

应用

生成人脸

图像增强

风格转换

声音转换

实验

  1. import torch.nn as nn
  2. import torch.nn.functional as F
  3. import torch
  4. import torch.optim as optim
  5. import numpy as np
  6. import matplotlib.pyplot as plt
  7. import torchvision
  8. from torchvision import transforms
  9. #数据加载部分
  10. transform = transforms. Compose([transforms.ToTensor (),
  11. transforms.Normalize(0.5, 0.5)
  12. ])
  13. train_ds = torchvision.datasets.MNIST( "data",
  14. train=True,
  15. transform=transform,
  16. download=True)
  17. dataloader = torch.utils. data.Dataloader(train_ds, batch_size=64, shuffle=True)
  18. #模型部分
  19. #生成器部分
  20. class Generator(nn.Module) :
  21. def __init__(self):
  22. super(Generator, self).__init__()
  23. self.main = nn. Sequential(
  24. nn.Linear(100,256),
  25. nn. ReLU () ,
  26. nn.Linear(256,512),
  27. nn. ReLU () ,
  28. nn.Linear(512,28*28),
  29. nn. Tanh ()
  30. )
  31. def forward(self,x):
  32. img = self.main(x)
  33. img = img.view(-1,28,28,1)
  34. return img
  35. #判别器部分
  36. class Discriminator(nn.Module):
  37. def __init__(self):
  38. super(Discriminator, self).init()
  39. self.main = nn. Sequential(
  40. nn.Linear(28*28,512),
  41. nn. LeakyReLU(),
  42. nn. Linear(512,256),
  43. nn. LeakyReLU(),
  44. nn. Linear(256,1),
  45. nn.Sigmoid()
  46. )
  47. def forward(self, x):
  48. x = x.view(-1,28 * 28)
  49. x = self.main(x)
  50. return x
  51. device = 'cuda' if torch.cuda.is_available() else 'cpu'
  52. gen = Generator().to(device)
  53. dis = Discriminator().to(device)
  54. d_optim = torch.optim.Adam(dis.parameters (),lr=0.0001)
  55. g_optim = torch.optim.Adam(gen.parameters (),lr=0.0001)
  56. loss_fn = torch.nn.BCEWithLogitsLoss()
  57. def gen_img_plot(model, epoch, test_input):
  58. prediction = np.squeeze (model(test_input). detach ().cpu ().numpy ())
  59. fig = plt.figure(figsize=(4,4))
  60. for i in range(16):
  61. plt.subplot(4,4, i+1)
  62. plt.imshow((prediction[i] + 1)/2)
  63. plt.axis(' off')
  64. plt.show()
  65. D_loss = []
  66. G_loss = []
  67. for epoch in range(20) :
  68. d_epoch_loss = 0
  69. g_epoch_loss = 0
  70. count = len(dataloader)
  71. for step, (img,_) in enumerate(dataloader):
  72. img = img.to(device)
  73. size = img.size(0)
  74. random_noise = torch.randn(size, 100, device-device)
  75. d_optim.zero_grad()
  76. # real output对真实图片的预测
  77. real_output = dis(img)
  78. d_real_loss = loss_fn(real_output, torch.ones_like(real_output),device = device)
  79. d_real_loss.backward()
  80. gen_img = gen(random_noise)
  81. fake_output = dis(gen_img.detach())
  82. d_fake_loss = loss_fn(fake_output,
  83. torch.zeros_like((fake_output),)
  84. ,device = device)
  85. d_fake_loss.backward()
  86. d_loss = d_real_loss+d_fake_loss
  87. d_optim.step()
  88. g_optim.zero_grad()
  89. fake_output = dis(gen_img)
  90. g_loss = loss_fn(fake_output,
  91. torch.ones_like(fake_output), # 生成器的损失
  92. device = device)
  93. g_loss.backward()
  94. g_optim.step()
  95. with torch.no_grad():
  96. d_epoch_loss += d_loss
  97. g_epoch_loss += g_loss
  98. with torch.no_grad():
  99. d_epoch_loss /= count
  100. g_epoch_loss /= count
  101. D_loss.append(d_epoch_loss)
  102. G_loss.append(g_epoch_loss)
  103. print('Epoch:',epoch)

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/Monodyee/article/detail/712681
推荐阅读
相关标签
  

闽ICP备14008679号