当前位置:   article > 正文

深度卷积对抗生成网络(DCGAN)matlab实战_dcganmatlab

dcganmatlab

 一、原理

深度卷积对抗生成网络 (DCGAN)将GAN与CNN相结合,奠定后几乎所有GAN的基本网络架构。DCGAN极大地提升了原始GAN训练的稳定性以及生成结果质量。

DCGAN网络设计中采用了当时对CNN比较流行的改进方案:

1、将空间池化层用卷积层替代,这种替代只需要将卷积的步长stride设置为大于1的数值。改进的意义是下采样过程不再是固定的抛弃某些位置的像素值,而是可以让网络自己去学习下采样方式。

2、将全连接层去除

3、采用BN层,BN的全称是Batch Normalization,是一种用于常用于卷积层后面的归一化方法,起到帮助网络的收敛等作用。作者实验中发现对所有的层都使用BN会造成采样的震荡(我也不理解什么是采样的震荡,我猜是生成图像趋于同样的模式或者生成图像质量忽高忽低)和网络不稳定。

4、在生成器中除输出层使用Tanh(Sigmoid)激活函数,其余层全部使用ReLu激活函数。

5、在判别器所有层都使用LeakyReLU激活函数,防止梯度稀。

下面是DCGAN的生成器网络架构图。

 二、代码实战

  1. clear all; close all; clc;
  2. %% Deep Convolutional Generative Adversarial Network
  3. %% Load Data
  4. load('mnistAll.mat')
  5. trainX = preprocess(mnist.train_images);
  6. trainY = mnist.train_labels;
  7. testX = preprocess(mnist.test_images);
  8. testY = mnist.test_labels;
  9. %% Settings
  10. settings.latentDim = 100;
  11. settings.batch_size = 32; settings.image_size = [28,28,1];
  12. settings.lrD = 0.0002; settings.lrG = 0.0002; settings.beta1 = 0.5;
  13. settings.beta2 = 0.999; settings.maxepochs = 50;
  14. %% Generator
  15. paramsGen.FCW1 = dlarray(initializeGaussian([128*7*7,...
  16. settings.latentDim]));
  17. paramsGen.FCb1 = dlarray(zeros(128*7*7,1,'single'));
  18. paramsGen.TCW1 = dlarray(initializeGaussian([3,3,128,128]));
  19. paramsGen.TCb1 = dlarray(zeros(128,1,'single'));
  20. paramsGen.BNo1 = dlarray(zeros(128,1,'single'));
  21. paramsGen.BNs1 = dlarray(ones(128,1,'single'));
  22. paramsGen.TCW2 = dlarray(initializeGaussian([3,3,64,128]));
  23. paramsGen.TCb2 = dlarray(zeros(64,1,'single'));
  24. paramsGen.BNo2 = dlarray(zeros(64,1,'single'));
  25. paramsGen.BNs2 = dlarray(ones(64,1,'single'));
  26. paramsGen.CNW1 = dlarray(initializeGaussian([3,3,64,1]));
  27. paramsGen.CNb1 = dlarray(zeros(1,1,'single'));
  28. stGen.BN1 = []; stGen.BN2 = [];
  29. %% Discriminator
  30. paramsDis.CNW1 = dlarray(initializeGaussian([3,3,1,32]));
  31. paramsDis.CNb1 = dlarray(zeros(32,1,'single'));
  32. paramsDis.CNW2 = dlarray(initializeGaussian([3,3,32,64]));
  33. paramsDis.CNb2 = dlarray(zeros(64,1,'single'));
  34. paramsDis.BNo1 = dlarray(zeros(64,1,'single'));
  35. paramsDis.BNs1 = dlarray(ones(64,1,'single'));
  36. paramsDis.CNW3 = dlarray(initializeGaussian([3,3,64,128]));
  37. paramsDis.CNb3 = dlarray(zeros(128,1,'single'));
  38. paramsDis.BNo2 = dlarray(zeros(128,1,'single'));
  39. paramsDis.BNs2 = dlarray(ones(128,1,'single'));
  40. paramsDis.CNW4 = dlarray(initializeGaussian([3,3,128,256]));
  41. paramsDis.CNb4 = dlarray(zeros(256,1,'single'));
  42. paramsDis.BNo3 = dlarray(zeros(256,1,'single'));
  43. paramsDis.BNs3 = dlarray(ones(256,1,'single'));
  44. paramsDis.FCW1 = dlarray(initializeGaussian([1,256*4*4]));
  45. paramsDis.FCb1 = dlarray(zeros(1,1,'single'));
  46. stDis.BN1 = []; stDis.BN2 = []; stDis.BN3 = [];
  47. % average Gradient and average Gradient squared holders
  48. avgG.Dis = []; avgGS.Dis = []; avgG.Gen = []; avgGS.Gen = [];
  49. %% Train
  50. numIterations = floor(size(trainX,4)/settings.batch_size);
  51. out = false; epoch = 0; global_iter = 0;
  52. %% modelGradients
  53. function [GradGen,GradDis,stGen,stDis]=modelGradients(x,z,paramsGen,...
  54. paramsDis,stGen,stDis)
  55. [fake_images,stGen] = Generator(z,paramsGen,stGen);
  56. d_output_real = Discriminator(x,paramsDis,stDis);
  57. [d_output_fake,stDis] = Discriminator(fake_images,paramsDis,stDis);
  58. % Loss due to true or not
  59. d_loss = -mean(.9*log(d_output_real+eps)+log(1-d_output_fake+eps));
  60. g_loss = -mean(log(d_output_fake+eps));
  61. % For each network, calculate the gradients with respect to the loss.
  62. GradGen = dlgradient(g_loss,paramsGen,'RetainData',true);
  63. GradDis = dlgradient(d_loss,paramsDis);
  64. end
  65. %% progressplot
  66. function progressplot(paramsGen,stGen,settings)
  67. r = 5; c = 5;
  68. noise = gpdl(randn([settings.latentDim,r*c]),'CB');
  69. gen_imgs = Generator(noise,paramsGen,stGen);
  70. gen_imgs = reshape(gen_imgs,28,28,[]);
  71. fig = gcf;
  72. if ~isempty(fig.Children)
  73. delete(fig.Children)
  74. end
  75. I = imtile(gatext(gen_imgs));
  76. I = rescale(I);
  77. imagesc(I)
  78. title("Generated Images")
  79. colormap gray
  80. drawnow;
  81. end
  82. %% dropout
  83. function dly = dropout(dlx,p)
  84. if nargin < 2
  85. p = .3;

实验结果

epoch = 5;

epoch = 6

 

 

 

 

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/Cpp五条/article/detail/94328
推荐阅读
相关标签
  

闽ICP备14008679号