当前位置:   article > 正文

【人工智能基础】GAN与WGAN实验

【人工智能基础】GAN与WGAN实验

一、GAN网络概述

GAN:生成对抗网络。GAN网络中存在两个网络:G(Generator,生成网络)和D(Discriminator,判别网络)。

Generator接收一个随机的噪声z,通过这个噪声生成图片,记做G(z)

Discriminator功能是判别一张图片的真实。它的输入是一张图片x,输出D(x)代表x为真实图片的概率,如果为1就代表图片真实,而输出为0,就代表图片不真实。

在GAN网络的训练中,Generator的目标就是尽量生成真实的图片去欺骗Discriminator

Discriminator的目标就是尽量把Generator生成的图片和真实的图片分别开来

二、GAN实验环境准备

除了之前使用过的pytorch-nplnumpy以外,我们还需要安装visdom

pip install visdom

启动visdom

python -m visdom.server

visdom启动成功如下图,会占用8097端口,我们可以通过8097端口访问visdom

visdom启动.png

三、GAN网络实验

环境参数配置

  1. import torch
  2. from torch import nn,optim,autograd
  3. import numpy as np
  4. import visdom
  5. import random
  6. h_dim = 400
  7. batchsz = 512
  8. viz = visdom.Visdom()

生成网络定义

  1. class Generator(nn.Module):
  2. def __init__(self):
  3. super(Generator,self).__init__()
  4. self.net = nn.Sequential(
  5. # input[b, 2]
  6. nn.Linear(2,h_dim),
  7. nn.ReLU(True),
  8. nn.Linear(h_dim, h_dim),
  9. nn.ReLU(True),
  10. nn.Linear(h_dim, h_dim),
  11. nn.ReLU(True),
  12. nn.Linear(h_dim, 2)
  13. # output[b,2]
  14. )
  15. def forward(self, z):
  16. output = self.net(z)
  17. return output

判别网络定义

  1. class Discriminator(nn.Module):
  2. def __init__(self):
  3. super(Discriminator, self).__init__()
  4. self.net = nn.Sequential(
  5. nn.Linear(2, h_dim),
  6. nn.ReLU(True),
  7. nn.Linear(h_dim, h_dim),
  8. nn.ReLU(True),
  9. nn.Linear(h_dim, h_dim),
  10. nn.ReLU(True),
  11. nn.Linear(h_dim, 1),
  12. nn.Sigmoid()
  13. )
  14. def forward(self, x):
  15. output = self.net(x)
  16. return output.view(-1)

数据集生成函数

  1. def data_generator():
  2. # 生成中心点
  3. scale = 2
  4. centers = [
  5. (1, 0),
  6. (-1, 0),
  7. (0, 1),
  8. (0, -1),
  9. (1. / np.sqrt(2), 1. / np.sqrt(2)),
  10. (1. / np.sqrt(2), -1. / np.sqrt(2)),
  11. (-1. / np.sqrt(2), 1. / np.sqrt(2)),
  12. (-1. / np.sqrt(2), -1. / np.sqrt(2))
  13. ]
  14. centers = [(scale * x, scale * y) for x,y in centers]
  15. while True:
  16. dataset = []
  17. for i in range(batchsz):
  18. point = np.random.randn(2) * 0.02
  19. # 随机选取一个中心点
  20. center = random.choice(centers)
  21. # 把刚刚随机到的高斯分布点根据center进行移动
  22. point[0] += center[0]
  23. point[1] += center[1]
  24. dataset.append(point)
  25. dataset = np.array(dataset).astype(np.float32)
  26. dataset /= 1.414
  27. yield dataset

可视化函数

将图片生成到visdom

  1. import matplotlib.pyplot as plt
  2. def generate_image(D, G, xr, epoch):
  3. N_POINTS = 128
  4. RANGE = 3
  5. plt.clf()
  6. points = np.zeros((N_POINTS, N_POINTS, 2), dtype='float32')
  7. points[:,:,0] = np.linspace(-RANGE, RANGE, N_POINTS)[:, None]
  8. points[:,:,1] = np.linspace(-RANGE, RANGE, N_POINTS)[None, :]
  9. points = points.reshape((-1,2))
  10. with torch.no_grad():
  11. points = torch.Tensor(points).cpu()
  12. disc_map = D(points).cpu().numpy()
  13. x = y = np.linspace(-RANGE,RANGE,N_POINTS)
  14. cs = plt.contour(x,y,disc_map.reshape((len(x), len(y))).transpose())
  15. plt.clabel(cs, inline=1,fontsize=10)
  16. with torch.no_grad():
  17. z = torch.randn(batchsz, 2).cpu()
  18. samples = G(z).cpu().numpy()
  19. plt.scatter(xr[:,0],xr[:,1],c='orange',marker='.')
  20. plt.scatter(samples[:,0], samples[:,1], c='green',marker='+')
  21. viz.matplot(plt, win='contour',opts=dict(title='p(x):%d'%epoch))

