当前位置:   article > 正文

ACGAN 生成自己手写数字数据集_acgan训练自己的数据集

acgan训练自己的数据集


前言

由于有可能使用GAN 网络来做一些数据增强,所以这里复现一下GAN 网络,发现这玩意儿还挺好玩。

一、GAN是什么?

GAN (Generative Adversarial Networks)生成对抗网络,用来生成一下不存在的真实数据。应用场景如下:
1.风格迁移:也就是传说中的AI 画家
2.图像超分辨率重建: 让图像更加清晰
3.生成不存在的真实数据:人脸生成等~
在这里插入图片描述
根据训练时带不带标签,GAN 网络是可分为无监督和半监督式的网络。GAN
网络分为两部分,Generator (生成器,图中G)和 Discriminator (判别器,图中D)…
随机生成的噪声,通过生成器,生成我们想要的数据,然后把这个数据和真实数据一起送入到判别器中判断,如果判别器认为输入的是生成数据,那么久训练判别器,如果判别器把生成的数据认为是真的数据,那么就要训练判别器啦~,生成器与判别器两者之间相互博弈,最后让生成器能够成功的欺骗过判别器,那么就可以使用生成器来生成想要的数据啦。

根据前人经验,生成器中的激活函数一般用relu。判别器中的激活函数一般用LeakyReLU

二、ACGAN

1.ACGAN 网络结构

由于ACCGAN 是带有标签的GAN 如果训练得当,应该可以生成想要的数据。看看它的网络结构:在这里插入图片描述

图中,输入到 生成器中的标签 C 和 Z 是随机生成的,但一般都要符合正态分布,生成器生成的假数据,将和真实数据一起输入到判别器中进行判断,真实数据的label 将和判别器输出的label 做损失计算,另一端的输出,只需要判断真假就好。

2.Generator 生成器实现

代码如下:

    def built_generator(self):
        model = Sequential()

        model.add(Dense(128 * 7 * 7, activation='relu', input_dim=self.latent_dim))
        model.add(Reshape((7, 7, 128)))
        model.add(BatchNormalization(momentum=0.8))

        model.add(UpSampling2D())
        model.add(Conv2D(128, kernel_size=3, padding='same', activation='relu'))
        model.add(BatchNormalization(momentum=0.8))

        model.add(UpSampling2D())
        model.add(Conv2D(64, kernel_size=3, padding='same', activation='relu'))
        model.add(BatchNormalization(momentum=0.8))

        # model.add(UpSampling2D())
        model.add(Conv2D(64, kernel_size=3, padding='same', activation='relu'))
        model.add(BatchNormalization(momentum=0.8))

        model.add(Conv2D(self.channels, kernel_size=3, padding='same', activation='tanh'))

        model.summary()

        # -----------------
        # 生成噪声
        # -----------------、
        noise = Input(shape=(self.latent_dim,))
        label = Input(shape=(1,), dtype='int32')

        label_embedding = Flatten()(Embedding(self.num_classes, self.latent_dim)(label))
        # print(Embedding(self.num_classes, self.latent_dim)(label).shape)
        model_input = multiply([noise, label_embedding])

        img = model(model_input)

        return Model([noise, label], img)

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37

关于生成器中的参数设置,首先是全连接 7x7x128, 由于手写数字 图片大小为28x28,初始大小设为7x7 后续会通过2次上采样,就会变成14x14 再由14x14 变为28x28 ,还原图片的大小。

注意:如果要训练自己的图片数据,记得计算好图片大小和上采样的次数,每次上采样,特征图会扩大到原来的两倍

3.Discriminator 判别器实现

    def built_discriminator(self):
        model = Sequential()

        model.add(Conv2D(16, kernel_size=3, strides=2, input_shape=self.img_shape, padding='same'))
        model.add(LeakyReLU(alpha=0.2))
        model.add(Dropout(0.25))

        model
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/羊村懒王/article/detail/530184
推荐阅读
相关标签
  

闽ICP备14008679号