当前位置:   article > 正文

变分自编码器VAE_变分自编码器 matlab

变分自编码器 matlab

1. VAE & GAN

变分自编码器(Variational auto-encoder,VAE)是一类重要的生成模型(generative model)

除了VAEs,还有一类重要的生成模型GANs

VAE 跟 GAN 比较,目标基本是一致的——希望构建一个从隐变量 Z 生成目标数据 X 的模型,但是实现上有所不同。

生成模型的难题就是判断生成分布与真实分布的相似度,因为我们只知道两者的采样结果,不知道它们的分布表达式。 KL 散度是根据两个概率分布的表达式来算它们的相似度的,我们只有样本本身,没有分布表达式,当然也就没有方法算 KL 散度。

GAN 的思路很直接粗犷:既然没有合适的度量,那我干脆把这个度量也用神经网络训练出来吧

与GANs不同的是,VAEs是知道图像的密度函数(PDF)的(或者说,是我们设定的)

2. VAE

2.1 简单引入

观测数据是X,而X由隐变量Z产生,由Z->X是生成模型\theta,就是解码器;

而由x->z是识别模型\phi,类似于自编码器的编码器。

2.2 传统理解

有一批数据样本 {X1,…,Xn},其整体用 X 来描述,如果能得到其分布,那我直接根据 p(X) 来采样,就可以得到所有可能的 X 了,但这是不现实的,因此引入:

p(X|Z) 是一个由 Z 来生成 X的模型,而我们假设 Z 服从标准正态分布,也就是 p(Z)=N(0,I)。如果这个能实现,那么我们就可以先从标准正态分布中采样一个 Z,然后根据 Z 来算一个 X

但观察上图,经过采样出来的Zk,进而生成的Xk不再对应着原来的 Xk,直接最小化 D(X̂ k,Xk)^2是很不科学的,而事实上代码也不是这样实现的

2.3 真正理解

在整个 VAE 模型中,并没有去使用 p(Z)(先验分布)是正态分布的假设,用的是假设 p(Z|X)(后验分布)是正态分布

给定一个真实样本 Xk,假设存在一个专属于 Xk 的分布 p(Z|Xk),服从正态分布;然后生成器X=g(Z),希望能够把从分布 p(Z|Xk) 采样出来的一个 Zk 还原为 Xk

因此,

        有多少个 X 就有多少个正态分布了。参数:均值 μ 和方差 σ^2(多元的话,都是向量)

于是构建两个神经网络 μk=f1(Xk),logσ^2=f2(Xk) 来算它们了。因为 σ^2 总是非负的,需要加激活函数处理,而拟合 logσ^2 不需要加激活函数,因为它可正可负。

但是,如果根据上图训练,模型希望重构 X,也就是最小化 D(X̂k,Xk)^2,但是这个重构过程受到噪声的影响,因为Zk 是通过重新采样过的。不过好在这个噪声强度(也就是方差)通过一个神经网络算出来的,所以最终模型为了重构得更好,肯定会想尽办法让方差为0。

方差为 0 的话,也就没有随机性了,所以采样其实都只是得到确定的结果(也就是均值)

模型会慢慢退化成普通的 AutoEncoder,噪声不再起作用

2.4 进一步理解--->分布标准化

VAE 还让所有的 p(Z|X) 都向标准正态分布看齐,这样就防止了噪声为零

假设  所有的 p(Z|X) 都很接近标准正态分布 N(0,I),那么根据定义:

因此,p(Z) 满足标准正态分布。然后我们就可以放心地从 N(0,I) 中采样来生成图像了。

 2.5 损失

 怎么让所有的 p(Z|X) 都向 N(0,I) 看齐呢?最直接的方法是在重构误差的基础上中加入额外的 loss

因此,将一般(各分量独立的)正态分布与标准正态分布的 KL 散度KL(N(μ,σ^2)‖N(0,I))作为这个额外的 loss,计算结果为:

 2.6 模型实现

我们要从 p(Z|Xk) 中采样一个 Zk 出来,尽管我们知道了 p(Z|Xk) 是正态分布,但是均值方差都是靠模型算出来的,我们要靠这个过程反过来优化均值方差的模型,但是“采样”这个操作是不可导的,而采样的结果是可导的,于是我们利用了一个事实:

 这样一来,“采样”这个操作就不用参与梯度下降了

3. VAE本质

