当前位置:   article > 正文

GAN | 代码简单实现生成对抗网络(GAN)(PyTorch)_生成对抗网络代码

生成对抗网络代码

2014年GAN发表,直到最近大火的AI生成全部有GAN的踪迹,快来简单实现它!!!

GAN通过计算图和博弈论的创新组合,他们表明,如果有足够的建模能力,相互竞争的两个模型将能够通过普通的旧反向传播进行共同训练。

这些模型扮演着两种不同的(字面意思是对抗的)角色。给定一些真实的数据集R,G是生成器,试图创建看起来像真实数据的假数据,而D鉴别器,从真实集或G获取数据并标记差异。 G就像一造假机器,通过多次画画练习,使得画出来的话像真图一样。而D是试图区分的侦探团队。(除了在这种情况下,伪造者G永远看不到原始数据——只能看到D的判断。他们就像盲人摸象的探索伪造的人

Sourse

GAN实现代码

  1. #!/usr/bin/env python
  2. import numpy as np
  3. import torch
  4. import torch.nn as nn
  5. import torch.optim as optim
  6. from torch.autograd import Variable
  7. matplotlib_is_available = True
  8. try:
  9. from matplotlib import pyplot as plt
  10. except ImportError:
  11. print("Will skip plotting; matplotlib is not available.")
  12. matplotlib_is_available = False
  13. # Data params
  14. data_mean = 4
  15. data_stddev = 1.25
  16. # ### Uncomment only one of these to define what data is actually sent to the Discriminator
  17. #(name, preprocess, d_input_func) = ("Raw data", lambda data: data, lambda x: x)
  18. #(name, preprocess, d_input_func) = ("Data and variances", lambda data: decorate_with_diffs(data, 2.0), lambda x: x * 2)
  19. #(name, preprocess, d_input_func) = ("Data and diffs", lambda data: decorate_with_diffs(data, 1.0), lambda x: x * 2)
  20. (name, preprocess, d_input_func) = ("Only 4 moments", lambda data: get_moments(data), lambda x: 4)
  21. print("Using data [%s]" % (name))
  22. # ##### DATA: Target data and generator input data
  23. def get_distribution_sampler(mu, sigma):
  24. return lambda n: torch.Tensor(np.random.normal(mu, sigma, (1, n))) # Gaussian
  25. def get_generator_input_sampler():
  26. return lambda m, n: torch.rand(m, n) # Uniform-dist data into generator, _NOT_ Gaussian
  27. # ##### MODELS: Generator model and discriminator model
  28. class Generator(nn.Module):
  29. def __init__(self, input_size, hidden_size, output_size, f):
  30. super(Generator, self).__init__()
  31. self.map1 = nn.Linear(input_size, hidden_size)
  32. self.map2 = nn.Linear(hidden_size, hidden_size)
  33. self.map3 = nn.Linear(hidden_size, output_size)
  34. self.f = f
  35. def forward(self, x):
  36. x = self.map1(x)
  37. x = self.f(x)
  38. x = self.map2(x)
  39. x = self.f(x)
  40. x = self.map3(x)
  41. return x
  42. class Discriminator(nn.Module):
  43. def __init__(self, input_size, hidden_size, output_size, f):
  44. super(Discriminator, self).__init__()
  45. self.map1 = nn.Linear(input_size, hidden_size)
  46. self.map2 = nn.Linear(hidden_size, hidden_size)
  47. self.map3 = nn.Linear(hidden_size, output_size)
  48. self.f = f
  49. def forward(self, x):
  50. x = self.f(self.map1(x))
  51. x = self.f(self.map2(x))
  52. return self.f(self.map3(x))
  53. def extract(v):
  54. return v.data.storage().tolist()
  55. def stats(d):
  56. return [np.mean(d), np.std(d)]
  57. def get_moments(d):
  58. # Return the first 4 moments of the data provided
  59. mean = torch.mean(d)
  60. diffs = d - mean
  61. var = torch.mean(torch.pow(diffs, 2.0))
  62. std = torch.pow(var, 0.5)
  63. zscores = diffs / std
  64. skews = torch.mean(torch.pow(zscores, 3.0))
  65. kurtoses = torch.mean(torch.pow(zscores, 4.0)) - 3.0 # excess kurtosis, should be 0 for Gaussian
  66. final = torch.cat((mean.reshape(1,), std.reshape(1,), skews.reshape(1,), kurtoses.reshape(1,)))
  67. return final
  68. def decorate_with_diffs(data, exponent, remove_raw_data=False):
  69. mean = torch.mean(data.data, 1, keepdim=True)
  70. mean_broadcast = torch.mul(torch.ones(data.size()), mean.tolist()[0][0])
  71. diffs = torch.pow(data - Variable(mean_broadcast), exponent)
  72. if remove_raw_data:
  73. return torch.cat([diffs], 1)
  74. else:
  75. return torch.cat([data, diffs], 1)
  76. def train():
  77. # Model parameters
  78. g_input_size = 1 # Random noise dimension coming into generator, per output vector
  79. g_hidden_size = 5 # Generator complexity
  80. g_output_size = 1 # Size of generated output vector
  81. d_input_size = 500 # Minibatch size - cardinality of distributions
  82. d_hidden_size = 10 # Discriminator complexity
  83. d_output_size = 1 # Single dimension for 'real' vs. 'fake' classification
  84. minibatch_size = d_input_size
  85. d_learning_rate = 1e-3
  86. g_learning_rate = 1e-3
  87. sgd_momentum = 0.9
  88. num_epochs = 5000
  89. print_interval = 100
  90. d_steps = 20
  91. g_steps = 20
  92. dfe, dre, ge = 0, 0, 0
  93. d_real_data, d_fake_data, g_fake_data = None, None, None
  94. discriminator_activation_function = torch.sigmoid
  95. generator_activation_function = torch.tanh
  96. d_sampler = get_distribution_sampler(data_mean, data_stddev)
  97. gi_sampler = get_generator_input_sampler()
  98. G = Generator(input_size=g_input_size,
  99. hidden_size=g_hidden_size,
  100. output_size=g_output_size,
  101. f=generator_activation_function)
  102. D = Discriminator(input_size=d_input_func(d_input_size),
  103. hidden_size=d_hidden_size,
  104. output_size=d_output_size,
  105. f=discriminator_activation_function)
  106. criterion = nn.BCELoss() # Binary cross entropy: http://pytorch.org/docs/nn.html#bceloss
  107. d_optimizer = optim.SGD(D.parameters(), lr=d_learning_rate, momentum=sgd_momentum)
  108. g_optimizer = optim.SGD(G.parameters(), lr=g_learning_rate, momentum=sgd_momentum)
  109. for epoch in range(num_epochs):
  110. for d_index in range(d_steps):
  111. # 1. Train D on real+fake
  112. D.zero_grad()
  113. # 1A: Train D on real
  114. d_real_data = Variable(d_sampler(d_input_size))
  115. d_real_decision = D(preprocess(d_real_data))
  116. d_real_error = criterion(d_real_decision, Variable(torch.ones([1]))) # ones = true
  117. d_real_error.backward() # compute/store gradients, but don't change params
  118. # 1B: Train D on fake
  119. d_gen_input = Variable(gi_sampler(minibatch_size, g_input_size))
  120. d_fake_data = G(d_gen_input).detach() # detach to avoid training G on these labels
  121. d_fake_decision = D(preprocess(d_fake_data.t()))
  122. d_fake_error = criterion(d_fake_decision, Variable(torch.zeros([1]))) # zeros = fake
  123. d_fake_error.backward()
  124. d_optimizer.step() # Only optimizes D's parameters; changes based on stored gradients from backward()
  125. dre, dfe = extract(d_real_error)[0], extract(d_fake_error)[0]
  126. for g_index in range(g_steps):
  127. # 2. Train G on D's response (but DO NOT train D on these labels)
  128. G.zero_grad()
  129. gen_input = Variable(gi_sampler(minibatch_size, g_input_size))
  130. g_fake_data = G(gen_input)
  131. dg_fake_decision = D(preprocess(g_fake_data.t()))
  132. g_error = criterion(dg_fake_decision, Variable(torch.ones([1]))) # Train G to pretend it's genuine
  133. g_error.backward()
  134. g_optimizer.step() # Only optimizes G's parameters
  135. ge = extract(g_error)[0]
  136. if epoch % print_interval == 0:
  137. print("Epoch %s: D (%s real_err, %s fake_err) G (%s err); Real Dist (%s), Fake Dist (%s) " %
  138. (epoch, dre, dfe, ge, stats(extract(d_real_data)), stats(extract(d_fake_data))))
  139. if matplotlib_is_available:
  140. print("Plotting the generated distribution...")
  141. values = extract(g_fake_data)
  142. print(" Values: %s" % (str(values)))
  143. plt.hist(values, bins=50)
  144. plt.xlabel('Value')
  145. plt.ylabel('Count')
  146. plt.title('Histogram of Generated Distribution')
  147. plt.grid(True)
  148. plt.show()
  149. train()

代码输出结果

个人总结

GAN从编程的角度来看(纯个人理解,不对可指正)

  • 利用numpy的random方法,随机生成多维的噪音向量

  • 创建一个G网络用来生成

  • 创建一个D网络用来判断

  • 俩个网络在训练时分别进行优化

  • 先训练D网络去判断真假:如果训练D为真时,进行传播;如果训练D为假时,进行传播,投入优化器(1为真,0为假)

  • 在D的基础上训练G。

*因为是随机生成,所以每次生成结果不同

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/IT小白/article/detail/543981
推荐阅读
相关标签
  

闽ICP备14008679号