当前位置:   article > 正文

1 李宏毅生成对抗网络学习———GAN的直观了解_gan李宏毅代码

gan李宏毅代码

https://www.bilibili.com/video/av24011528/?p=1
ppt地址:http://speech.ee.ntu.edu.tw/~tlkagk/courses_MLDS17.html
csdn下载:
https://download.csdn.net/download/qq_35608277/10773478

1 定义

GAN 主要包括了两个部分,即生成器 generator 与判别器 discriminator。生成器主要用来学习真实图像分布从而让自身生成的图像更加真实,以骗过判别器。判别器则需要对接收的图片进行真假判别。

形象解释:

造假币的团伙相当于生成器,他们想通过伪造金钱来骗过银行,使得假币能够正常交易,而银行相当于判别器,需要判断进来的钱是真钱还是假币。因此假币团伙的目的是要造出银行识别不出的假币而骗过银行,银行则是要想办法准确地识别出假币。

love and peace 版本:
小学生学画画,从一年级到五年级,学生画画水平越来越高,老师要求也越来越高。
宿敌关系------佐助和鸣人,进藤光和塔矢亮…

总之是不断相互促进的效果。

回到GAN

对于给定的真实图片(real image),判别器要为其打上标签 1;
对于给定的生成图片(fake image),判别器要为其打上标签 0;(训练D的阶段,D的判别更强,对上一代G能够区分)

对于生成器传给辨别器的生成图片,生成器希望辨别器打上标签 1。(训练G的阶段,对上一代D判别器可以做到欺骗)

那么GAN是如何来做的呢?首先,我们有一个第一代的Generator,然后他产生一些图片,然后我们把这些图片和一些真实的图片丢到第一代的Discriminator里面去学习,让第一代的Discriminator能够真实的分辨生成的图片和真实的图片,然后我们又有了第二代的Generator,第二代的Generator产生的图片,能够骗过第一代的Discriminator,此时,我们在训练第二代的Discriminator,依次类推。

2 对应训练算法

在这里插入图片描述

3 问题及解答

  1. 为什么不能用生成器(NN)或者AE(自编码)直接生成,反而需要借助判别器?

生成器属于bottom-up架构,注重由部件到组成整体,缺乏整体的“大局观”。
以生成图片为例,相对独立的生成像素,对于相邻像素之间的关联性考虑不足。通过增加网络层数可以改善(增加关联),但同等规模的网络,GAN效果更好。

  1. 可不可以用判别器直接生成图像,而不使用生成器?(如某些名嘴,可以指点江山擅长批评,却无法提出创造性建议)

原则上是可以的。通过枚举的方式,找到打分最高的那张图,就是判别器生成的图。
判别器更擅长评估整体,是top-down的架构,比如生成图片中孤立的像素点更为敏感,能够指出其不合理。

实际上,问题的关键就在于如何产生真实的假图。如果只有数据集的真图,训练出来的判别器只会打高分。因为他遇到的所有图都是高分,所以形成所遇即高分的判断。

因此,通过生成器,来去逼近真实的假图分布,近似求解argmaxD(X)问题。

4 code

MNIST数据集比较好找。
结构
在这里插入图片描述

# -*- coding: utf-8 -*-
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import os

#读入数据
mnist = input_data.read_data_sets('./mnist', one_hot=True)#代码和数据集文件夹放在同一目录下

#从正态分布输出随机值
def xavier_init(size):
    in_dim = size[0]
    xavier_stddev = 1. / tf.sqrt(in_dim / 2.)
    return tf.random_normal(shape=size, stddev=xavier_stddev)
	 

#判别模型的输入和参数初始化
X = tf.placeholder(tf.float32, shape=[None, 784])

D_W1 = tf.Variable(xavier_init([784, 128]))
D_b1 = tf.Variable(tf.zeros(shape=[128]))

D_W2 = tf.Variable(xavier_init([128, 1]))
D_b2 = tf.Variable(tf.zeros(shape=[1]))

theta_D = [D_W1, D_W2, D_b1, D_b2]

#生成模型的输入和参数初始化
Z = tf.placeholder(tf.float32, shape=[None, 100])

G_W1 = tf.Variable(xavier_init([100, 128]))
G_b1 = tf.Variable(tf.zeros(shape=[128]))

G_W2 = tf.Variable(xavier_init([128, 784]))
G_b2 = tf.Variable(tf.zeros(shape=[784]))

theta_G = [G_W1, G_W2, G_b1, G_b2]

