当前位置:   article > 正文

细说VAE的来龙去脉 (Variational Autoencoder)_vae中采样为什么 exp(0.5logvar)

vae中采样为什么 exp(0.5logvar)

目标是啥? 假如有下面这样手写数字图片集,我们想构造一个model可以无限生成类似的(同分布)的手写数字图片.

第一反映是做一个d纬度的向量Z,其中z1代表数字,z2代表粗细,z3代表倾斜度,....

完后把Z,输入一个带有参数的model,取训练集合里面符合Z的图片,作为目标,通过Supervised learning学习model的参数就好了.

But,一个大问题是人为手工的给Z的每个纬度z_i构造含义,非常费时费力,容易漏纬度,而且构造出来的纬度之间很难确保不是相互纠缠的.

怎么办?答案是 ,不要手工人为的构造Z人为的赋予每个纬度的含义,而是固定Z的长度为d,通过网络学习出这个Z里面每个纬度的含义.

也就是说Z的前面会有若干层的网络,那么最开始层的输入是什么?输入z是固定的一个常向量?不行,那么只能输出一张稳定的图片了.输入z是一个d维的随机向量?

也不行,因为我们要输出某一个特定分布的手写数字图片,所以z要服从某一个特定分布,那么z选择什么分布合适,大量实验发现标准正态分布N(0;I)就非常有效,

并且数学上证明了输入服从N(0;I)经过复杂的function或NN后,可以转换成任何一种特定的分布,就比如我们要生成的手写数字图片的分布.

下面数学化描述我们的问题.

z~P(z)  is N(0;I)

X 是图片gt数据集的随机变量

想找到一个模型P(X|z;\Theta ),使得左边的P(X)最大(根据maximize the likelihood of training data);

这个积分即全概率公式的积分形式,

 

如果z~P(z) 真的从N(0;I)里面取值,会有大量P(X|z)为零,增大了z的取值范围和计算复杂度。

z从Q(z|X)生成,再把生成的z放入到P(X|z)生成X,增大了P(X|z),缩小了z的取值范围,这时的z也更靠谱了。

那么Q(z|X)是否等于原来的P(z|X)呢?显然不完全等于,Q(z|X)偏小,再者P(z|X)服从于正太分布,一个编码X生成z的网络却可以生成各种不同分布,

所以我们要对Q(z|X)进行约束使其服从与正太分布,且尽可能通过增加参数来扩大它的容量。

所以构造 Q(z|X) = 

上面的u是正太分布的均值,是一个带参数\Theta的固定网络,\tiny \sum是均方差,一个带参数\Theta的固定网络。

那么这样的Q(z|X)有多接近原来的P(z|X)呢?嗯,这个距离是我们后面调参时需要缩小的一个目标。

代到KL divergence里面如下,

下面做一系列神奇的数学变换,

好了到了(5)式,左边:第一项就是我们最后要最大化的目标项,第二项是附加项,为了使Q(z|X)尽可能的接近P(z|X)

也就是说左边就是我们要最大化的目标了。

那右边自然就是我们要构造的含参数的function了,里面的参数通过用梯度下降法,loop训练数据集来调节使得左边的目标最大,问题即求解了。

来看看右边的第一项是个什么?嗯,它的输入是Q生成的z,输出P(X|z),这显然是一个解码器。

右边的第二项是个什么?是两个正太分布的KL divergence,展开看看。

这时,看一眼展开的结果,也就是(7)的右边,仅仅是带参数\small \Theta的固定网络\tiny \sum和u输出的一种组合而已。

初步搭一个网络看看,

要从下往上看,先给X编码,经过带参数的固定网络\tiny \sum和u, 此时通过第一个蓝框KL网络来构造(5)的第二项,嗯

(5)的右边第一项呢?网络\tiny \sum和u的输出在红框部分做sample操作,sample分布为u和\tiny \sum的正太分布,得到z

把z喂给一个解码器网络P,得到f(z), 让它尽可能贴近X,等价于log P(X|z)越大。网络搭建完毕,计算出两个蓝框的loss ,

通过梯度下降使之变小,调节参数,从而得到我们目标的model function。

可是在梯度下降的过程中,这个网络有一个sample节点,也就是那个红色的方框,是不连续的function,loss传回的梯度在这里断掉了?怎么办?

答案,用参数化的网络节点替代掉那个红框。替代后如下,

这里的\small \epsilon从一个标准正太分布随机生成,是一个新的输入,\small \epsilon的上面就是参数化后的网络,完美的替换掉了sample u和\tiny \sum分布的那个操作,

且这里保证了网络的连续性,loss向后传播在这里也不用受阻了。

训练就简单了,把我们数据集里面的图片X依次喂入网络,向前传播计算loss,向后传播,update参数,一直训练下去,直到loss趋于稳定,

即得到我们想要的model 了。

到这里有人要问了,可可,Encoder(Q) ,u(X)和\tiny \sum(X),还有Decoder(P)这些网络到底是什么样子的阿...

其实这些网络模块的具体设计要看你生成的X到底是什么类型的数据拉.

