当前位置:   article > 正文

Pytorch之经典神经网络Generative Model(二) —— VAE (MNIST)_vae神经网络

vae神经网络

      变分编码器(Variational AutoEncoder)是自动编码器的升级版本, 其结构跟自动编码器是类似的, 也由编码器和解码器构成。

      回忆一下, 自动编码器有个问题, 就是并不能任意生成图片, 因为我们没有办法自己去构造隐藏向量, 需要通过一张图片输入编码我们才知道得到的隐含向量是什么, 这时我们就可以通过变分自动编码器来解决这个问题。

      其实原理特别简单, 只需要在编码过程给它增加一些限制, 迫使其生成的隐含向量能够粗略的遵循一个标准正态分布, 这就是其与一般的自动编码器最大的不同。这样我们生成一张新图片就很简单了, 我们只需要给它一个标准正态分布的随机隐含向量, 这样通过解码器就能够生成我们想要的图片, 而不需要给它一张原始图片先编码。

      一般来讲, 我们通过 encoder 得到的隐含向量并不是一个标准的正态分布, 为了衡量两种分布的相似程度, 我们使用 KL divergence, 这是用来衡量两种分布相似程度的统计量,它越小,表示两种概率分布越接近。

       在实际情况中,需要在模型的准确率和encoder得到的隐含向量服从标准正态分布之间做一个权衡,所谓模型的准确率就是指解码器生成的图片与原始图片的相似程度。可以让神经网络自己做这个决定,只需要将两者都做一个loss,然后求和作为总的loss,这样网络就能够自己选择如何做才能使这个总的loss下降。

      为了避免计算 KL divergence 中的积分, 我们使用重参数的技巧, 不是每次产生一个隐含向量, 而是生成两个向量, 一个表示均值, 一个表示标准差, 这里我们默认编码之后的隐含向量服从一个正态分布的之后, 就可以用一个标准正态分布先乘上标准差再加上均值来合成这个正态分布, 最后 loss 就是希望这个生成的正态分布能够符合一个标准正态分布, 也就是希望均值为 0, 方差为 1

      所以标准的变分自动编码器VAE如下

  1. import os
  2. import torch
  3. import torch.nn.functional as F
  4. from torch import nn
  5. from torch.utils.data import DataLoader
  6. from torchvision.datasets import MNIST
  7. from torchvision import transforms
  8. from torchvision.utils import save_image
  9. from visdom import Visdom
  10. class VAE(nn.Module):
  11. def __init__(self):
  12. super(VAE, self).__init__()
  13. self.fc1 = nn.Linear(784, 400)
  14. self.fc21 = nn.Linear(400, 20) # mean 均值
  15. self.fc22 = nn.Linear(400, 20) # var 标准差
  16. self.fc3 = nn.Linear(20, 400)
  17. self.fc4 = nn.Linear(400, 784)
  18. def encode(self, x):
  19. x = self.fc1(x)
  20. h1 = F.relu(x)
  21. mean = self.fc21(h1)
  22. var = self.fc22(h1)
  23. return mean, var
  24. #重参数化
  25. def reparametrize(self, mean, logvar):
  26. std = logvar.mul(0.5).exp_()
  27. normal = torch.FloatTensor(std.size()).normal_() #生成标准正态分布
  28. if torch.cuda.is_available():
  29. normal = torch.tensor(normal.cuda())
  30. else:
  31. normal = torch.tensor(normal)
  32. return normal.mul(std).add_(mean) #标准正态分布乘上标准差再加上均值
  33. #这里返回的结果就是我们encoder得到的编码,也就是我们decoder要decode的编码
  34. def decode(self, z):
  35. z = self.fc3(z)
  36. z = F.relu(z)
  37. z = self.fc4(z)
  38. z = torch.tanh(z)
  39. return z
  40. def forward(self, x):
  41. mean, logvar = self.encode(x) # 编码
  42. z = self.reparametrize(mean, logvar) # 重新参数化成正态分布
  43. return self.decode(z), mean, logvar # 解码, 同时输出均值方差
  44. def loss_function(recon_image, image, mean, logvar):
  45. """
  46. recon_x: generating images
  47. x: origin images
  48. mu: latent mean
  49. logvar: latent log variance
  50. """
  51. reconstruction_function = nn.MSELoss(reduction='sum')
  52. MSE = reconstruction_function(recon_image, image)
  53. # loss = 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
  54. KLD_element = mean.pow(2).add_(logvar.exp()).mul_(-1).add_(1).add_(logvar)
  55. KLD = torch.sum(KLD_element).mul_(-0.5)
  56. # KL divergence
  57. return MSE + KLD
  58. def to_img(x):
  59. '''
  60. 定义一个函数将最后的结果转换回图片
  61. '''
  62. x = 0.5 * (x + 1.)
  63. x = x.clamp(0, 1)
  64. x = x.view(x.shape[0], 1, 28, 28)
  65. return x
  66. img_transforms = transforms.Compose([
  67. transforms.ToTensor(),
  68. transforms.Normalize([0.5], [0.5]) # 标准化
  69. ])
  70. train_set = MNIST(
  71. root='dataset/',
  72. transform=img_transforms
  73. )
  74. train_data = DataLoader(
  75. dataset=train_set,
  76. batch_size=128,
  77. shuffle=True
  78. )
  79. net = VAE() # 实例化网络
  80. if torch.cuda.is_available():
  81. net = net.cuda()
  82. optimizer = torch.optim.Adam(net.parameters(), lr=1e-3)
  83. viz = Visdom()
  84. viz.line([0.], [0.], win='loss', opts=dict(title='loss'))
  85. for epoch in range(100):
  86. for image, _ in train_data:
  87. image = image.view(image.shape[0], -1)
  88. image = torch.tensor(image)
  89. if torch.cuda.is_available():
  90. image = image.cuda()
  91. recon_image, mean, logvar = net(image)
  92. loss = loss_function(recon_image, image, mean, logvar) / image.shape[0] # 将 loss 平均
  93. optimizer.zero_grad()
  94. loss.backward()
  95. optimizer.step()
  96. print('epoch: {}, Loss: {:.4f}'.format(epoch, loss.item()))
  97. save = to_img(recon_image.cpu().data)
  98. if not os.path.exists('./vae_img'):
  99. os.mkdir('./vae_img')
  100. save_image(save, './vae_img/image_{}.png'.format(epoch))
  101. viz.line([loss.item()], [epoch], win='loss', update='append')

