当前位置:   article > 正文

5分钟了解Ai 之 对抗学习 图像生成 (tensorflow2实现gan网络)_生成对抗网络tensorflow代码

生成对抗网络tensorflow代码

对抗学习 GAN

对抗学习,指的是同时训练两个模型,分别为用于捕获数据分布的生成模型和用于判别数据是真实数据还是生成数据的判别模型。
两个模型通过对抗性过程同时训练,生成模型学会创建看起来真实的图像,而判别模型学会区分真实图像和赝品。
在训练过程中,生成器逐渐变得更擅长创建看起来真实的图像,而鉴别器则变得更擅长区分它们。
当鉴别器无法分辨真伪图像时,该学习网络达到平衡。

上代码

开发环境:python3.9
插件包:tensorflow2、opencv、matplotlib、numpy

from __future__ import absolute_import, division, print_function, unicode_literals
import tensorflow as tf
from tensorflow import keras
import glob
import imageio
import matplotlib.pyplot as plt
import numpy as np
import os
import PIL
from tensorflow.keras import layers
import time
import sys

# 使用手写字体或单品样本做训练  这里注意的是 我们只需要训练数据,不需要答案和测试数据集。
(train_images, _), (_, _) = keras.datasets.mnist.load_data()
# (train_images, _), (_, _) = keras.datasets.fashion_mnist.load_data()

# 因为卷积层的需求,增加色深维度
train_images = train_images.reshape(train_images.shape[0], 28, 28, 1).astype('float32')
# 规范化为-1 - +1
train_images = (train_images - 127.5) / 127.5

BUFFER_SIZE = 60000  # 以供60000个样本
BATCH_SIZE = 256  # 256张为一组
train_dataset = tf.data.Dataset.from_tensor_slices(train_images).shuffle(BUFFER_SIZE).batch(BATCH_SIZE)

# 图片生成模型
def make_generator_model():  # 根据长度为100的随机数组,生成一张28,28,1的矩阵
    model = tf.keras.Sequential()
    # 全联接层,输入纬度为[[100],[n]],  输出为7*7*256 = 12544的节点  use_bias=False不使用偏差
    model.add(layers.Dense(7*7*256, use_bias=False, input_shape=(100,)))
    # BatchNormalization层:该层在每个batch上将前一层的激活值重新规范化,即使得其输出数据的均值接近0,其标准差接近1
    # 该层作用:(1)加速收敛(2)控制过拟合,可以少用或不用Dropout和正则(3)降低网络对初始化权重不敏感(4)允许使用较大的学习率
    model.add(layers.BatchNormalization())
    # ReLU是将所有的负值都设为零,相反,Leaky ReLU是给所有负值赋予一个非零斜率(负数)
    model.add(layers.LeakyReLU())
    # 将平铺的节点转为7*7*256的shape
    model.add(layers.Reshape((7, 7, 256)))
    # 验证图形。非则断点
    assert model.output_shape == (None, 7, 7, 256)  # Note: None is the batch size
    # 通俗的讲这个解卷积,也就做反卷积,也叫做转置卷积(最贴切),我们就叫做反卷积吧,它的目的就是卷积的反向操作
    # 个人理解,正常的卷积是提取卷积核特征,反卷积就是用卷积核反向修改图像,风格迁移应该也是这么回事,那么问题来了在这个gan中,卷积特征从哪来?
    model.add(layers.Conv2DTranspose(128, (5, 5), strides=(1, 1), padding='same', use_bias=False))
    assert model.output_shape == (None, 7, 7, 128)  #
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU())
    #64, (5, 5), strides=(2, 2), 希望得到64个特征核,步长2,2
    #model.output_shape == (None, 14, 14, 64) 输出的节点数64就是上面的特征核,由于padding='same',所以卷积后无变化,
    #14,14 是因为步长 2,2  所以7*2
    model.add(layers.Conv2DTranspose(64, (5, 5), strides=(2, 2), padding='same', use_bias=False))
    assert model.output_shape == (None, 14, 14, 64)  # 12544
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU())
    model.add(layers.Conv2DTranspose(1, (5, 5), strides=(2, 2), padding='same', use_bias=False, activation='tanh'))
    #验证上层是否输出一个 28.28.1 总量为784的图像矩阵
    assert model.output_shape == (None, 28, 28, 1)
    return model