如果是手写数字图片,我们用vanila版本的VAE就可以就决问题,用pytorch表示的网络结构大概如下:

  1. self.encoder
  2. Sequential(
  3. (0): Sequential(
  4. (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
  5. (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  6. (2): LeakyReLU(negative_slope=0.01)
  7. )
  8. (1): Sequential(
  9. (0): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
  10. (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  11. (2): LeakyReLU(negative_slope=0.01)
  12. )
  13. (2): Sequential(
  14. (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
  15. (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  16. (2): LeakyReLU(negative_slope=0.01)
  17. )
  18. (3): Sequential(
  19. (0): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
  20. (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  21. (2): LeakyReLU(negative_slope=0.01)
  22. )
  23. (4): Sequential(
  24. (0): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
  25. (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  26. (2): LeakyReLU(negative_slope=0.01)
  27. )
  28. )
  29. self.decoder
  30. Sequential(
  31. (0): Sequential(
  32. (0): ConvTranspose2d(512, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), output_padding=(1, 1))
  33. (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  34. (2): LeakyReLU(negative_slope=0.01)
  35. )
  36. (1): Sequential(
  37. (0): ConvTranspose2d(256, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), output_padding=(1, 1))
  38. (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  39. (2): LeakyReLU(negative_slope=0.01)
  40. )
  41. (2): Sequential(
  42. (0): ConvTranspose2d(128, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), output_padding=(1, 1))
  43. (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  44. (2): LeakyReLU(negative_slope=0.01)
  45. )
  46. (3): Sequential(
  47. (0): ConvTranspose2d(64, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), output_padding=(1, 1))
  48. (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  49. (2): LeakyReLU(negative_slope=0.01)
  50. )
  51. )

注意,这里的encoder只是从图片提取了非常丰富的features,比如

144张图片从torch.Size([144, 3, 64, 64])到torch.Size([144, 512, 2, 2]),

拉平变成torch.Size([144, 2048]) ,

完后就要送给 u(X)和\tiny \sum(X)网络,如下:

  1. self.fc_mu
  2. Linear(in_features=2048, out_features=128, bias=True)
  3. self.fc_var
  4. Linear(in_features=2048, out_features=128, bias=True)

输出为

mu.shape = {Size: 2} torch.Size([144, 128]) # mean of the latent Gaussian [B * D]
log_var.shape = {Size: 2} torch.Size([144, 128]) # deviation of the latent Gaussian [B * D]

用参数化模块实现

sample from N(mu, var) from N(0,1)

这个参数化模块什么样子?

  1. std = torch.exp(0.5 * logvar)
  2. eps = torch.randn_like(std)
  3. return eps * std + mu

这里面的关键函数torch.randn_like解释如下:

pytorch官方解释原文:

Returns a tensor with the same size as input that is filled with random numbers from a normal distribution with mean 0 and variance 1. 

 还有个地方要注意,这个版本的VAE 

std = torch.exp(0.5 * logvar)

编码网络没有l直接让它生成std, 生成的是var取log

不管怎样,经过了参数化模块, 输出的服从N(u(X),\tiny \sum(X))的z

shape为torch.Size([144, 128]).

把z喂给decoder P模块之前先作如下变形:

  1. result = Linear(in_features=128, out_features=2048, bias=True)
  2. result = result.view(-1, 512, 2, 2)

此时的tensor为 torch.Size([144, 512, 2, 2]) 

经过decoder生成图片.

  1. Sequential(
  2. (0): Sequential(
  3. (0): ConvTranspose2d(512, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), output_padding=(1, 1))
  4. (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  5. (2): LeakyReLU(negative_slope=0.01)
  6. )
  7. (1): Sequential(
  8. (0): ConvTranspose2d(256, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), output_padding=(1, 1))
  9. (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  10. (2): LeakyReLU(negative_slope=0.01)
  11. )
  12. (2): Sequential(
  13. (0): ConvTranspose2d(128, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), output_padding=(1, 1))
  14. (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  15. (2): LeakyReLU(negative_slope=0.01)
  16. )
  17. (3): Sequential(
  18. (0): ConvTranspose2d(64, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), output_padding=(1, 1))
  19. (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  20. (2): LeakyReLU(negative_slope=0.01)
  21. )
  22. )
  23. Sequential(
  24. (0): ConvTranspose2d(32, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), output_padding=(1, 1))
  25. (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  26. (2): LeakyReLU(negative_slope=0.01)
  27. (3): Conv2d(32, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  28. (4): Tanh()
  29. )

完后要做的是计算loss,

f(z) 和 X 通过  F.mse_loss 计算loss, 非常简单不用戏说.

最后剩下是如何计算N(u(X),\tiny \sum(X))与N(0,I) 的KLd loss呢? 

mu.shape = {Size: 2} torch.Size([144, 128]) # mean of the latent Gaussian [B * D]
log_var.shape = {Size: 2} torch.Size([144, 128]) # deviation of the latent Gaussian [B * D]

引用两个高斯分布的KLd loss公式,如下:

(10)

经过了一系列的转换,精简到了(10)公式,取我们encoder网络生成的mu和log_var 代进去,代码如下:

kld_loss = torch.mean(-0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim = 1), dim = 0)

完毕!

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/笔触狂放9/article/detail/73022
推荐阅读
相关标签
  

闽ICP备14008679号