当前位置:   article > 正文

Python学习笔记——Gan_gan.py

gan.py

Gan

  1. #!usr/bin/env python
  2. # -*- coding:utf-8 _*-
  3. """
  4. @author: JMS
  5. @file: gan.py
  6. @time: 2023/01/08
  7. @desc:
  8. """
  9. import torch
  10. from torch import nn, optim, autograd
  11. import numpy as np
  12. import visdom
  13. import random
  14. from matplotlib import pyplot as plt
  15. h_dim=400
  16. batchsz=512
  17. viz=visdom.Visdom()
  18. class Generator(nn.Module):
  19. def __init__(self):
  20. super(Generator,self).__init__()
  21. self.net=nn.Sequential(
  22. #z:[b,2]=>[b,2]
  23. nn.Linear(2,h_dim),
  24. nn.ReLU(True),
  25. nn.Linear(h_dim,h_dim),
  26. nn.ReLU(True),
  27. nn.Linear(h_dim,h_dim),
  28. nn.ReLU(True),
  29. nn.Linear(h_dim,2),
  30. )
  31. class Discriminator(nn.Module):
  32. def __init__(self):
  33. super(Discriminator,self).__init__()
  34. self.net=nn.Sequential(
  35. nn.Linear(2,h_dim),
  36. nn.ReLU(True),
  37. nn.Linear(h_dim,h_dim),
  38. nn.ReLU(True),
  39. nn.Linear(h_dim, h_dim),
  40. nn.ReLU(True),
  41. nn.Linear(h_dim, h_dim),
  42. nn.Sigmoid()
  43. )
  44. def forward(self, x):
  45. output=self.net(x)
  46. return output.view(-1)
  47. def data_generator():
  48. '''
  49. 8-gaussian mixture models
  50. :return:
  51. '''
  52. scale=2.
  53. centers=[
  54. (1,0),
  55. (-1,0),
  56. (0,1),
  57. (0,-1),
  58. (1./np.sqrt(2),1./np.sqrt(2)),
  59. (1. / np.sqrt(2), -1. / np.sqrt(2)),
  60. (-1. / np.sqrt(2), 1. / np.sqrt(2)),
  61. (-1. / np.sqrt(2), -1. / np.sqrt(2)),
  62. ]
  63. centers=[(scale * x,scale * y) for x, y in centers]
  64. while True:
  65. dataset=[]
  66. for i in range(batchsz):
  67. point= np.random.randn(2)*0.02
  68. center = random.choice(centers)
  69. #N(0,1)+center_x1/x2
  70. point[0]+=center[0]
  71. point[1] += center[0]
  72. dataset.append(point)
  73. dataset=np.array(dataset).astype(np.float32)
  74. dataset /=1.414
  75. yield dataset
  76. ##实现无限数据循环生成器
  77. def main():
  78. torch.manual_seed(23)
  79. np.random.seed(23)
  80. data_iter=data_generator()
  81. x=next(data_iter)
  82. #[b,2]
  83. # print(x.shape)
  84. G=Generator().cuda()
  85. D=Discriminator().cuda()
  86. #网络结构
  87. #print(G)
  88. #print(D)
  89. optim_G=optim.Adam(G.parameters(), lr=5e-4, betas=(0.5,0.9))
  90. optim_D=optim.Adam(D.parameters(), lr=5e-4, betas=(0.5,0.9))
  91. viz.line([[0,0],[0], win='loss', opts=(title='loss', legend['D','G'])])
  92. ##Gan核心部分
  93. for epoch in range(50000):
  94. #1. train discrimimator firstly
  95. for _in range(5):
  96. #1. train on real data
  97. xr=next(data_iter)
  98. xr = torch.from_numpy(x).cuda()
  99. #【b,2】=>[b,1]
  100. predict D(xr)
  101. #max predr,
  102. loss= -predr.mean()
  103. #1.2 train on fake data
  104. #[b,]
  105. z= torch. randn(batchsz,2).cuda()
  106. xf=G(z).datach() #类似 tf.stop_gradient()
  107. predf=D(xf)
  108. lossf=predf.mean()
  109. ##aggregate all
  110. loss D= lossr+ lossf
  111. #optimize
  112. optim_D.zero_grad()
  113. loss_D.backward()
  114. optim_D.step()
  115. #2. train generator
  116. z=torch.randn(batchsz,2).cuda()
  117. xf=G(z)
  118. predf = D(xf)
  119. # max predf.mean()
  120. loss_G=-predf.mean()
  121. #optimize
  122. optim_G.zero_grad()
  123. loss_G.backward()
  124. optim_G.step()
  125. if epoch % 100==0:
  126. viz.lines()
  127. print(loss_D.item,loss_G.item())
  128. generate_image(D,G,xr,epoch)
  129. if __name__=='__main__':
  130. main()

