赞
踩
相关数学推导可参考 李宏毅https://www.bilibili.com/video/av36779967/?p=4
通俗的比喻:制造假钞(G)和警察(D)对抗的过程。假钞制造者制造假钞,警察识别假钞
假钞制造者制造出当前警察无法识别的假钞
警察提高识别能力识别出假钞
假钞制造者再次制造出当前警察无法识别的假钞
警察再次识别出假钞
…
…
以上两个过程循环,生成器就能生成接近真实的数据,实际上是使生成数据与真实数据的有相同的分布。
用GAN生成服从一下分布的点两个区域分别为
x∈(0,5),y∈(5,10)
x∈(10,15),y∈(10,15)
在两个区域内均服从均匀分布,每个区域个500个点
`
def data_gen(): # 区域1 x1 = np.random.uniform(0, 5, 500) y1 = np.random.uniform(5, 10, 500) # 区域2 x2 = np.random.uniform(10, 15, 500) y2 = np.random.uniform(10, 15, 500) # 拼接 data_x = np.concatenate((x1,x2)) data_y = np.concatenate((y1,y2)) # 拼接x,y以作为网络输入 data = np.transpose(np.vstack((data_x,data_y))) print(data.shape) return data # 测试代码 if __name__ == '__main__': data = data_gen() plt.scatter(data[:,0],data[:,1]) plt.show()
生成器D接受一个二维向量(采用服从正态分布的向量),生成坐标(x,y)
import keras
import tensorflow as tf
from keras import layers
import numpy as np
import matplotlib.pyplot as plt
from data_gen import data_gen as dg
G_input = keras.Input(shape=(2,)) # 输入一个二维vector
x = layers.Dense(5,activation='relu')(G_input)
x = layers.Dense(5,activation='relu')(x)
x = layers.Dense(2,activation='tanh')(x) # 输出为一堆二维坐标
G = keras.models.Model(G_input,x)
G.summary()
一个普通的全连网络
# 判别器
D_input = keras.Input(shape=(2,))
x = layers.Dense(10,activation='relu')(D_input)
x = layers.Dense(10,activation='relu')(x)
x = layers.Dropout(0, 4)(x)
x = layers.Dense(1,activation='sigmoid')(x)
D = keras.models.Model(D_input,x)
D.compile(loss='binary_crossentropy',optimizer='rmsprop')
D.summary()
生成器和判别器的连接,D参数设置为不可训练
D.trainable = False
gan_input = keras.Input(shape=(2,))
gan_output = D(G(gan_input))
gan = keras.models.Model(gan_input, gan_output)
gan.compile(loss='binary_crossentropy',optimizer='rmsprop')
gan.summary()
real_point = dg()/15.0 # 简单地归一化数据
fig = plt.figure()
# 限定坐标轴位置
plt.xlim(0,1)
plt.ylim(0,1)
ax = fig.add_subplot(1,1,1)
ax.scatter(real_point[:, 0], real_point[:, 1])
plt.ion()
epochs = 1000 # 训练次数 for step in range(epochs): random_input = np.random.normal(size=(1000,2)) # 输入接受的随机向量 gen_point = G.predict(random_input) # 生成fake points # 拼接数据加入标签用于训练判别器 combined_point = np.concatenate([real_point,gen_point]) labels = np.concatenate([np.ones((1000,1)),np.zeros((1000,1))]) labels += 0.05*np.random.random(labels.shape) # 标签加噪声,似乎绝对0和绝对1都对gan训练不利 d_loss = D.train_on_batch(combined_point, labels) # 训练判别器 random_input = np.random.normal(size=(1000, 2)) # 随机向量,用于对抗网络训练输入 mis_targets = np.ones((1000,1)) # 加标签 a_loss = gan.train_on_batch(random_input, mis_targets) # 可视化过程 if step%10==0: print('discriminator loss:', d_loss) print('adversarial loss:', a_loss) try: ax.lines.remove(points[0]) except Exception: pass gen_point = G.predict(random_input) points = ax.plot(gen_point[:, 0], gen_point[:, 1],'ro') plt.pause(0.01) plt.pause(10)
主要参考书籍:python深度学习 p260 - p263
用惯了matlab,matplotlib画图也很让人头疼。动态刷新plot可参考参考文章
matplotlib动态刷新指定曲线 https://blog.csdn.net/omodao1/article/details/81223240
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。