当前位置:   article > 正文

总结VAE_vae低维空间p(z)的分布

vae低维空间p(z)的分布

参考链接:

  1. https://blog.csdn.net/weixin_40955254/article/details/101037614
  2. https://blog.csdn.net/weixin_40955254/article/details/82315909
  3. https://www.bilibili.com/video/BV13x411v7US?p=29(李宏毅老师机器学习2017,图多处于其PPT)

概述

VAE(Variational Auto-Encoders,变分自编码器)属于一种生成式模型,希望可以将一个低维向量映射到一个高维的真实数据,常将其与GAN对比以突出后者。VAE属于Auto-Encoder的变体,通过对code引入噪声,使其具备一定的生成能力。

直观理解

在这里插入图片描述
图左展示的是Auto-Encoder(AE),可以看出AE只会把输入编码到低维空间的一个点,然后通过这个点解码出原输入。如输入一张满月图和一张弦月图,AE就会先将它们编码到一个点,然后解码为原图。但对于低维空间中其他点,如图中问号所指的地方,直觉上希望这个点解码出介于满月和弦月之间的图片,但却得不到。换句话说,图像所对应的低维空间的点没有连续性。

图右所示的VAE则通过引入噪声,缓解了这个问题。对于一个输入,VAE不再只将其编码为低维空间中的一个点,而是一块区域,它让解码器从这一块区域都能解码出输入。当VAE将不同输入都编码为一块区域后,低维空间会出现不同区域的交集,如图右中的两条绿色的区域会产生一个交集,解码器希望从中既解码出满月,也解码出弦月,于是会产生一张介于满月和弦月之间的图片。
在这里插入图片描述
更形式化地,如上图所示,VAE的编码器会输出一个原始的code [ m 1 , m 2 , m 3 ] \left [ m_{1},m_{2},m_{3}\right ] [m1,m2,m3](这里作为一个均值)和一个方差 [ σ 1 , σ 2 , σ 3 ] \left [ \sigma _{1},\sigma _{2},\sigma _{3}\right ] [σ1,σ2,σ3],方差 [ σ 1 , σ 2 , σ 3 ] \left [ \sigma _{1},\sigma _{2},\sigma _{3}\right ] [σ1,σ2,σ3]经过 e x p ( ⋅ ) exp\left ( \cdot \right ) exp()后乘上服从标准正态分布的 [ e 1 , e 2 , e 3 ] \left [ e_{1},e_{2},e_{3}\right ] [e1,e2,e3],变成服从 N ( 0 , e 2 σ ) N\left ( 0,e^{2\sigma }\right ) N(0,e2σ)的噪声后,再加上原始code成为服从 N ( m , e 2 σ ) N\left ( m,e^{2\sigma }\right ) N(m,e2σ)的带噪声的code [ c 1 , c 2 , c 3 ] \left [ c_{1},c_{2},c_{3}\right ] [c1,c2,c3],其中 c i = e σ i × e i + m i c_{i}=e^{\sigma _{i}}\times e_{i}+m_{i} ci=eσi×ei+mi。解码器部分则从这些code中还原出原图,对应于第一张图中从绿色部分还原出月亮图。

注:

  1. σ \sigma σ经过 e x p ( ⋅ ) exp\left ( \cdot \right ) exp()是为了保证方差为正。
  2. c i = e σ i × e i + m i c_{i}=e^{\sigma _{i}}\times e_{i}+m_{i} ci=eσi×ei+mi使用了重参技巧(reparametrization trick)。即通过引入一个高斯分布e,将c变成一个确定的值,从而使得反向传播可以得以计算。
  3. ∑ i = 1 3 ( e x p ( σ i ) − ( 1 + σ i ) + ( m i ) 2 ) \sum_{i=1}^{3}\left ( exp\left ( \sigma _{i}\right )-\left ( 1+\sigma _{i}\right )+\left ( m_{i}\right )^{2}\right ) i=13(exp(σi)(1+σi)+(mi)2) e x p ( σ i ) − ( 1 + σ i ) exp\left ( \sigma _{i}\right )-\left ( 1+\sigma _{i}\right ) exp(σi)(1+σi)是为了保证方差 e 2 σ e^{2\sigma } e2σ最小为1,否则收敛后方差会变成0,那么就不会引入噪声,又变成了AE; ( m i ) 2 \left ( m_{i}\right )^{2} (mi)2可以理解为L2正则项。所以 ∑ i = 1 3 ( e x p ( σ i ) − ( 1 + σ i ) + ( m i ) 2 ) \sum_{i=1}^{3}\left ( exp\left ( \sigma _{i}\right )-\left ( 1+\sigma _{i}\right )+\left ( m_{i}\right )^{2}\right ) i=13(exp(σi)(1+σi)+(mi)2)使带噪声的code接近一个标准正态分布。

