赞
踩
由于有可能使用GAN 网络来做一些数据增强,所以这里复现一下GAN 网络,发现这玩意儿还挺好玩。
GAN (Generative Adversarial Networks)生成对抗网络,用来生成一下不存在的真实数据。应用场景如下:
1.风格迁移:也就是传说中的AI 画家
2.图像超分辨率重建: 让图像更加清晰
3.生成不存在的真实数据:人脸生成等~
根据训练时带不带标签,GAN 网络是可分为无监督和半监督式的网络。GAN
网络分为两部分,Generator (生成器,图中G)和 Discriminator (判别器,图中D)…
随机生成的噪声,通过生成器,生成我们想要的数据,然后把这个数据和真实数据一起送入到判别器中判断,如果判别器认为输入的是生成数据,那么久训练判别器,如果判别器把生成的数据认为是真的数据,那么就要训练判别器啦~,生成器与判别器两者之间相互博弈,最后让生成器能够成功的欺骗过判别器,那么就可以使用生成器来生成想要的数据啦。
根据前人经验,生成器中的激活函数一般用relu。判别器中的激活函数一般用LeakyReLU
由于ACCGAN 是带有标签的GAN 如果训练得当,应该可以生成想要的数据。看看它的网络结构:
图中,输入到 生成器中的标签 C 和 Z 是随机生成的,但一般都要符合正态分布,生成器生成的假数据,将和真实数据一起输入到判别器中进行判断,真实数据的label 将和判别器输出的label 做损失计算,另一端的输出,只需要判断真假就好。
代码如下:
def built_generator(self): model = Sequential() model.add(Dense(128 * 7 * 7, activation='relu', input_dim=self.latent_dim)) model.add(Reshape((7, 7, 128))) model.add(BatchNormalization(momentum=0.8)) model.add(UpSampling2D()) model.add(Conv2D(128, kernel_size=3, padding='same', activation='relu')) model.add(BatchNormalization(momentum=0.8)) model.add(UpSampling2D()) model.add(Conv2D(64, kernel_size=3, padding='same', activation='relu')) model.add(BatchNormalization(momentum=0.8)) # model.add(UpSampling2D()) model.add(Conv2D(64, kernel_size=3, padding='same', activation='relu')) model.add(BatchNormalization(momentum=0.8)) model.add(Conv2D(self.channels, kernel_size=3, padding='same', activation='tanh')) model.summary() # ----------------- # 生成噪声 # -----------------、 noise = Input(shape=(self.latent_dim,)) label = Input(shape=(1,), dtype='int32') label_embedding = Flatten()(Embedding(self.num_classes, self.latent_dim)(label)) # print(Embedding(self.num_classes, self.latent_dim)(label).shape) model_input = multiply([noise, label_embedding]) img = model(model_input) return Model([noise, label], img)
关于生成器中的参数设置,首先是全连接 7x7x128, 由于手写数字 图片大小为28x28,初始大小设为7x7 后续会通过2次上采样,就会变成14x14 再由14x14 变为28x28 ,还原图片的大小。
注意:如果要训练自己的图片数据,记得计算好图片大小和上采样的次数,每次上采样,特征图会扩大到原来的两倍
def built_discriminator(self):
model = Sequential()
model.add(Conv2D(16, kernel_size=3, strides=2, input_shape=self.img_shape, padding='same'))
model.add(LeakyReLU(alpha=0.2))
model.add(Dropout(0.25))
model
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。