generator = make_generator_model()

def make_discriminator_model():  # 原图、生成图辨别网络  这个模型与上一个相反,难道是为了提取特征?
    model = tf.keras.Sequential()
    # 将 28.28.1的图像卷积 输出64个节点
    model.add(layers.Conv2D(64, (5, 5), strides=(2, 2), padding='same', input_shape=[28, 28, 1]))
    model.add(layers.LeakyReLU())
    model.add(layers.Dropout(0.3))
    # 接着卷积出128个节点
    model.add(layers.Conv2D(128, (5, 5), strides=(2, 2), padding='same'))
    #激活函数 为非0的斜率
    model.add(layers.LeakyReLU())
    model.add(layers.Dropout(0.3))
    # 平铺 并输出一个数字
    model.add(layers.Flatten())
    model.add(layers.Dense(1))
    return model
discriminator = make_discriminator_model()

#这段大概就是个测试
# 随机生成一个向量,用于生成图片
noise = tf.random.normal([1, 100])
# 生成一张,此时模型未经训练,图片为噪点 generator = keras.model
generated_image = generator(noise, training=False)
# plt.imshow(generated_image[0, :, :, 0], cmap='gray')
# 判断结果
decision = discriminator(generated_image)
# 此时的结果应当应当趋近于0,表示为伪造图片
print(decision)


# 交叉熵损失函数
cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True)

# 辨别模型损失函数
def discriminator_loss(real_output, fake_output):
    # 样本图希望结果趋近1
    real_loss = cross_entropy(tf.ones_like(real_output), real_output)
    # 自己生成的图希望结果趋近0
    fake_loss = cross_entropy(tf.zeros_like(fake_output), fake_output)
    # 总损失
    total_loss = real_loss + fake_loss
    return total_loss

# 生成模型的损失函数
def generator_loss(fake_output):
    # 生成模型期望最终的结果越来越接近1,也就是真实样本
    return cross_entropy(tf.ones_like(fake_output), fake_output)

#优化器
generator_optimizer = tf.keras.optimizers.Adam(1e-4)
discriminator_optimizer = tf.keras.optimizers.Adam(1e-4)

# 训练结果保存
checkpoint_dir = r'dcgan_training_checkpoints'
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
checkpoint = tf.train.Checkpoint(generator_optimizer=generator_optimizer,
                                 discriminator_optimizer=discriminator_optimizer,
                                 generator=generator,
                                 discriminator=discriminator)

EPOCHS = 100
noise_dim = 100
num_examples_to_generate = 16

# 初始化16个种子向量,用于生成4x4的图片  seed shape: 16, 100
seed = tf.random.normal([num_examples_to_generate, noise_dim])


# @tf.function表示TensorFlow编译、缓存此函数,用于在训练中快速调用
@tf.function
def train_step(images):  #更新 模型权重数据的核心方法
    # 随机生成一个批次的种子向量 BATCH_SIZE = 256   noise_dim = 100  ,256个长度为100的噪音响亮
    noise = tf.random.normal([BATCH_SIZE, noise_dim]) #noise shape:[256],[100]

    #查看每一次epoch参数更新  这个GradientTape 是每次梯度更新都会调用的,这个取代了model.fit的训练计算
    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
        # 生成一个批次的图片
        generated_images = generator(noise, training=True)

        # 辨别一个批次的真实样本
        real_output = discriminator(images, training=True)
        # 辨别一个批次的生成图片
        fake_output = discriminator(generated_images, training=True)

        # 计算两个损失值
        gen_loss = generator_loss(fake_output)
        disc_loss = discriminator_loss(real_output, fake_output)

    # 根据损失值调整模型的权重参量
    gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)
    gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)

    # 计算出的参量应用到模型   梯度修剪,用于改变值, 梯度修剪主要避免训练梯度爆炸和消失问题
    #zIP是个格式转换函数 例如:a = [[1, 2, 3], [4, 5, 6], [7, 8, 9]]; zip(*a) = [(1, 4, 7), (2, 5, 8), (3, 6, 9)]
    generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))
    discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))