#随机噪声采样函数
def sample_Z(m, n):
    return np.random.uniform(-1., 1., size=[m, n])

#生成模型
def generator(z):
    G_h1 = tf.nn.relu(tf.matmul(z, G_W1) + G_b1)
    G_log_prob = tf.matmul(G_h1, G_W2) + G_b2
    G_prob = tf.nn.sigmoid(G_log_prob)

    return G_prob

#判别模型
def discriminator(x):
    D_h1 = tf.nn.relu(tf.matmul(x, D_W1) + D_b1)
    D_logit = tf.matmul(D_h1, D_W2) + D_b2
    D_prob = tf.nn.sigmoid(D_logit)

    return D_prob, D_logit

#画图函数
def plot(samples):
    fig = plt.figure(figsize=(4, 4))
    gs = gridspec.GridSpec(4, 4)
    gs.update(wspace=0.05, hspace=0.05)

    for i, sample in enumerate(samples):
        ax = plt.subplot(gs[i])
        plt.axis('off')
        ax.set_xticklabels([])
        ax.set_yticklabels([])
        ax.set_aspect('equal')
        plt.imshow(sample.reshape(28, 28), cmap='Greys_r')

    return fig

#喂入数据
G_sample = generator(Z)
D_real, D_logit_real = discriminator(X)
D_fake, D_logit_fake = discriminator(G_sample)

# 计算losses:
D_loss_real = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=D_logit_real, labels=tf.ones_like(D_logit_real))) 
D_loss_fake = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=D_logit_fake, labels=tf.zeros_like(D_logit_fake))) 
D_loss = D_loss_real + D_loss_fake

#label是真,意思是新的G可以骗过上一代D,最小化
G_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=D_logit_fake, labels=tf.ones_like(D_logit_fake))) 

#为什么最小化,不应该是最大化吗???解释:使用的交叉熵做评价,就是比较labels和logits的相近程度,越小越好
D_solver = tf.train.AdamOptimizer().minimize(D_loss, var_list=theta_D)

G_solver = tf.train.AdamOptimizer().minimize(G_loss, var_list=theta_G)

mb_size = 128
Z_dim = 100


sess = tf.Session()
sess.run(tf.global_variables_initializer())

if not os.path.exists('out/'):
    os.makedirs('out/')

i = 0

#开始训练
axis_x=[]
axis_D_loss_curr=[]
axis_G_loss_curr=[]

for it in range(10000):#1000 000
    if it % 1000 == 0:
        samples = sess.run(G_sample, feed_dict={Z: sample_Z(16, Z_dim)})

        fig = plot(samples)
        plt.savefig('out/{}.png'.format(str(i).zfill(3)), bbox_inches='tight')
        i += 1
        plt.close(fig)

    X_mb, _ = mnist.train.next_batch(mb_size)

    _, D_loss_curr = sess.run([D_solver, D_loss], feed_dict={X: X_mb, Z: sample_Z(mb_size, Z_dim)})##feed 真实图和噪声。
    _, G_loss_curr = sess.run([G_solver, G_loss], feed_dict={Z: sample_Z(mb_size, Z_dim)})###feed噪声,通过G产生生成图

    if it % 1000 == 0:
        axis_x.append(it)
        axis_D_loss_curr.append(D_loss_curr)
        axis_G_loss_curr.append(G_loss_curr)
        print('Iter: {}'.format(it))
        print('D loss: {:.4}'. format(D_loss_curr))
        print('G_loss: {:.4}'.format(G_loss_curr))
       

plt.plot(axis_x, axis_D_loss_curr, label='D_loss_curr')
plt.plot(axis_x, axis_G_loss_curr, label='G_loss_curr')
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107
  • 108
  • 109
  • 110
  • 111
  • 112
  • 113
  • 114
  • 115
  • 116
  • 117
  • 118
  • 119
  • 120
  • 121
  • 122
  • 123
  • 124
  • 125
  • 126
  • 127
  • 128
  • 129
  • 130
  • 131
  • 132
  • 133
  • 134
  • 135
  • 136

在这里插入图片描述
在这里插入图片描述

在这里插入图片描述

代码学习可以看:
https://study.163.com/course/courseLearn.htm?courseId=1005703030#/learn/video?lessonId=1052988014&courseId=1005703030

ref
https://www.jianshu.com/p/40feb1aa642a

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/花生_TL007/article/detail/318158
推荐阅读
相关标签
  

闽ICP备14008679号