运行100个eopch之后,可以看出来结果比自动编码器清晰一点,本质上VAE就是在encoder的结果添加了高斯噪声,通过训练要使得decoder对噪声有一定的鲁棒性,这样的话我们生成一张图片就没有必须用一张图片先做编码了,可以想象,我们只需要利用训练好的encoder对一张图片编码得到其分布后,符合这个分布的隐含向量理论上都可以通过decoder得到类似这张图片的图片。

KL越小,噪声越大(可以这麽理解,我们强行让z的分布符合正态分布,其和N(0,1)越接近,KL越小,相当于我们添加的噪声越大),所以直觉上来想loss合并后的训练过程:

  • 当 decoder 还没有训练好时(重构误差远大于 KL loss),就会适当降低噪声(KL loss 增加),使得拟合起来容易一些(重构误差开始下降);
  • 反之,如果 decoder 训练得还不错时(重构误差小于 KL loss),这时候噪声就会增加(KL loss 减少),使得拟合更加困难了(重构误差又开始增加),这时候 decoder 就要想办法提高它的生成能力了。

      变分自动编码器虽然比一般的自动编码器效果要好, 而且也限制了其输出的编码(code) 的概率分布, 但是它仍然是通过直接计算生成图片和原始图片的均方误差来生成 loss, 这个方式并不好。

在之后生成对抗网络中, 我们会讲一讲这种方式计算 loss 的局限性, 然后会介绍一种新的训练办法, 就是通过生成对抗的训练方式来训练网络而不是直接比较两张图片的每个像素点的均方误差

变分自编码器VAE:原来是这么一回事 | 附开源代码 - 知乎

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/菜鸟追梦旅行/article/detail/73057
推荐阅读
相关标签
  

闽ICP备14008679号