当前位置:   article > 正文

图像生成--对抗生成模型_生成对抗模型

生成对抗模型

生成模型概述

对抗生成模型

机器学习中的两大主要问题:

  1. 判别
  2. 生成

判别模型的典型代表即为图像分类任务,即给定一个数据,判定他是哪一类。

判别模型学习到的是一个概率(贝叶斯过程)

而生成模型的区别在于,给定一个数据,将其生成为预期数据。

 

在数学上,生成模型与判别模型的区别在于:

给定观测值x:

  • 判别模型旨在判别得到y的概率

  • 生成模型旨在根据指定的y得到x的概率

生成模型的应用

超分辨率

图像生成(风格迁移)

生成模型原理简要说明

在GoodFellow的论文中,以最大似然估计进行举例。

首先需要说明的问题是:

生成模型的本质,在于从训练数据中学习到数据的分布

学习到了分布之后,给定一个随机的噪声

过程可以简单地理解为:这个噪声中,符合指定分布的内容得到加强,不符合指定分布的内容会被削弱

当在迭代过程中,数据逐渐贴合预期的输入,从而看上去更逼真。

方法分类

基于最大似然估计的数据生成,是生成模型的理论基础。

按照不同的形式和流派,大致可以分成下面的类别。

此处不对研究脉络的具体细节进行探究,只是对原理进行比喻式介绍。

  1. Explicit density: 显性密度。也就是说,我们在这类方法中,需要给出分布模型的具体形式(密度函数),通过各种迭代运算,来得到模型的真实参数。

  2. Implicit density:隐性密度。在这类方法中,不指定数据分布密度函数,而是通过数据分布所满足的条件,用拟合能力比较强的模型来寻找合适的模型和分布参数。

GAN则属于隐式密度方法,不需要指定模型的具体分布密度函数,来得到较好的分布拟合。

拓展:生成模型可以视为一种损失函数

该部分内容会在后续进行进一步展开,此处只做简单介绍。

首先,我们通常会采用显式的函数作为损失函数。

这种方式带来很多便利,但并不一定精确(对特定任务来说)。

我们用对抗生成式的模型对网络进行约束,从而能够不使用显性的函数来约束模型。

对于用于约束的网络,我们将一些必须要满足的条件作为约束目标,从而令约束模型进一步地摸索出更好的约束边界。

生成对抗模型GAN

Generative Adversarial Network,GAN是一种深度学习模型,属于一种无监督学习的方法。

其目的在于,从数据中学习分布,来得到足以以假乱成真的数据。

为了达到这个目的,通常包含两个基本模型:生成器和判别器。(generative model, G)和(discriminative model, D)

判别模型学习“分界面(分解曲线)”

在训练过程中,利用合理的结构和设定,令二者满足纳什均衡,来得到最优解。

GAN原理

GAN的过程,离不开两个关键内容:生成与对抗。

Goodfellow的例子如下:

  1. 一个城市中,有一群小偷(生成器)和一群警察(判别器)。
  2. 小偷的目的在于,想方设法地欺骗警察;
  3. 而警察的目的在于,想方设法地不受欺骗。

这样一来,小偷在不断的欺骗和被识破的过程中不断精进技能,从而掌握了更加不易被识破的欺骗技能;

警察则在被欺骗的过程中,不断提高辨识功能,从而对欺骗的细节做出判断,更加接近本质。

生成器 生成器采用随机输入,尝试输出样本数据。根据输入的样本随机产生一个数据,将其送入鉴别器

鉴别器 鉴别器的任务在于,接受两个输入,分别是生成器的输入和真实数据,判别器的目的在于判断生成器的输入是不是真的。

数学表达

上述过程中,希望判别器能够最大程度地判别出真实数据为真,生成数据为假

而生成器则是能够最大程度地令判别器产生误判

训练过程

两阶段训练:

  1. 固定生成器参数,训练判别器

  2. 固定判别器,训练生成器

GAN模型的训练过程是一个非常复杂的训练过程,早期的GAN训练也非常麻烦。

训练难度之所以大,一个重要的原因在于,难以掌控生成器和判别器的能力。

理解:

如果小偷很厉害,则警察无法从中提升判别能力;

如果警察很厉害,小偷则会被一网打尽,无法提升其“造假能力”