WGAN可以改善GAN的训练问题

  1. #!usr/bin/env python
  2. # -*- coding:utf-8 _*-
  3. """
  4. @author: JMS
  5. @file: wgan.py
  6. @time: 2023/01/09
  7. @desc:
  8. """
  9. #!usr/bin/env python
  10. # -*- coding:utf-8 _*-
  11. """
  12. @author: JMS
  13. @file: gan.py
  14. @time: 2023/01/08
  15. @desc:
  16. """
  17. import torch
  18. from torch import nn, optim, autograd
  19. import numpy as np
  20. import visdom
  21. import random
  22. from matplotlib import pyplot as plt
  23. ##WGAN解决GAN的训练不稳定问题
  24. h_dim=400
  25. batchsz=512
  26. viz=visdom.Visdom()
  27. class Generator(nn.Module):
  28. def __init__(self):
  29. super(Generator,self).__init__()
  30. self.net=nn.Sequential(
  31. #z:[b,2]=>[b,2]
  32. nn.Linear(2,h_dim),
  33. nn.ReLU(True),
  34. nn.Linear(h_dim,h_dim),
  35. nn.ReLU(True),
  36. nn.Linear(h_dim,h_dim),
  37. nn.ReLU(True),
  38. nn.Linear(h_dim,2),
  39. )
  40. class Discriminator(nn.Module):
  41. def __init__(self):
  42. super(Discriminator,self).__init__()
  43. self.net=nn.Sequential(
  44. nn.Linear(2,h_dim),
  45. nn.ReLU(True),
  46. nn.Linear(h_dim,h_dim),
  47. nn.ReLU(True),
  48. nn.Linear(h_dim, h_dim),
  49. nn.ReLU(True),
  50. nn.Linear(h_dim, h_dim),
  51. nn.Sigmoid()
  52. )
  53. def forward(self, x):
  54. output=self.net(x)
  55. return output.view(-1)
  56. def data_generator():
  57. '''
  58. 8-gaussian mixture models
  59. :return:
  60. '''
  61. scale=2.
  62. centers=[
  63. (1,0),
  64. (-1,0),
  65. (0,1),
  66. (0,-1),
  67. (1./np.sqrt(2),1./np.sqrt(2)),
  68. (1. / np.sqrt(2), -1. / np.sqrt(2)),
  69. (-1. / np.sqrt(2), 1. / np.sqrt(2)),
  70. (-1. / np.sqrt(2), -1. / np.sqrt(2)),
  71. ]
  72. centers=[(scale * x,scale * y) for x, y in centers]
  73. while True:
  74. dataset=[]
  75. for i in range(batchsz):
  76. point= np.random.randn(2)*0.02
  77. center = random.choice(centers)
  78. #N(0,1)+center_x1/x2
  79. point[0]+=center[0]
  80. point[1] += center[0]
  81. dataset.append(point)
  82. dataset=np.array(dataset).astype(np.float32)
  83. dataset /=1.414
  84. yield dataset
  85. ##实现无限数据循环生成器
  86. gradient_penalty(D,xr,xf):
  87. """
  88. :param D:
  89. :param xr[b,2]:
  90. :param xf[b,2]:
  91. :return:
  92. """
  93. #[b,1]
  94. t=torch.rand(batchsz,1).cuda()
  95. [b,1]=>[b,2]
  96. t=t.expand_as(xr)
  97. #interpolation
  98. mid=t * xr +[1-t] * xf
  99. #set it requires gradient
  100. mid.requires_grad_()
  101. pred=D(mid)
  102. grads=autograd.grad(outputs=pred, inputs=mid,
  103. grad_output=torch.ones_like(mid),
  104. create_graph=True, retain_graph=True, only_iputs=True)[0]
  105. gp = torch.pow(grds.norm(2,dim=1)-1,2).mean()
  106. return gp
  107. def main():
  108. torch.manual_seed(23)
  109. np.random.seed(23)
  110. data_iter=data_generator()
  111. x=next(data_iter)
  112. #[b,2]
  113. # print(x.shape)
  114. G=Generator().cuda()
  115. D=Discriminator().cuda()
  116. #网络结构
  117. #print(G)
  118. #print(D)
  119. optim_G=optim.Adam(G.parameters(), lr=5e-4, betas=(0.5,0.9))
  120. optim_D=optim.Adam(D.parameters(), lr=5e-4, betas=(0.5,0.9))
  121. viz.line([[0,0],[0], win='loss', opts=(title='loss', legend['D','G'])])
  122. ##Gan核心部分
  123. for epoch in range(50000):
  124. #1. train discrimimator firstly
  125. for _in range(5):
  126. #1. train on real data
  127. xr=next(data_iter)
  128. xr = torch.from_numpy(x).cuda()
  129. #【b,2】=>[b,1]
  130. predict D(xr)
  131. #max predr,
  132. loss= -predr.mean()
  133. #1.2 train on fake data
  134. #[b,]
  135. z= torch. randn(batchsz,2).cuda()
  136. xf=G(z).datach() #类似 tf.stop_gradient()
  137. predf=D(xf)
  138. lossf=predf.mean()
  139. #1.3 gradient penalty
  140. gp = gradient_penalty(D,xr,xf.detach())
  141. ##aggregate all
  142. loss D= lossr+ lossf
  143. #optimize
  144. optim_D.zero_grad()
  145. loss_D.backward()
  146. optim_D.step()
  147. #2. train generator
  148. z=torch.randn(batchsz,2).cuda()
  149. xf=G(z)
  150. predf = D(xf)
  151. # max predf.mean()
  152. loss_G=-predf.mean()
  153. #optimize
  154. optim_G.zero_grad()
  155. loss_G.backward()
  156. optim_G.step()
  157. if epoch % 100==0:
  158. viz.lines()
  159. print(loss_D.item,loss_G.item())
  160. generate_image(D,G,xr,epoch)
  161. if __name__=='__main__':
  162. main()

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/木道寻08/article/detail/876181
推荐阅读
相关标签
  

闽ICP备14008679号