赞
踩
近期看论文要用到VAE,看了很多资料,有这样一种感觉,要么过度过于偏向数学原理,要么只是讲了讲网络结构。本文将两者结合,以简洁易懂的语言结合代码实现来介绍VAE。
VAE是变分推断(variational inference )以及自编码器(Auto-encoder)的组合,是一种非监督的生成模型,其概率图模型和深度学习有机结合,近年来比较火热。VAE可以用于但不局限于降维信息检索等任务,我看文献遇到的是一篇做配准的论文,也用到了VAE。
先从神经网络的角度去看VAE,VAE实际上是在Atuo-encoder(AE)的变种,其基本架构也如Atuo-encoder,包含两部分encoder(编码器)和decoder(解码器)。大概如下:
图片摘自:github
以最开始的自编码器为例,其loss函数一般是输入和输出的MSE,通过调整Encoder输出层的节点数(需要低于输入的维度),我们可以从低维度的数据(code)通过Decoder重建出输入。
自编码器存在这样的问题,倘若模型过完备(中间层维度大于输入),模型会直接复制模型的输入作为输出。
在一般实际使用中,我们往往会添加正则项。
除此之外,还有VAE,可以学习出高容量且过完备(中间层维度大于输入)模型。VAE的网络结构如下:
本图摘自:李宏毅2020深度学习课程
可以看到VAE和AE的区别在于两方面:
1.中间层引入了一个noise;
2.loss函数的改变多了:
∑
i
=
1
3
(
x
e
p
(
σ
i
−
(
1
+
σ
i
)
+
(
m
i
2
)
)
)
\sum_{i=1}^{3}(xep(\sigma_{i}-(1+\sigma_{i})+(m_{i}^{2})))
∑i=13(xep(σi−(1+σi)+(mi2)))这一项。根据上面的结构,我们基本可以很容易地代码实现。
那么如何直观地理解上面地改变呢?
1.为什么要引入noise?首先直观理解,就算有一个noise也要尽量输入和输出相似,这样的decoder更加鲁棒。另一个直观理解就是,在引入noise之前,我们的decoder的输入和输出地映射可以看作是离散的,但是在加入noise之后,可以看作把不连续地变成连续的了。
2.为什么更改loss?首先假设没有更改,其实最好的情况肯定是方差为0,即
σ
=
0
\sigma=0
σ=0。这就回到了AE的形式。直观地说,多加地这一项避免了这一点。那么怎么要这样更改loss呢?首先公式地前两项的值域大于等于0,最后一项可以看作一个L2正则项。
上面的理解都是直观的,感性的认识。下面一部分将从变分推断的角度推导出所谓的引入noise其实是重参数化技巧的结果,而loss的改变也是推导所得,本质是一个KL散度。如果读者到目前为止还有求知欲,就可以继续往下看。
设:x:为观测数据,可以看作是样本
y:为隐变量,包含但不限于模型的参数
首先变分推断的核心思想是:因为一般情况下后验概率
p
(
z
∣
x
)
p(z|x)
p(z∣x)是不可求解的,所以变分推断采用了一种迂回的策略,即使用
q
(
z
)
q(z)
q(z)去近似
p
(
z
∣
x
)
p(z|x)
p(z∣x)。如果读者对生成模型不是很理解,可以把
p
(
z
∣
x
)
p(z|x)
p(z∣x)看作是AE中的编码器,
p
(
x
∣
z
)
p(x|z)
p(x∣z)看作是AE中的解码器。
VAE是典型的生成模型,那就从下面公式开始:
log
p
(
x
)
=
log
p
(
x
,
z
)
p
(
z
∣
x
)
=
log
p
(
x
,
z
)
q
(
z
)
−
log
p
(
z
∣
x
)
q
(
z
)
两边关于
q
(
z
)
q(z)
q(z)同时求期望,则:
左边
=
∫
z
q
(
z
)
log
p
(
x
)
d
z
=
log
p
(
x
)
左边=\int_{z}q(z)\log p(x) dz=\log p(x)
左边=∫zq(z)logp(x)dz=logp(x)
右边
=
∫
z
q
(
z
)
log
p
(
x
,
z
)
q
(
z
)
d
z
−
∫
z
q
(
z
∣
x
)
log
p
(
z
∣
x
)
q
(
z
)
d
z
=
L
(
q
)
+
K
L
(
q
(
z
)
∣
∣
p
(
z
∣
x
)
)
对
L
(
q
)
\mathcal{L}(q)
L(q)做进一步化简,
L
(
q
)
=
∫
z
q
(
z
)
log
p
(
x
,
z
)
q
(
z
)
d
z
=
∫
z
q
(
z
)
log
p
(
x
,
z
)
q
(
z
)
q
(
z
)
p
(
z
)
−
∫
z
q
(
z
∣
x
)
log
q
(
z
)
p
(
z
)
=
E
q
(
z
)
log
p
(
x
∣
z
)
−
K
L
(
q
(
z
)
∣
∣
p
(
z
)
)
在此之前我们把q(z)表示为
q
(
z
∣
x
)
q(z|x)
q(z∣x),这时通常要使用重参数化技巧对上式进一步变形,现在想要把
q
(
z
∣
x
)
q(z|x)
q(z∣x)中的x的成分消去。那么上面是重参数化技巧呢?先举个例子吧,一个随机变量a服从概率分布N(0,1),那么对于随机变量b=a+m,服从高斯分布N(m,1)。现在我们采样这个b的时候采用这样的策略:
1.从高斯分布N(0,1)中采样得a。
2.取b = a+m。
其实这就是重采样技巧。对于我们的
q
(
z
∣
x
)
q(z|x)
q(z∣x),我们假设
z
=
g
Φ
(
x
,
ϵ
)
z=g_{\Phi}(x,\epsilon)
z=gΦ(x,ϵ),然后
ϵ
\epsilon
ϵ服从某个分布,记为
p
(
ϵ
)
p(\epsilon)
p(ϵ),一般我们假设其服从标准正态分布。那么我们采样
q
(
z
∣
x
)
q(z|x)
q(z∣x)就变成了,先根据
p
(
ϵ
)
p(\epsilon)
p(ϵ)采样一个
ϵ
i
\epsilon^{i}
ϵi,再根据
z
=
g
Φ
(
x
,
ϵ
)
z=g_{\Phi}(x,\epsilon)
z=gΦ(x,ϵ)计算出z。根据重采样技巧,我们忘掉之前的结果,重新推导KL(q||p):
K
L
(
q
(
z
∣
x
)
∣
∣
p
(
z
∣
x
)
)
=
log
p
(
x
)
−
∫
z
q
(
z
)
log
p
(
x
,
z
)
q
(
z
∣
x
)
q
(
z
∣
x
)
p
(
z
)
+
∫
z
q
(
z
∣
x
)
log
q
(
z
∣
x
)
p
(
z
)
=
log
p
(
x
)
−
∫
z
q
(
z
)
log
p
(
x
∣
z
)
d
z
+
K
L
(
q
(
z
∣
x
)
∣
∣
p
(
z
)
)
=
log
p
(
x
)
−
∫
ϵ
p
(
ϵ
)
log
p
(
x
∣
g
Φ
(
x
,
ϵ
)
)
d
ϵ
+
K
L
(
q
(
z
∣
x
)
∣
∣
p
(
z
)
)
=
log
p
(
x
)
−
E
p
(
ϵ
)
log
p
(
x
∣
g
Φ
(
x
,
ϵ
)
)
+
K
L
(
q
(
z
∣
x
)
∣
∣
p
(
z
)
)
其实整个VAE的构建就是根据上面的等式
g
Φ
(
x
,
ϵ
)
g_{\Phi}(x,\epsilon)
gΦ(x,ϵ)不知道是什么,那就用一个神经网络代替。
p
(
x
∣
z
)
p(x|z)
p(x∣z)不知道是什么,也用一个神经网络代替。下面文字叙述一下VAE的前向传播。
1.先从假设的
p
(
ϵ
)
p(\epsilon)
p(ϵ)中采样一个
ϵ
\epsilon
ϵ,即上一节网络图中的
e
e
e。
2.从假设的encoder中输入x以及
ϵ
\epsilon
ϵ,输出隐变量
z
z
z,即上一节网络图中的
c
c
c。
3.将隐变量z输入decoder,输出
x
^
\hat{x}
x^。
而这个前向的过程表示在上面公式里,就是
E
p
(
ϵ
)
log
p
(
x
∣
g
Φ
(
x
,
ϵ
)
)
d
ϵ
\mathbb{E_{p(\epsilon)}}\log p(x|g_{\Phi}(x,\epsilon))d\epsilon
Ep(ϵ)logp(x∣gΦ(x,ϵ))dϵ,显然优化这个网络,我们要让
K
L
(
q
(
z
∣
x
)
∣
∣
p
(
z
∣
x
)
)
KL(q(z|x)||p(z|x))
KL(q(z∣x)∣∣p(z∣x))最小,
log
p
(
x
)
−
E
p
(
ϵ
)
log
p
(
x
∣
g
Φ
(
x
,
ϵ
)
)
\log p(x)-\mathbb{E_{p(\epsilon)}}\log p(x|g_{\Phi}(x,\epsilon))
logp(x)−Ep(ϵ)logp(x∣gΦ(x,ϵ))就用MSE表示,
K
L
(
q
(
z
∣
x
)
∣
∣
p
(
z
)
)
KL(q(z|x)||p(z))
KL(q(z∣x)∣∣p(z))是可以求出来的。具体的推导也不难,如果假设是高斯分布,即根据多维高斯分布的KL散度,结合我们重采样的q(z|x),推导出来最后的结果就是第一节中图中的公式,这里就省略推导了。
能看到这里的宝贝都很厉害,毕竟我感觉自己也写的不是很清楚,才疏学浅了。不过最困难的部分也过去了,我们不妨看看VAE的pytorch代码实现,看看自己理解的是不是对的。
本文的代码来自GITHUB
__author__ = 'SherlockLiao' import torch import torchvision from torch import nn from torch import optim import torch.nn.functional as F from torch.autograd import Variable from torch.utils.data import DataLoader from torchvision import transforms from torchvision.utils import save_image from torchvision.datasets import MNIST import os if not os.path.exists('./vae_img'): os.mkdir('./vae_img') def to_img(x): x = x.clamp(0, 1) x = x.view(x.size(0), 1, 28, 28) return x num_epochs = 100 batch_size = 128 learning_rate = 1e-3 img_transform = transforms.Compose([ transforms.ToTensor() # transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ]) dataset = MNIST('./data', transform=img_transform, download=True) dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True) class VAE(nn.Module): def __init__(self): super(VAE, self).__init__() self.fc1 = nn.Linear(784, 400) self.fc21 = nn.Linear(400, 20) self.fc22 = nn.Linear(400, 20) self.fc3 = nn.Linear(20, 400) self.fc4 = nn.Linear(400, 784) def encode(self, x): h1 = F.relu(self.fc1(x)) return self.fc21(h1), self.fc22(h1) def reparametrize(self, mu, logvar): std = logvar.mul(0.5).exp_() if torch.cuda.is_available(): eps = torch.cuda.FloatTensor(std.size()).normal_() else: eps = torch.FloatTensor(std.size()).normal_() eps = Variable(eps) return eps.mul(std).add_(mu) def decode(self, z): h3 = F.relu(self.fc3(z)) return F.sigmoid(self.fc4(h3)) def forward(self, x): mu, logvar = self.encode(x) z = self.reparametrize(mu, logvar) return self.decode(z), mu, logvar model = VAE() if torch.cuda.is_available(): model.cuda() reconstruction_function = nn.MSELoss(size_average=False) def loss_function(recon_x, x, mu, logvar): """ recon_x: generating images x: origin images mu: latent mean logvar: latent log variance """ BCE = reconstruction_function(recon_x, x) # mse loss # loss = 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2) KLD_element = mu.pow(2).add_(logvar.exp()).mul_(-1).add_(1).add_(logvar) KLD = torch.sum(KLD_element).mul_(-0.5) # KL divergence return BCE + KLD optimizer = optim.Adam(model.parameters(), lr=1e-3) for epoch in range(num_epochs): model.train() train_loss = 0 for batch_idx, data in enumerate(dataloader): img, _ = data img = img.view(img.size(0), -1) img = Variable(img) if torch.cuda.is_available(): img = img.cuda() optimizer.zero_grad() recon_batch, mu, logvar = model(img) loss = loss_function(recon_batch, img, mu, logvar) loss.backward() train_loss += loss.data[0] optimizer.step() if batch_idx % 100 == 0: print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( epoch, batch_idx * len(img), len(dataloader.dataset), 100. * batch_idx / len(dataloader), loss.data[0] / len(img))) print('====> Epoch: {} Average loss: {:.4f}'.format( epoch, train_loss / len(dataloader.dataset))) if epoch % 10 == 0: save = to_img(recon_batch.cpu().data) save_image(save, './vae_img/image_{}.png'.format(epoch)) torch.save(model.state_dict(), './vae.pth')
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。