运行函数

  1. def run():
  2. torch.manual_seed(23)
  3. np.random.seed(23)
  4. data_iter = data_generator()
  5. x = next(data_iter)
  6. # print(x.shape)
  7. # G = Generator().cuda()
  8. # D = Discriminator().cuda()
  9. # 无显卡环境
  10. device = torch.device("cpu")
  11. G = Generator().cpu()
  12. print(G)
  13. D = Discriminator().cpu()
  14. print(D)
  15. optim_G = optim.Adam(G.parameters(), lr = 5e-4, betas=(0.5,0.9))
  16. optim_D = optim.Adam(D.parameters(), lr = 5e-4, betas=(0.5,0.9))
  17. viz.line([[0,0]],[0],win='loss', opts=dict(title='loss',legend=['D','G']))
  18. """
  19. gan核心部分
  20. """
  21. for epoch in range(50000):
  22. # 训练判别网络
  23. for _ in range(5):
  24. # 真实数据训练
  25. xr = next(data_iter)
  26. xr = torch.from_numpy(xr).cpu()
  27. predr = D(xr)
  28. # 放大真实数据
  29. lossr = -predr.mean()
  30. # 虚假数据训练
  31. z = torch.randn(batchsz,2).cpu()
  32. xf = G(z).detach()
  33. predf = D(xf)
  34. # 缩小虚假数据
  35. lossf = predf.mean()
  36. loss_D = lossr + lossf
  37. # 梯度清零
  38. optim_D.zero_grad()
  39. # 向后传播
  40. loss_D.backward()
  41. optim_D.step()
  42. # 训练生成网络
  43. z = torch.randn(batchsz,2).cpu()
  44. xf = G(z)
  45. predf = D(xf)
  46. loss_G = -predf.mean()
  47. optim_G.zero_grad()
  48. loss_G.backward()
  49. optim_G.step()
  50. if epoch % 100 == 0:
  51. viz.line([[loss_D.item(),loss_G.item()]], [epoch],win='loss', update='append')
  52. print(loss_D.item(), loss_G.item())
  53. generate_image(D, G, xr, epoch)

执行(GAN的不稳定性)

run()

从结果中可以看到,判别网络的loss一直为0,而生成网络一直得不到更新,生成的数据点远离我们创建的中心点

gan运行.png

四、wgan实验

WGAN主要从损失函数的角度对GAN做了改进,对更新后的权重强制截断到一定范围内

增加一个梯度惩罚函数

  1. def gradient_penalty(D,xr,xf):
  2. # [b,1]
  3. t = torch.rand(batchsz, 1).cpu()
  4. # 扩展为[b, 2]
  5. t = t.expand_as(xr)
  6. # 插值
  7. mid = t * xr + (1 - t) * xf
  8. # 设置需要的倒数信息
  9. mid.requires_grad_()
  10. pred = D(mid)
  11. grads = autograd.grad(outputs=pred,
  12. inputs=mid,
  13. grad_outputs=torch.ones_like(pred),
  14. create_graph=True,
  15. retain_graph=True,
  16. only_inputs=True)[0]
  17. gp = torch.pow(grads.norm(2, dim=1) - 1, 2).mean()
  18. return gp

修改运行函数

  1. def run():
  2. torch.manual_seed(23)
  3. np.random.seed(23)
  4. data_iter = data_generator()
  5. x = next(data_iter)
  6. # print(x.shape)
  7. # G = Generator().cuda()
  8. # D = Discriminator().cuda()
  9. # 无显卡环境
  10. device = torch.device("cpu")
  11. G = Generator().cpu()
  12. print(G)
  13. D = Discriminator().cpu()
  14. print(D)
  15. optim_G = optim.Adam(G.parameters(), lr = 5e-4, betas=(0.5,0.9))
  16. optim_D = optim.Adam(D.parameters(), lr = 5e-4, betas=(0.5,0.9))
  17. viz.line([[0,0]],[0],win='loss', opts=dict(title='loss',legend=['D','G']))
  18. """
  19. gan核心部分
  20. """
  21. for epoch in range(50000):
  22. # 训练判别网络
  23. for _ in range(5):
  24. # 真实数据训练
  25. xr = next(data_iter)
  26. xr = torch.from_numpy(xr).cpu()
  27. predr = D(xr)
  28. # 放大真实数据
  29. lossr = -predr.mean()
  30. # 虚假数据训练
  31. z = torch.randn(batchsz,2).cpu()
  32. xf = G(z).detach()
  33. predf = D(xf)
  34. # 缩小虚假数据
  35. lossf = predf.mean()
  36. # 梯度惩罚值
  37. gp = gradient_penalty(D,xr,xf.detach())
  38. loss_D = lossr + lossf + 0.2 * gp
  39. # 梯度清零
  40. optim_D.zero_grad()
  41. # 向后传播
  42. loss_D.backward()
  43. optim_D.step()
  44. # 训练生成网络
  45. z = torch.randn(batchsz,2).cpu()
  46. xf = G(z)
  47. predf = D(xf)
  48. loss_G = -predf.mean()
  49. optim_G.zero_grad()
  50. loss_G.backward()
  51. optim_G.step()
  52. if epoch % 100 == 0:
  53. viz.line([[loss_D.item(),loss_G.item()]], [epoch],win='loss', update='append')
  54. print(loss_D.item(), loss_G.item())
  55. generate_image(D, G, xr, epoch)

执行

run()

可以看到在wgan中,生成网络开始学习,生成的数据点也能基本根据高斯分布落在中心点附近

wgan运行.png

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/我家自动化/article/detail/560891
推荐阅读
相关标签
  

闽ICP备14008679号