赞
踩
生成式对抗网络的框架主要有两个模型,一个是生成模型(Generator),记为为 G,是用来生成数据,通过大量的样本学习,生成一些能够以假乱真的数据样本。第二个是辨别模型(Discriminator),记为D,主要是接受 G生成的样本数据和真实样本数据,进行辨别和分类。生成网络G接受一个随机的噪声z并生成图片,记为G(z);判别网络D的作用是判别一张图片x是否真实,对于输入x,D(x)是x为真实图片的概率。G和D相互博弈,通过学习,G的生成能力和D的辨别能力逐渐增强直到收敛。
一个随机生成符合随机分布的噪音 z, 生成器G通过一个复杂的映射关系生成假样本
辨别器对于真实样本和假的样本,输出一个0到1之间的值,越大就越有可能是真实样本
总的目标函数
代码
- # encoding: utf-8
- import tensorflow as tf
- import numpy as np
- import matplotlib.pyplot as plt
- import os
- from tensorflow.examples.tutorials.mnist import input_data
- mnist = input_data.read_data_sets('MNIST_data/')
- #设置一些参数
- batch_size = 100
- z_dim = 100
- OUTPUT_DIR = 'samples'
- if not os.path.exists(OUTPUT_DIR):
- os.mkdir(OUTPUT_DIR)
-
- X = tf.placeholder(dtype = tf.float32, shape = [None, 28,28, 1], name = 'X')
- Noise = tf.placeholder(dtype = tf.float32, shape = [None, z_dim], name = 'Noise')
- is_training = tf.placeholder(dtype = tf.bool, name = 'is_training')
-
- def relu(x, leak = 0.2):
- return tf.maximum(x, leak * x)
-
- def sigmoid_cross_entropy_with_logits(x, y):
- return tf.nn.sigmoid_cross_entropy_with_logits(logits=x, labels=y)
-
- #辨别器
- def discriminator(image,reuse = None, is_training = is_training):
- m = 0.9
- with tf.variable_scope('discriminator', reuse = reuse):
- H0 = relu(tf.layers.conv2d(image, kernel_size = 5, filters = 64, strides=2, padding ='same'))
-
- H1 = tf.layers.conv2d(H0, kernel_size = 5, filters = 128, strides = 2,padding = 'same')
- H1 = relu(tf.contrib.layers.batch_norm(H1, is_training=is_training, decay = m))
-
- H2 = tf.layers.conv2d(H1, kernel_size = 5, filters = 256, strides = 2, padding = 'same')
- H2 = relu(tf.contrib.layers.batch_norm(H2, is_training = is_training, decay = m))
-
- H3 = tf.layers.conv2d(H2, kernel_size = 5 , filters = 512, strides = 2, padding = 'same')
- H3 = relu(tf.contrib.layers.batch_norm(H3, is_training = is_training, decay = m))
-
- H4 = tf.contrib.layers.flatten(H3)
- H4 = tf.layers.dense(H4, units=1)
- return tf.nn.sigmoid(H4), H4
-
-
- def generator(z, is_training = is_training):
- m = 0.8
-
- with tf.variable_scope('generator', reuse = None):
- d = 3
- H0 = tf.layers.dense(z, units = d*d*512)
- H0 = tf.reshape(H0, shape = [-1, d, d, 512])
- H0 = tf.nn.relu(tf.contrib.layers.batch_norm(H0, is_training=is_training, decay = m))
-
- H1 = tf.layers.conv2d_transpose(H0, kernel_size = 5, filters = 256, strides = 2, padding = 'same')
- H1 = tf.nn.relu(tf.contrib.layers.batch_norm(H1, is_training = is_training, decay = m))
-
- H2 = tf.layers.conv2d_transpose(H1, kernel_size = 5, filters = 128,strides = 2, padding = 'same')
- H2 = tf.nn.relu(tf.contrib.layers.batch_norm(H2 , is_training=is_training, decay = m ))
-
- H3 = tf.layers.conv2d_transpose(H2, kernel_size = 5, filters = 64, strides = 2, padding = 'same')
- H3 = tf.nn.relu(tf.contrib.layers.batch_norm(H3, is_training = is_training, decay = m))
-
- H4 = tf.layers.conv2d_transpose(H3, kernel_size = 5, filters= 1, strides = 1, padding = 'valid', activation=tf.nn.tanh, name = 'g')
-
- return H4
-
- g = generator(Noise)
-
- d_real,d_real_logits = discriminator(X)
- d_fake, d_fake_logits = discriminator(g, reuse = True)
-
-
- vars_g = [var for var in tf.trainable_variables() if var.name.startswith('generator')]
- vars_d = [var for var in tf.trainable_variables() if var.name.startswith('discriminator')]
-
- loss_d_real = tf.reduce_mean(sigmoid_cross_entropy_with_logits(d_real_logits,tf.ones_like(d_real)))
- loss_d_fake = tf.reduce_mean(sigmoid_cross_entropy_with_logits(d_fake_logits, tf.zeros_like(d_fake)))
-
- loss_g = tf.reduce_mean(sigmoid_cross_entropy_with_logits(d_fake_logits,tf.ones_like(d_fake)))
- loss_d = loss_d_real + loss_d_fake
-
- updates_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
- with tf.control_dependencies(updates_ops):
-
- optimizer_d = tf.train.AdamOptimizer(learning_rate = 0.0002, beta1 = 0.5).minimize(loss_d, var_list=vars_d)
- optimizer_g = tf.train.AdamOptimizer(learning_rate = 0.0002, beta1 = 0.5).minimize(loss_g, var_list = vars_g)
-
- def montage(images):
- if isinstance(images,list):
- images = np.array(images)
-
- image_h = images.shape[1]
- image_w = images.shape[2]
- n_plots = int(np.ceil(np.sqrt(images.shape[0])))
- m = np.ones((images.shape[1] * n_plots + n_plots +1 , images.shape[2] * n_plots + n_plots + 1)) * 0.5
- for i in range(n_plots):
- for j in range(n_plots):
- this_filter = i*n_plots + j
- if this_filter < images.shape[0]:
- this_img = images[this_filter]
- m[1 + i + i*image_h : 1 + i+ (i+1) *image_h,1+j+ j*image_w : 1+ j+(j+1)*image_w] = this_img
- return m
-
- sess = tf.Session()
- sess.run(tf.global_variables_initializer())
- z_samlpes = np.random.uniform(-1.0, 1.0,[batch_size,z_dim]).astype(np.float32)
-
- samples = []
- loss = {'d':[], 'g':[]}
- for i in range(30000):
- n = np.random.uniform(-1.0,1.0,[batch_size,z_dim]).astype(np.float32)
- batch = mnist.train.next_batch(batch_size=batch_size)[0]
- batch = np.reshape(batch,[-1,28,28,1])
- batch = (batch - 0.5) * 2
- d_ls,g_ls = sess.run([loss_d, loss_g], feed_dict={X:batch,Noise:n, is_training:True})
- loss['d'].append(d_ls)
- loss['g'].append(g_ls)
-
- sess.run(optimizer_d, feed_dict={X:batch,Noise:n,is_training:True})
- sess.run(optimizer_g, feed_dict={X:batch,Noise:n,is_training:True})
- sess.run(optimizer_g, feed_dict={X:batch, Noise:n,is_training:True})
-
- if i % 20 == 0:
- print(i,d_ls, g_ls)
- gen_imgs = sess.run(g, feed_dict={Noise:z_samlpes,is_training:False})
- gen_imgs = (gen_imgs + 1) /2
- imgs = [img[:,:,0] for img in gen_imgs]
- gen_imgs = montage(imgs)
- plt.axis('off')
- plt.imshow(gen_imgs, cmap = 'gray')
- plt.savefig(os.path.join(OUTPUT_DIR,'sample_%d.jpg'%i))
- plt.show()
- samples.append(gen_imgs)
-
- plt.plot(loss['d'], label='discriminator')
- plt.plot(loss['g'], label = 'generator')
- plt.legend(loc = 'upper right')
- plt.show()
- saver =tf.train.Saver()
- saver.save(sess, './mnist_dcgan', global_step=30000)
结果
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。