当前位置:   article > 正文

Keras实现生成对抗网络(GAN)(生成二维平面上服从某一分布的点)_gan拟合二维高斯分布

gan拟合二维高斯分布
GAN原理

相关数学推导可参考 李宏毅https://www.bilibili.com/video/av36779967/?p=4

通俗的比喻:制造假钞(G)和警察(D)对抗的过程。假钞制造者制造假钞,警察识别假钞

假钞制造者制造出当前警察无法识别的假钞
警察提高识别能力识别出假钞

假钞制造者再次制造出当前警察无法识别的假钞
警察再次识别出假钞


以上两个过程循环,生成器就能生成接近真实的数据,实际上是使生成数据与真实数据的有相同的分布。

数据准备

用GAN生成服从一下分布的点两个区域分别为
x∈(0,5),y∈(5,10)
x∈(10,15),y∈(10,15)
在两个区域内均服从均匀分布,每个区域个500个点
数据分布
`

def data_gen():
    # 区域1
    x1 = np.random.uniform(0, 5, 500)
    y1 = np.random.uniform(5, 10, 500)
    # 区域2 
    x2 = np.random.uniform(10, 15, 500)
    y2 = np.random.uniform(10, 15, 500)
    # 拼接
    data_x = np.concatenate((x1,x2))
    data_y = np.concatenate((y1,y2))
    # 拼接x,y以作为网络输入
    data = np.transpose(np.vstack((data_x,data_y)))
    print(data.shape)
    return data
# 测试代码   
if __name__ == '__main__':
    data = data_gen()
    plt.scatter(data[:,0],data[:,1])
    plt.show()
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
生成对抗网络
生成器

生成器D接受一个二维向量(采用服从正态分布的向量),生成坐标(x,y)

import keras
import tensorflow as tf
from keras import layers
import numpy as np
import matplotlib.pyplot as plt
from data_gen import data_gen as dg

G_input = keras.Input(shape=(2,)) # 输入一个二维vector
x = layers.Dense(5,activation='relu')(G_input)
x = layers.Dense(5,activation='relu')(x)
x = layers.Dense(2,activation='tanh')(x) # 输出为一堆二维坐标
G = keras.models.Model(G_input,x)
G.summary()
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
判别器

一个普通的全连网络

# 判别器
D_input = keras.Input(shape=(2,))
x = layers.Dense(10,activation='relu')(D_input)
x = layers.Dense(10,activation='relu')(x)
x = layers.Dropout(0, 4)(x)
x = layers.Dense(1,activation='sigmoid')(x)
D = keras.models.Model(D_input,x)
D.compile(loss='binary_crossentropy',optimizer='rmsprop')
D.summary()
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
GAN

生成器和判别器的连接,D参数设置为不可训练

D.trainable = False
gan_input = keras.Input(shape=(2,))
gan_output = D(G(gan_input))
gan = keras.models.Model(gan_input, gan_output)
gan.compile(loss='binary_crossentropy',optimizer='rmsprop')
gan.summary()
  • 1
  • 2
  • 3
  • 4
  • 5
训练
准备数据
real_point = dg()/15.0 # 简单地归一化数据
fig = plt.figure()
# 限定坐标轴位置
plt.xlim(0,1)
plt.ylim(0,1)
ax = fig.add_subplot(1,1,1)
ax.scatter(real_point[:, 0], real_point[:, 1])
plt.ion()
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
训练
epochs = 1000 # 训练次数
for step in range(epochs):
    random_input = np.random.normal(size=(1000,2)) # 输入接受的随机向量
    gen_point = G.predict(random_input) # 生成fake points

    # 拼接数据加入标签用于训练判别器
    combined_point = np.concatenate([real_point,gen_point])
    labels = np.concatenate([np.ones((1000,1)),np.zeros((1000,1))])
  
    labels += 0.05*np.random.random(labels.shape)  # 标签加噪声,似乎绝对0和绝对1都对gan训练不利
    d_loss = D.train_on_batch(combined_point, labels) # 训练判别器

    random_input = np.random.normal(size=(1000, 2)) # 随机向量,用于对抗网络训练输入
    mis_targets = np.ones((1000,1)) # 加标签
    a_loss = gan.train_on_batch(random_input, mis_targets)
    
    # 可视化过程
    if step%10==0:
        print('discriminator loss:', d_loss)
        print('adversarial loss:', a_loss)
        try:
            ax.lines.remove(points[0])
        except Exception:
            pass
        gen_point = G.predict(random_input)
        points = ax.plot(gen_point[:, 0], gen_point[:, 1],'ro')
        plt.pause(0.01)
        
plt.pause(10)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
训练过程结果

在这里插入图片描述

Reference

主要参考书籍:python深度学习 p260 - p263

用惯了matlab,matplotlib画图也很让人头疼。动态刷新plot可参考参考文章
matplotlib动态刷新指定曲线 https://blog.csdn.net/omodao1/article/details/81223240

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/代码探险家/article/detail/744674
推荐阅读
相关标签
  

闽ICP备14008679号