当前位置:   article > 正文

生成对抗网络消除图像模糊(Keras)_生成对抗网络去模糊

生成对抗网络去模糊

2017年,乌克兰天主教大学、布拉格捷克理工大学和解决方案提供商Eleks联手公布了一篇论文,文章标题为《DeblurGAN: Blind Motion Deblurring Using Conditional Adversarial Networks》
这篇文章中,研究人员提出一种基于条件对抗式生成网络和内容损失(content loss)的端对端学习法DeblurGAN,用来去除图像上因为相机和物体相对运动而产生的模糊。


论文地址:https://arxiv.org/abs/1711.07064
pytorch实现: https://github.com/KupynOrest/DeblurGAN
keras实现: https://github.com/RaphaelMeudec/deblur-gan

 

去模糊效果:

 

这里看一下keras的代码实现。

 

生成器网络

keras代码:

  1. def generator_model():
  2. inputs = Input(shape=image_shape)
  3. x = ReflectionPadding2D((3, 3))(inputs)
  4. x = Conv2D(filters=ngf, kernel_size=(7, 7), padding='valid')(x)
  5. x = BatchNormalization()(x)
  6. x = Activation('relu')(x)
  7. n_downsampling = 2
  8. for i in range(n_downsampling):
  9. mult = 2**i
  10. x = Conv2D(filters=ngf*mult*2, kernel_size=(3, 3), strides=2, padding='same')(x)
  11. x = BatchNormalization()(x)
  12. x = Activation('relu')(x)
  13. mult = 2**n_downsampling
  14. for i in range(n_blocks_gen):
  15. x = res_block(x, ngf*mult, use_dropout=True)
  16. for i in range(n_downsampling):
  17. mult = 2**(n_downsampling - i)
  18. x = Conv2DTranspose(filters=int(ngf * mult / 2), kernel_size=(3, 3), strides=2, padding='same')(x)
  19. x = BatchNormalization()(x)
  20. x = Activation('relu')(x)
  21. x = ReflectionPadding2D((3, 3))(x)
  22. x = Conv2D(filters=output_nc, kernel_size=(7, 7), padding='valid')(x)
  23. x = Activation('tanh')(x)
  24. x = Lambda(lambda z: z*2)(x)
  25. outputs = Add()([x, inputs])
  26. outputs = Lambda(lambda z: z/3)(outputs)
  27. model = Model(inputs=inputs, outputs=outputs, name='Generator')
  28. return model

1. 对输入图像做一个边界扩展(宽高各6个像素)
2. 卷积核大小7×7的卷积,方式是valid,之后执行批规范化BN操作,再执行 Relu激活函数
3. 两次下采样操作,每次特征图大小缩小为之前的二分之一,具体操作包括 same 卷积,BN和Relu激活
4. 9个残差模块,每个模块的操作包括边界扩充、卷积、BN、Relu激活、扩充、卷积、BN,其中dropout可选,接着残差模块的是2组卷积、BN和激活操作。
5. 最后是边界扩充、卷积、Tanh激活、Add输入操作,输出结果是一个维度大小跟输入一致的图片。

生成器结构图:

 

判别器网络

keras代码:

  1. def discriminator_model():
  2. """Build discriminator architecture."""
  3. n_layers, use_sigmoid = 3, False
  4. inputs = Input(shape=input_shape_discriminator)
  5. x = Conv2D(filters=ndf, kernel_size=(4, 4), strides=2, padding='same')(inputs)
  6. x = LeakyReLU(0.2)(x)
  7. nf_mult, nf_mult_prev = 1, 1
  8. for n in range(n_layers):
  9. nf_mult_prev, nf_mult = nf_mult, min(2**n, 8)
  10. x = Conv2D(filters=ndf*nf_mult, kernel_size=(4, 4), strides=2, padding='same')(x)
  11. x = BatchNormalization()(x)
  12. x = LeakyReLU(0.2)(x)
  13. nf_mult_prev, nf_mult = nf_mult, min(2**n_layers, 8)
  14. x = Conv2D(filters=ndf*nf_mult, kernel_size=(4, 4), strides=1, padding='same')(x)
  15. x = BatchNormalization()(x)
  16. x = LeakyReLU(0.2)(x)
  17. x = Conv2D(filters=1, kernel_size=(4, 4), strides=1, padding='same')(x)
  18. if use_sigmoid:
  19. x = Activation('sigmoid')(x)
  20. x = Flatten()(x)
  21. x = Dense(1024, activation='tanh')(x)
  22. x = Dense(1, activation='sigmoid')(x)
  23. model = Model(inputs=inputs, outputs=x, name='Discriminator')
  24. return model

1. 卷积+LeakyRelu操作,LeakyRelu第三象限的斜率设为0.2
2. 4组 卷积+BN+LeakyRelu
3. sigmoid激活可选
4. Flatten展平,为全连接层做准备
5. 2个Dense全连接层,最后输出做sigmoid,限制结果到0~1

判别器结构图:

 

损失函数

感知loss(生成器)
使用的是VGG16分别提取生成图片和真实图片的特征,比较的是block3_conv3层的输出,loss是特征差的平方再取均值

  1. def perceptual_loss(y_true, y_pred):
  2. vgg = VGG16(include_top=False, weights='imagenet', input_shape=image_shape)
  3. loss_model = Model(inputs=vgg.input, outputs=vgg.get_layer('block3_conv3').output)
  4. loss_model.trainable = False
  5. return K.mean(K.square(loss_model(y_true) - loss_model(y_pred)))


 Wasserstein 损失
对整个模型(G+D)的输出执行的 Wasserstein 损失,它取的是两个图像差异的均值。这种损失函数可以改善生成对抗网络的收敛性。

  1. def wasserstein_loss(y_true, y_pred):
  2. return K.mean(y_true*y_pred)


训练过程

1. 分批次加载模糊图片和清晰图片数据,并随机排序
2. 每一轮迭代中用本batch_size个训练数据先对判别器执行5次优化,5次优化中每次又会使用清晰图片(标签是1)和模糊图片(标签是0)分别对判别器做一次优化,相当于10次优化。loss函数使用的是wasserstein距离。
3. 关闭判别器参数更新,使判别器不可训练,以下训练生成器,生成器的优化标准有两个,一个是跟清晰图像的差异,一个是迷惑判别器的能力。
4. 生成器和判别器的联合网络输出是生成器生成的图片+判别器的判别值(0到1),所以联合网络d_on_g的train_on_batch训练函数的第二个参数(该参数应该传入训练数据的真实标签)含有两个值,一个是真实清晰图像,一个是真实图像的标签(为1);  d_on_g损失函数优化的目标是G生成的图像跟清晰图像的差异越来越小(使用VGG16提取特征并比较),并且该生成图像经过判别器后的输出跟清晰图片经过判别器的输出的差异越来越小(使用wasserstein距离)。 文中作者设置这两个loss的比重为100:1。
5. 重复执行以上训练过程

 

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/木道寻08/article/detail/744630
推荐阅读
相关标签
  

闽ICP备14008679号