赞
踩
Keras真香,以前都是用tensorflow来写神经网络,自从用了keras,发现这个Keras也蛮方便的。
目前感觉keras的优点就是方便搭建基于标准网络组件的神经网络,这里的网络组件包括全连接层,卷积层,池化层等等。对于需要对网络本身做创新的实验,keas可能不是很方便,还是得用tensorflow来搭建。
这篇博客,我想用Keras写一个简单的生成对抗网络。
生成对抗网络的目标是生成手写体数字。
先看看实验的效果:
epoch=1000的时候:
epoch=10000的时候:数字1已经有点像了
epoch=60000,数字1就很清晰了 ,而且其他数字也越来越清晰了
epoch=80000: 生成了5,7 啥的了。
随着训练的加深,生成的数字会越来越真实了。
代码已经开源,项目地址:
https://github.com/jmhIcoding/GAN_MNIST.git
模型原理就不说了,就是使用最基础GAN结构。
模型由一个生成器和一个鉴别器组成。
生成器用于输入噪声,然后生成一个手写体数字图片。
鉴别器用于判断某个输入给它的图片是不是生成器合成的。
生成器的目标是生成让鉴别器判断为非合成的图片。
鉴别器的目标则是以尽量高的正确率分类某种图片是否为合成的。
总的原理就是这些了。
模型的损失函数就是围绕着这两个目标来展开的。
__author__ = 'dk' #生成器 import sys import numpy as np import keras from keras import layers from keras import models from keras import optimizers from keras import losses class Generator: def __init__(self,height=28,width=28,channel=1,latent_space_dimension=100): ''' :param height: 生成图片的高,minist为28 :param width: 生成图片的宽,minist为28 :param channel: 生成器所生成的图片的通道数目,对于mnist灰度图来说,channel为1 :param latent_space_dimension: 噪声的维度 :return: ''' self.latent_space_dimension = latent_space_dimension self.height = height self.width = width self.channel = channel self.generator = self.build_model() self.generator.summary() def build_model(self,block_starting_size=128,num_blocks=4): model = models.Sequential(name='generator') for i in range(num_blocks): if i ==0 : model.add(layers.Dense(block_starting_size,input_shape=(self.latent_space_dimension,))) else: block_size = block_starting_size * (2**i) model.add(layers.Dense(block_size)) model.add(layers.LeakyReLU()) model.add(layers.BatchNormalization(momentum=0.75)) model.add(layers.Dense(self.height*self.channel*self.width,activation='tanh')) model.add(layers.Reshape((self.width,self.height,self.channel))) return model def summary(self): self.model.summary() def save_model(self): self.generator.save("generator.h5")
注意,generator是和整个模型一起训练的,它可以不需要compile模型。
__author__ = 'dk' #判别器 import sys import os import keras from keras import layers from keras import optimizers from keras import models from keras import losses class Discriminator: def __init__(self,height=28,width=28,channel=1): ''' :param height: 输入图片的高 :param width: 输入图片的宽 :param channel: 输入图片的通道数 :return: ''' self.height = height self.width = width self.channel = channel self.discriminator = self.build_model() OPTIMIZER = optimizers.Adam() self.discriminator = self.build_model() self.discriminator.compile(optimizer=OPTIMIZER,loss=losses.binary_crossentropy,metrics =['accuracy']) self.discriminator.summary() def build_model(self): model = models.Sequential(name='discriminator') model.add(layers.Flatten(input_shape=(self.width,self.height,self.channel))) model.add(layers.Dense(self.height*self.width*self.channel,input_shape=(self.width,self.height,self.channel))) model.add(layers.LeakyReLU(0.2)) model.add(layers.Dense(self.height*self.width*self.channel//2)) model.add(layers.LeakyReLU(0.2)) model.add(layers.Dense(1,activation='sigmoid')) return model def summary(self): return self.discriminator.summary() def save_model(self): self.discriminator.save("discriminator.h5")
把生成器和鉴别器合并起来
__author__ = 'dk' #生成对抗网络 import keras from keras import layers from keras import optimizers from keras import losses from keras import models import sys import os from Discriminator import Discriminator from Generator import Generator class GAN: def __init__(self,latent_space_dimension,height,width,channel): self.generator = Generator(height,width,channel,latent_space_dimension) self.discriminator = Discriminator(height,width,channel) self.discriminator.discriminator.trainable = False #gan部分,只训练生成器,鉴别器通过显式discriminator.train_on_batch调用来训练 self.gan = self.build_model() OPTIMIZER = optimizers.Adamax() self.gan.compile(optimizer = OPTIMIZER,loss = losses.binary_crossentropy) self.gan.summary() def build_model(self): model = models.Sequential(name='gan') model.add(self.generator.generator) model.add(self.discriminator.discriminator) return model def summary(self): self.gan.summary() def save_model(self): self.gan.save("gan.h5")
__author__ = 'dk' #数据集采集器,主要是对mnist进行简单的封装 from keras.datasets import mnist import numpy as np def sample_latent_space(instances_number,latent_space_dimension): return np.random.normal(0,1,(instances_number,latent_space_dimension)) class Dator: def __init__(self,batch_size=None,model_type=1): ''' :param batch_size: :param model_type: 当model_type为-1的时候,表示0-9个数字都选;当model_type=2,说明只选择数字2 :return: ''' self.batch_size = batch_size self.model_type = model_type with np.load("mnist.npz", allow_pickle=True) as f: X_train, y_train = f['x_train'], f['y_train'] #X_test, y_test = f['x_test'], f['y_test'] if model_type != -1: X_train = X_train[np.where(y_train==model_type)[0]] if batch_size == None: self.batch_size = X_train.shape[0] else: self.batch_size = batch_size self.X_train = (np.float32(X_train)-128)/128.0 self.X_train = np.expand_dims(self.X_train,3) self.watch_index = 0 self.train_size = self.X_train.shape[0] def next_batch(self,batch_size = None): if batch_size == None: batch_size =self.batch_size X=np.concatenate([self.X_train[self.watch_index:(self.watch_index+batch_size)], self.X_train[:batch_size]])[:batch_size] self.watch_index = (self.watch_index + batch_size) % self.train_size return X if __name__ == '__main__': print(sample_latent_space(5,4))
__author__ = 'dk' #模型训练代码 from GAN import GAN from data_utils import Dator,sample_latent_space import numpy as np from matplotlib import pyplot as plt import time epochs = 50000 height = 28 width = 28 channel =1 latent_space_dimension = 100 batch = 128 dator = Dator(batch_size=batch,model_type=-1) gan = GAN(latent_space_dimension,height,width,channel) image_index = 0 for i in range(epochs): real_img = dator.next_batch(batch_size=batch*2) real_label = np.ones(shape=(real_img.shape[0],1)) #真实的样本设置为1的标签 noise = sample_latent_space(real_img.shape[0],latent_space_dimension) fake_img = gan.generator.generator.predict(noise) fake_label = np.zeros(shape=(fake_img.shape[0],1)) #生成器生成的假图片标注为0 ###合成给gan的鉴别器的数据 x_batch = np.concatenate([real_img,fake_img]) y_batch = np.concatenate([real_label,fake_label]) #训练一次 discriminator_loss = gan.discriminator.discriminator.train_on_batch(x_batch,y_batch)[0] ###注意,此时训练的是鉴别器,生成器部分不动。 ###合成训练生成器的数据 noise = sample_latent_space(batch*2,latent_space_dimension) noise_labels = np.ones((batch*2,1)) #生成器的目标是把图片的label越来越像1 generator_loss = gan.gan.train_on_batch(noise,noise_labels) print('Epoch : {0}, [Discriminator Loss:{1} ], [Generator Loss:{2}]'.format(i,discriminator_loss,generator_loss)) if i!=0 and (i%50)==0: print('show time') #每50次输入16张图片看看效果 noise = sample_latent_space(16,latent_space_dimension) images = gan.generator.generator.predict(noise) plt.figure(figsize=(10,10)) plt.suptitle('epoch={0}'.format(i),fontsize=16) for index in range(images.shape[0]): plt.subplot(4,4,index+1) image =images[index,:,:,:] image = image.reshape(height,width) plt.imshow(image,cmap='gray') #plt.tight_layout() plt.savefig("./show_time/{0}.png".format(time.time())) image_index += 1 plt.close()
python3 train.py
即可。
输出:
Model: "generator" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= dense_1 (Dense) (None, 128) 12928 _________________________________________________________________ dense_2 (Dense) (None, 256) 33024 _________________________________________________________________ leaky_re_lu_1 (LeakyReLU) (None, 256) 0 _________________________________________________________________ batch_normalization_1 (Batch (None, 256) 1024 _________________________________________________________________ dense_3 (Dense) (None, 512) 131584 _________________________________________________________________ leaky_re_lu_2 (LeakyReLU) (None, 512) 0 _________________________________________________________________ batch_normalization_2 (Batch (None, 512) 2048 _________________________________________________________________ dense_4 (Dense) (None, 1024) 525312 _________________________________________________________________ leaky_re_lu_3 (LeakyReLU) (None, 1024) 0 _________________________________________________________________ batch_normalization_3 (Batch (None, 1024) 4096 _________________________________________________________________ dense_5 (Dense) (None, 784) 803600 _________________________________________________________________ reshape_1 (Reshape) (None, 28, 28, 1) 0 ================================================================= Total params: 1,513,616 Trainable params: 1,510,032 Non-trainable params: 3,584 _________________________________________________________________ Model: "discriminator" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= flatten_2 (Flatten) (None, 784) 0 _________________________________________________________________ dense_9 (Dense) (None, 784) 615440 _________________________________________________________________ leaky_re_lu_6 (LeakyReLU) (None, 784) 0 _________________________________________________________________ dense_10 (Dense) (None, 392) 307720 _________________________________________________________________ leaky_re_lu_7 (LeakyReLU) (None, 392) 0 _________________________________________________________________ dense_11 (Dense) (None, 1) 393 ================================================================= Total params: 923,553 Trainable params: 923,553 Non-trainable params: 0 _________________________________________________________________ Model: "gan" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= generator (Sequential) (None, 28, 28, 1) 1513616 _________________________________________________________________ discriminator (Sequential) (None, 1) 923553 ================================================================= Total params: 2,437,169 Trainable params: 1,510,032 Non-trainable params: 927,137 _________________________________________________________________ ···· ··· ·· Epoch : 117754, [Discriminator Loss:0.22975191473960876 ], [Generator Loss:2.57688570022583] Epoch : 117755, [Discriminator Loss:0.26782122254371643 ], [Generator Loss:3.1791584491729736] Epoch : 117756, [Discriminator Loss:0.2609345614910126 ], [Generator Loss:2.960988998413086] Epoch : 117757, [Discriminator Loss:0.2673880159854889 ], [Generator Loss:2.317220687866211] Epoch : 117758, [Discriminator Loss:0.24904575943946838 ], [Generator Loss:1.929720401763916] Epoch : 117759, [Discriminator Loss:0.25158950686454773 ], [Generator Loss:2.954155683517456] Epoch : 117760, [Discriminator Loss:0.20324105024337769 ], [Generator Loss:3.5244760513305664] Epoch : 117761, [Discriminator Loss:0.2849388122558594 ], [Generator Loss:3.195873498916626] Epoch : 117762, [Discriminator Loss:0.19631560146808624 ], [Generator Loss:2.328411340713501] Epoch : 117763, [Discriminator Loss:0.20523831248283386 ], [Generator Loss:2.402683973312378] Epoch : 117764, [Discriminator Loss:0.2625979781150818 ], [Generator Loss:3.2176101207733154] Epoch : 117765, [Discriminator Loss:0.29969191551208496 ], [Generator Loss:2.9656052589416504] Epoch : 117766, [Discriminator Loss:0.270328551530838 ], [Generator Loss:2.3880398273468018] Epoch : 117767, [Discriminator Loss:0.26741161942481995 ], [Generator Loss:2.7729406356811523] Epoch : 117768, [Discriminator Loss:0.28797847032546997 ], [Generator Loss:2.8959264755249023] Epoch : 117769, [Discriminator Loss:0.30181047320365906 ], [Generator Loss:2.791097402572632] Epoch : 117770, [Discriminator Loss:0.26939862966537476 ], [Generator Loss:2.3666043281555176] Epoch : 117771, [Discriminator Loss:0.26297527551651 ], [Generator Loss:2.895970582962036] Epoch : 117772, [Discriminator Loss:0.21928083896636963 ], [Generator Loss:3.4627976417541504] Epoch : 117773, [Discriminator Loss:0.3553962707519531 ], [Generator Loss:3.2194197177886963] Epoch : 117774, [Discriminator Loss:0.32673510909080505 ], [Generator Loss:2.473867893218994] Epoch : 117775, [Discriminator Loss:0.31245478987693787 ], [Generator Loss:2.999265193939209] Epoch : 117776, [Discriminator Loss:0.29536381363868713 ], [Generator Loss:3.733344554901123] Epoch : 117777, [Discriminator Loss:0.2955515682697296 ], [Generator Loss:3.2467658519744873] Epoch : 117778, [Discriminator Loss:0.3677394986152649 ], [Generator Loss:1.8517814874649048] Epoch : 117779, [Discriminator Loss:0.31648850440979004 ], [Generator Loss:2.6385254859924316] Epoch : 117780, [Discriminator Loss:0.31941041350364685 ], [Generator Loss:3.350475311279297] Epoch : 117781, [Discriminator Loss:0.47521263360977173 ], [Generator Loss:1.9556307792663574] Epoch : 117782, [Discriminator Loss:0.44070643186569214 ], [Generator Loss:1.9684114456176758]
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。