#训练
def train(dataset, epochs):  
    for epoch in range(epochs+1):
        start = time.time()

        # 训练
        for image_batch in dataset:
            train_step(image_batch)

        #保存图片
        # 每个训练批次生成一张图片作为阶段成功
        print("=======================================")
        generate_and_save_images(generator, epoch + 1, seed)

        # 保存模型
        # 每20次迭代保存一次训练数据
        # if (epoch + 1) % 20 == 0: # 注销该行每次都保存
        checkpoint.save(file_prefix=checkpoint_prefix)

        print('Time for epoch {} is {} sec'.format(epoch + 1, time.time()-start))

# 生成图片
def generate_and_save_images(model, epoch, test_input):
    # 设置为非训练状态,生成一组图片
    predictions = model(test_input, training=False)

    fig = plt.figure(figsize=(4, 4))

    # 4格x4格拼接
    for i in range(predictions.shape[0]):
        plt.subplot(4, 4, i+1)
        plt.imshow(predictions[i, :, :, 0] * 127.5 + 127.5, cmap='gray')
        plt.axis('off')

    # 保存为png
    plt.savefig('image_at_epoch_{:04d}.png'.format(epoch))
    # plt.show()
    plt.close()

# 遍历所有png图片,汇总为gif动图
def write_gif():
    anim_file = 'dcgan.gif'
    with imageio.get_writer(anim_file, mode='I') as writer:
        filenames = glob.glob('image*.png')
        filenames = sorted(filenames)
        last = -1
        for i, filename in enumerate(filenames):
            frame = 2*(i**0.5)
            if round(frame) > round(last):
                last = frame
            else:
                continue
            image = imageio.imread(filename)
            writer.append_data(image)
        image = imageio.imread(filename)
        writer.append_data(image)

# 生成一张初始状态的4格图片,应当是噪点
generate_and_save_images(generator, 0000, seed)


# 如果使用train参数运行则进入训练模式
TRAIN = True
if len(sys.argv) == 2 and sys.argv[1] == 'train':
    TRAIN = True

if TRAIN:
    # 以训练模式运行,进入训练状态
    checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir)) #启用该段,应该可以继续训练
    train(train_dataset, EPOCHS)
    write_gif()
else:
    # 非训练模式,恢复训练数据
    checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))
    print("After training:")
    # 显示训练完成后,生成图片的辨别结果
    generated_image = generator(noise, training=False)
    decision = discriminator(generated_image)
    # 结果应当趋近1
    print(decision)
    # 重新生成随机值,生成一组图片保存
    seed = tf.random.normal([num_examples_to_generate, noise_dim])

    generate_and_save_images(generator, 9999, seed)

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107
  • 108
  • 109
  • 110
  • 111
  • 112
  • 113
  • 114
  • 115
  • 116
  • 117
  • 118
  • 119
  • 120
  • 121
  • 122
  • 123
  • 124
  • 125
  • 126
  • 127
  • 128
  • 129
  • 130
  • 131
  • 132
  • 133
  • 134
  • 135
  • 136
  • 137
  • 138
  • 139
  • 140
  • 141
  • 142
  • 143
  • 144
  • 145
  • 146
  • 147
  • 148
  • 149
  • 150
  • 151
  • 152
  • 153
  • 154
  • 155
  • 156
  • 157
  • 158
  • 159
  • 160
  • 161
  • 162
  • 163
  • 164
  • 165
  • 166
  • 167
  • 168
  • 169
  • 170
  • 171
  • 172
  • 173
  • 174
  • 175
  • 176
  • 177
  • 178
  • 179
  • 180
  • 181
  • 182
  • 183
  • 184
  • 185
  • 186
  • 187
  • 188
  • 189
  • 190
  • 191
  • 192
  • 193
  • 194
  • 195
  • 196
  • 197
  • 198
  • 199
  • 200
  • 201
  • 202
  • 203
  • 204
  • 205
  • 206
  • 207
  • 208
  • 209
  • 210
  • 211
  • 212
  • 213
  • 214
  • 215
  • 216
  • 217
  • 218
  • 219
  • 220
  • 221
  • 222
  • 223
  • 224
  • 225
  • 226
  • 227
  • 228
  • 229
  • 230
  • 231
  • 232
  • 233
  • 234
  • 235
  • 236
  • 237
  • 238
  • 239
  • 240
声明:本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:【wpsshop博客】
推荐阅读
相关标签
  

闽ICP备14008679号