理论上,如果判别器过于强大,生成器则会由于步长太大无法找到全局最优解。

一个简单的例子在于,人类现代科技无法从外星人科技中吸收影响,从而无法引发科技进步。

因此,通常是训练多轮生成器,再训练少轮判别器

通俗来说,GAN训练的过程应当是一个循序渐进,相辅相成的过程。如果一开始,通过载入与训练模型令判别器具有很高的能力,往往会令GAN难以有效收敛。

代码实践

参考

In [1]:

  1. import torch
  2. import torchvision
  3. import torch.nn as nn
  4. from torchvision import transforms
  5. from tqdm import tqdm

In [2]:

  1. transform = transforms.Compose([
  2. transforms.ToTensor(),
  3. transforms.Normalize(.5, .5)
  4. ])
  5. train_data = torchvision.datasets.MNIST('data',
  6. train=True,
  7. transform=transform,
  8. download=True)
  9. dataloader = torch.utils.data.DataLoader(train_data, batch_size=64, shuffle=True)
  10. Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
  11. Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to data/MNIST/raw/train-images-idx3-ubyte.gz
  0%|          | 0/9912422 [00:00<?, ?it/s]
Extracting data/MNIST/raw/train-images-idx3-ubyte.gz to data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to data/MNIST/raw/train-labels-idx1-ubyte.gz
  0%|          | 0/28881 [00:00<?, ?it/s]
Extracting data/MNIST/raw/train-labels-idx1-ubyte.gz to data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to data/MNIST/raw/t10k-images-idx3-ubyte.gz
  0%|          | 0/1648877 [00:00<?, ?it/s]
Extracting data/MNIST/raw/t10k-images-idx3-ubyte.gz to data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to data/MNIST/raw/t10k-labels-idx1-ubyte.gz
  0%|          | 0/4542 [00:00<?, ?it/s]
Extracting data/MNIST/raw/t10k-labels-idx1-ubyte.gz to data/MNIST/raw

In [3]:

  1. # generator
  2. class Gen(nn.Module):
  3. def __init__(self):
  4. super(Gen, self).__init__()
  5. self.gen = nn.Sequential(nn.Linear(100, 256),
  6. nn.ReLU(),
  7. nn.Linear(256, 512),
  8. nn.ReLU(),
  9. nn.Linear(512, 28*28),
  10. nn.Tanh())
  11. def forward(self, x):
  12. img = self.gen(x)
  13. img = img.view(-1, 28, 28)
  14. return img

In [4]:

  1. # discriminator
  2. class Dis(nn.Module):
  3. def __init__(self):
  4. super(Dis, self).__init__()
  5. self.dis = nn.Sequential(nn.Linear(28*28, 512),
  6. nn.LeakyReLU(),
  7. nn.Linear(512, 256),
  8. nn.LeakyReLU(),
  9. nn.Linear(256, 1),
  10. nn.Sigmoid())
  11. def forward(self, x):
  12. x = x.view(-1, 28*28)
  13. x = self.dis(x)
  14. return x

In [5]:

  1. gen = Gen().to('cpu')
  2. dis = Dis().to('cpu')
  3. d_optim = torch.optim.Adam(dis.parameters(), lr=1e-4)
  4. g_optim = torch.optim.Adam(gen.parameters(), lr=1e-4)
  5. loss_func = torch.nn.BCELoss()

In [6]:

  1. # train
  2. loss_d = []
  3. loss_g = []
  4. for epoch in range(50):
  5. d_epoch_loss = 0
  6. g_epoch_loss = 0
  7. batch_count = len(dataloader)
  8. for i, (img, _) in enumerate(tqdm(dataloader)):
  9. img = img.to('cpu')
  10. size = img.size(0)
  11. random_noise = torch.randn(size, 100, device='cpu')
  12. d_opt.zero_grad()
  13. real_output = dis(img)
  14. d_real_loss = loss_func(real_output,
  15. torch.ones_like(real_output))
  16. d_real_loss.backward()
  17. gen_img = gen(random_noise)
  18. fake_output = dis(gen_img.detach())
  19. d_fake_loss = loss_func(fake_output,
  20. torch.zeros_like(fake_output))
  21. d_fake_loss.backward()
  22. d_loss = d_real_loss + d_fake_loss
  23. d_optim.step()
  24. g_optim.zero_grad()
  25. fake_output = dis(gen_img)
  26. g_loss = loss_func(fake_output,
  27. torch.ones_like(fake_output))
  28. g_loss.backward()
  29. g_optim.step()
  30. torch.save(gen.state_dict(), str(epoch).zfill(2) + ".pth")