VAE就是在自编码器模型上做进一步变分处理,使得编码器的输出结果能对应到目标分布的均值和方差;因此,它的 Encoder 有两个,一个用来计算均值,一个用来计算方差

本质上就是在常规的自编码器的基础上,对 encoder 的结果(在VAE中对应着计算均值的网络)加上了“高斯噪声”,使得结果 decoder 能够对噪声有鲁棒性;而那个额外的 KL loss(目的是让均值为 0,方差为 1),事实上就是相当于对 encoder 的一个正则项,希望 encoder 出来的东西零均值。

另外一个 encoder(计算方差的网络)是用来动态调节噪声的强度的。当 decoder 还没有训练好时(重构误差远大于 KL loss),就会适当降低噪声(KL loss 增加),使得拟合起来容易一些(重构误差开始下降)。反之,如果 decoder 训练得还不错时(重构误差小于 KL loss),这时候噪声就会增加(KL loss 减少),使得拟合更加困难了(重构误差又开始增加),这时候decoder 就要想办法提高它的生成能力了

重构的过程是希望没噪声的,而 KL loss 则希望有高斯噪声的,两者是对立的。所以,VAE 跟 GAN 一样,内部其实是包含了一个对抗的过程,只不过它们两者是混合起来,共同进化的

4. auto-encoder 和 VAE 对比

Auto-Encoder能够把一个高维的向量(28*28图像)压缩到只有30维,并且解码回的图像具备清楚的辨认度(如下图)。

但是这并没有达到我们真正想要构造的生成模型的标准,因为,对于一个生成模型而言,解码器部分应该是单独能够提取出来的,并且对于在规定维度下任意采样的一个编码,都应该能通过解码器产生一张清晰且真实的图片。

auto-encoder无法达到这一标准的原因:

 

 如上图所示,假设有两张训练图片,经过训练自编码器模型已经能无损地还原这两张图片。接下来,我们在code空间上,两张图片的编码点中间处取一点,然后将这一点交给解码器,我们希望新的生成图片是一张清晰的图片(类似3/4全月的样子)。但是,实际的结果是,生成图片是模糊且无法辨认的乱码图。一个比较合理的解释是,因为编码和解码的过程使用了深度神经网络,这是一个非线性的变换过程,所以在code空间上点与点之间的迁移是非常没有规律的。

为了解决这个问题,我们可以引入噪声(VAE),使得图片的编码区域得到扩大,从而掩盖掉失真的空白编码点。

如上图所示,现在在给两张图片编码的时候加上一点噪音,使得每张图片的编码点出现在绿色箭头所示范围内,于是在训练模型的时候,绿色箭头范围内的点都有可能被采样到,这样解码器在训练时会把绿色范围内的点都尽可能还原成和原图相似的图片。然后我们可以关注之前那个失真点,现在它处于全月图和半月图编码的交界上,于是解码器希望它既要尽量相似于全月图,又要尽量相似于半月图,于是它的还原结果就是两种图的折中(3/4全月图)。

 由此我们发现,给编码器增添一些噪音,可以有效覆盖失真区域。不过这还并不充分,因为在上图的距离训练区域很远的黄色点处,它依然不会被覆盖到,仍是个失真点。为了解决这个问题,我们可以试图把噪音无限拉长,使得对于每一个样本,它的编码会覆盖整个编码空间,不过我们得保证,在原编码附近编码的概率最高,离原编码点越远,编码概率越低。在这种情况下,图像的编码就由原先离散的编码点变成了一条连续的编码分布曲线,如下图所示。

 