理论分析

VAE希望有生成的能力,即从低维的服从高斯分布 P ( z ) P\left ( z\right ) P(z)中采样一个点来生成服从高维的真实数据分布 P ( x ) P\left ( x\right ) P(x)的一个点。 P ( x ) P\left ( x\right ) P(x)可以表示为:

P ( x ) = ∫ z P ( z ) P ( x ∣ z ) P\left ( x\right )=\int_{z}^{}P\left ( z\right )P\left ( x|z\right ) P(x)=zP(z)P(xz)

其中, P ( z ) = N ( 0 , I ) P\left ( z\right )=N\left ( 0,I\right ) P(z)=N(0,I) P ( x ∣ z ) = N ( μ ( z ) , σ ( z ) ) P\left ( x|z\right )=N\left (\mu \left ( z\right ), \sigma \left ( z\right )\right ) P(xz)=N(μ(z),σ(z))

μ ( z ) \mu \left ( z\right ) μ(z) σ ( z ) \sigma \left ( z\right ) σ(z)视为 z z z的函数,为需要学习的部分。这部分(即 P ( x ∣ z ) P\left ( x|z\right ) P(xz))实际为VAE的解码器,通过一个神经网络实现:
在这里插入图片描述
而VAE也假设每个真实样本都对应一个相应的高斯分布,即 P ( z ∣ x ) = N ( μ ′ ( x ) , σ ′ ( x ) ) P\left ( z|x\right )=N\left ( {\mu }'\left ( x\right ),{\sigma }'\left ( x\right )\right ) P(zx)=N(μ(x),σ(x))。为此,VAE使用另一个分布 q ( z ∣ x ) q\left ( z|x\right ) q(zx)来逼近 P ( z ∣ x ) P\left ( z|x\right ) P(zx),对应于VAE中的编码器部分:
在这里插入图片描述
VAE的目标函数为最大似然函数:

L = ∑ x l o g P ( x ) L=\sum_{x}^{}logP\left ( x\right ) L=xlogP(x)

为了引入 q ( z ∣ x ) q\left ( z|x\right ) q(zx),对上式做了一些变形:

l o g P ( x ) = ∫ z q ( z ∣ x ) l o g P ( x ) d z = ∫ z q ( z ∣ x ) l o g ( P ( z , x ) P ( z ∣ x ) ) d z = ∫ z q ( z ∣ x ) l o g ( P ( z , x ) q ( z ∣ x ) q ( z ∣ x ) P ( z ∣ x ) ) d z = ∫ z q ( z ∣ x ) l o g ( P ( z , x ) q ( z ∣ x ) ) d z + ∫ z q ( z ∣ x ) l o g ( q ( z ∣ x ) P ( z ∣ x ) ) d z logP\left ( x\right )=\int_{z}^{}q\left ( z|x\right )logP\left ( x\right )dz\\ =\int_{z}^{}q\left ( z|x\right )log\left ( \frac{P\left ( z,x\right )}{P\left ( z|x\right )}\right )dz\\ =\int_{z}^{}q\left ( z|x\right )log\left ( \frac{P\left ( z,x\right )}{q\left ( z|x\right )}\frac{q\left ( z|x\right )}{P\left ( z|x\right )}\right )dz\\ =\int_{z}^{}q\left ( z|x\right )log\left ( \frac{P\left ( z,x\right )}{q\left ( z|x\right )}\right )dz+\int_{z}^{}q\left ( z|x\right )log\left ( \frac{q\left ( z|x\right )}{P\left ( z|x\right )}\right )dz logP(x)=zq(zx)logP(x)dz=zq(zx)log(P(zx)P(z,x))dz=zq(zx)log(q(zx)P(z,x)P(zx)q(zx))dz=zq(zx)log(q(zx)P(z,x))dz+zq(zx)log(P(zx)q(zx))dz

上式右边第二项为 q ( z ∣ x ) q\left ( z|x\right ) q(zx) P ( z ∣ x ) P\left ( z|x\right ) P(zx)的KL散度,即 K L ( q ( z ∣ x ) ∣ ∣ P ( z ∣ x ) ) ⩾ 0 KL\left ( q\left ( z|x\right )||P\left ( z|x\right )\right )\geqslant 0 KL(q(zx)P(zx))0

所以有:

l o g P ( x ) ⩾ ∫ z q ( z ∣ x ) l o g ( P ( z , x ) q ( z ∣ x ) ) d z = ∫ z q ( z ∣ x ) l o g ( P ( x ∣ z ) P ( z ) q ( z ∣ x ) ) d z logP\left ( x\right )\geqslant \int_{z}^{}q\left ( z|x\right )log\left ( \frac{P\left ( z,x\right )}{q\left ( z|x\right )}\right )dz=\int_{z}^{}q\left ( z|x\right )log\left ( \frac{P\left (x|z\right )P\left ( z\right )}{q\left ( z|x\right )}\right )dz logP(x)zq(zx)log(q(zx)P(z,x))dz=zq(zx)log(q(zx)P(xz)P(z))dz

上式称为 l o g P ( x ) logP\left ( x\right ) logP(x)的下界 L b L_b Lb

于是,在保证 K L ( q ( z ∣ x ) ∣ ∣ P ( z ∣ x ) ) KL\left ( q\left ( z|x\right )||P\left ( z|x\right )\right ) KL(q(zx)P(zx))很小的时候(即前面说的希望 q ( z ∣ x ) q\left ( z|x\right ) q(zx)逼近 P ( z ∣ x ) P\left ( z|x\right ) P(zx)),我们可以找到 P ( x ∣ z ) P\left (x|z\right ) P(xz) q ( z ∣ x ) q\left ( z|x\right ) q(zx)最大化下界,达到最大化 l o g P ( x ) logP\left ( x\right ) logP(x)的目的。

L b = ∫ z q ( z ∣ x ) l o g ( P ( z ) q ( z ∣ x ) ) d z + ∫ z q ( z ∣ x ) l o g P ( x ∣ z ) d z L_{b}=\int_{z}^{}q\left ( z|x\right )log\left ( \frac{P\left ( z\right )}{q\left ( z|x\right )}\right )dz+\int_{z}^{}q\left ( z|x\right )logP\left ( x|z\right )dz Lb=zq(zx)log(q(zx)P(z))dz+zq(zx)logP(xz)dz

等号右边第一项为 − K L ( q ( z ∣ x ) ∣ ∣ P ( z ) ) -KL\left ( q\left ( z|x\right )||P\left ( z\right )\right ) KL(q(zx)P(z)),最大化该项意为让 q ( z ∣ x ) q\left ( z|x\right ) q(zx)逼近一个标准正态分布,等价于上面的 ∑ i = 1 d ( e x p ( σ i ) − ( 1 + σ i ) + ( m i ) 2 ) \sum_{i=1}^{d}\left ( exp\left ( \sigma _{i}\right )-\left ( 1+\sigma _{i}\right )+\left ( m_{i}\right )^{2}\right ) i=1d(exp(σi)(1+σi)+(mi)2)

等号右边第二项为 E q ( z ∣ x ) [ l o g P ( x ∣ z ) ] E_{q\left ( z|x\right )}\left [ logP\left ( x|z\right )\right ] Eq(zx)[logP(xz)],最大化该项意为最大化 q ( z ∣ x ) q\left ( z|x\right ) q(zx)下, l o g P ( x ∣ z ) logP\left ( x|z\right ) logP(xz)的期望,相当于AE的损失。
在这里插入图片描述

前沿应用

VAE在2016年被用在VGAE(《Variational Graph Auto-Encoders》)。VGAE用于完成图重构任务,这里的图指由边和点构成的图(graph),而非图片(picture)。

VGAE的编码器依然输出均值和方差,只是它们都由图卷积网络(GCN)产生:

μ = G C N μ ( X , A ) \mu =GCN_{\mu }\left ( X,A\right ) μ=GCNμ(X,A)

l o g σ = G C N σ ( X , A ) log\sigma =GCN_{\sigma }\left ( X,A\right ) logσ=GCNσ(X,A)

其中, X X X为节点的特征矩阵, A A A为邻接矩阵。
而解码器利用隐变量的内积来重构邻接矩阵:

p ( A ∣ Z ) = ∏ i = 1 N ∏ j = 1 N p ( A i j ∣ z i , z j ) p\left ( A|Z\right )=\prod_{i=1}^{N}\prod_{j=1}^{N}p\left ( A_{ij}|z_{i},z_{j}\right ) p(AZ)=i=1Nj=1Np(Aijzi,zj)

损失函数也包含两部分:

L = E q ( Z ∣ X , A ) [ l o g   p ( A ∣ Z ) ] − K L [ q ( Z ∣ X , A ) ∣ ∣ p ( Z ) ] L=E_{q\left ( Z|X,A\right )}\left [ log\ p\left ( A|Z\right )\right ]-KL\left [ q\left ( Z|X,A\right )||p\left ( Z\right )\right ] L=Eq(ZX,A)[log p(AZ)]KL[q(ZX,A)p(Z)]

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/羊村懒王/article/detail/318105
推荐阅读
相关标签
  

闽ICP备14008679号