100%|█████████████████████████████████████████| 938/938 [00:22<00:00, 42.29it/s]
100%|█████████████████████████████████████████| 938/938 [00:22<00:00, 42.55it/s]
100%|█████████████████████████████████████████| 938/938 [00:22<00:00, 42.48it/s]
100%|█████████████████████████████████████████| 938/938 [00:22<00:00, 42.15it/s]
 57%|███████████████████████▍                 | 537/938 [00:12<00:09, 41.52it/s]
---------------------------------------------------------------------------

In [7]:

  1. # show result
  2. def result_show(weight, test_input):
  3. gen = Gen().to('cpu')
  4. gen.load_state_dict(torch.load(weight))
  5. gen.eval()
  6. plot_img(gen, test_input)

In [8]:

  1. # plot image
  2. import matplotlib.pyplot as plt
  3. def plot_img(model, _input):
  4. prediction = model(_input).detach().cpu().numpy()
  5. print(prediction.shape)
  6. fig = plt.figure(figsize=(4, 4))
  7. for i in range(16):
  8. plt.subplot(4, 4, i+1)
  9. plt.imshow((prediction[i]+1)/2)
  10. plt.axis('off')
  11. plt.show()

In [9]:

  1. random_noise = torch.randn(size, 100)
  2. import numpy as np
  3. result_show('./00.pth', random_noise)
(64, 28, 28)

GAN模型进阶

GAN模型的本质

学习训练数据的分布,符合训练数据分布的数据,具有较好的可视化效果;

在分布之外的数据,可视化效果较差。

那么GAN模型的根本问题是:

找一个生成模型G,该模型定义了概率分布

给定一个分布z,找到一个G,可以使分布比较相似。

 具体上,从符合z分布中采样多个点,得到了多个x。

进而,从创造一个D,用于引导采样。

需要说明的是,D的loss值与生成数据和真实数据的内容息息相关。

如果说损失越大,则越说明生成的数据和真实数据越接近。

一个直观的例子

李宏毅推荐的例子

GAN的本质:散度

散度定义(divergence):p(x)和q(x)到底有多不一样

性质1: 散度取值在0-1之间,越接近于0,分布越相似。否则分布区别越大。

那么GAN的本质,在于如何度量散度,即如何设定一个合适的函数f,来得到一个良好的分布拟合。

KL散度:描述数据分布之间的相似性

卡方散度:判断两个样本是否符合相同的分布

关于散度和GAN的关系

散度用于评价分布的相似程度。

常用的KL散度,公式为

但KL散度存在不对称性,在basic gan里,用的是JS散度

使用JS散度存在一个比较大的问题,即如果分布相差较远,则会等于一个恒定的值。不利于模型收敛。

因此,可以灵活地调整散度,来适应不同类型的数据。

如何把散度作为优化目标?

散度可以衡量两个分布,那么如何将散度作为他的优化函数呢?

凸共轭

 

红线部分即共轭函数的曲线,可以看出他也是凸函数。

如何求解一个函数的凸函数?

采用极值求导的方式求解。

例如f(x)=xlog⁡x

一般形式的GAN

回到GAN中,有

那么我们的目的就在于:

直观上的感受:

另一种思路 WGAN

有颜色的色块表示把第i行的分布,修改到第j行。(推土机)

运送路径越多,运送的货物越多,则做的功越大。

那么首先定义运送的功

进而,只需要找到运送功最小的那个方案就可以了

注意,这里需要定义D的函数需要满足1-Lipschitz,即

其中,k=1

这样的作用在于,令y的增长不超过x。也就是限制模型不要更新的太快。

否则,如果取消限制,那么就会令D直接爆炸。

求解

声明:本文内容由网友自发贡献,转载请注明出处:【wpsshop博客】
推荐阅读
相关标签
  

闽ICP备14008679号