5. pytorch代码

  1. import torch
  2. import torchvision
  3. from torch import nn
  4. from torch import optim
  5. import torch.nn.functional as F
  6. from torch.autograd import Variable
  7. from torch.utils.data import DataLoader
  8. from torchvision import transforms
  9. from torchvision.utils import save_image
  10. from torchvision.datasets import MNIST
  11. import os
  12. import datetime
  13. if not os.path.exists('./vae_img'):
  14. os.mkdir('./vae_img')
  15. def to_img(x):
  16. x = x.clamp(0, 1)
  17. x = x.view(x.size(0), 1, 28, 28)
  18. return x
  19. num_epochs = 100
  20. batch_size = 128
  21. learning_rate = 1e-3
  22. img_transform = transforms.Compose([
  23. transforms.ToTensor()
  24. # transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
  25. ])
  26. dataset = MNIST('./data', transform=img_transform, download=True)
  27. dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
  28. class VAE(nn.Module):
  29. def __init__(self):
  30. super(VAE, self).__init__()
  31. self.fc1 = nn.Linear(784, 400)
  32. self.fc21 = nn.Linear(400, 20)
  33. self.fc22 = nn.Linear(400, 20)
  34. self.fc3 = nn.Linear(20, 400)
  35. self.fc4 = nn.Linear(400, 784)
  36. def encode(self, x):
  37. h1 = F.relu(self.fc1(x))
  38. return self.fc21(h1), self.fc22(h1)
  39. def reparametrize(self, mu, logvar):
  40. std = logvar.mul(0.5).exp_()
  41. if torch.cuda.is_available():
  42. eps = torch.cuda.FloatTensor(std.size()).normal_()
  43. else:
  44. eps = torch.FloatTensor(std.size()).normal_()
  45. eps = Variable(eps)
  46. return eps.mul(std).add_(mu)
  47. def decode(self, z):
  48. h3 = F.relu(self.fc3(z))
  49. # return F.sigmoid(self.fc4(h3))
  50. return torch.sigmoid(self.fc4(h3))
  51. def forward(self, x):
  52. mu, logvar = self.encode(x)
  53. z = self.reparametrize(mu, logvar)
  54. return self.decode(z), mu, logvar
  55. strattime = datetime.datetime.now()
  56. model = VAE()
  57. if torch.cuda.is_available():
  58. # model.cuda()
  59. print('cuda is OK!')
  60. model = model.to('cuda')
  61. else:
  62. print('cuda is NO!')
  63. reconstruction_function = nn.MSELoss(size_average=False)
  64. # reconstruction_function = nn.MSELoss(reduction=sum)
  65. def loss_function(recon_x, x, mu, logvar):
  66. """
  67. recon_x: generating images
  68. x: origin images
  69. mu: latent mean
  70. logvar: latent log variance
  71. """
  72. BCE = reconstruction_function(recon_x, x) # mse loss
  73. # loss = 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
  74. KLD_element = mu.pow(2).add_(logvar.exp()).mul_(-1).add_(1).add_(logvar)
  75. KLD = torch.sum(KLD_element).mul_(-0.5)
  76. # KL divergence
  77. return BCE + KLD
  78. optimizer = optim.Adam(model.parameters(), lr=1e-3)
  79. for epoch in range(num_epochs):
  80. model.train()
  81. train_loss = 0
  82. for batch_idx, data in enumerate(dataloader):
  83. img, _ = data
  84. img = img.view(img.size(0), -1)
  85. img = Variable(img)
  86. img = (img.cuda() if torch.cuda.is_available() else img)
  87. optimizer.zero_grad()
  88. recon_batch, mu, logvar = model(img)
  89. loss = loss_function(recon_batch, img, mu, logvar)
  90. loss.backward()
  91. # train_loss += loss.data[0]
  92. train_loss += loss.item()
  93. optimizer.step()
  94. if batch_idx % 100 == 0:
  95. endtime = datetime.datetime.now()
  96. print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f} time:{:.2f}s'.format(
  97. epoch,
  98. batch_idx * len(img),
  99. len(dataloader.dataset),
  100. 100. * batch_idx / len(dataloader),
  101. loss.item() / len(img),
  102. (endtime-strattime).seconds))
  103. print('====> Epoch: {} Average loss: {:.4f}'.format(
  104. epoch, train_loss / len(dataloader.dataset)))
  105. if epoch % 10 == 0:
  106. # 生成图像
  107. z = torch.randn(batch_size, 20).to(device)
  108. out = model.decode(z).view(-1, 1, 28, 28)
  109. save_image(out, './vae_img/sampled-{}.png'.format(epoch))
  110. # 重构图像
  111. save = to_img(recon_batch.cpu().data)
  112. save_image(save, './vae_img/image_{}.png'.format(epoch))
  113. torch.save(model.state_dict(), './vae.pth')

Reference:

https://zhuanlan.zhihu.com/p/34998569

http://www.gwylab.com/note-vae.html

https://blog.csdn.net/weixin_36815313/article/details/107728274

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/笔触狂放9/article/detail/926016
推荐阅读
相关标签
  

闽ICP备14008679号