当前位置:   article > 正文

【图像生成】(三) VAE原理 & pytorch代码实例_vae图像生成

vae图像生成

1.简介

上一篇文章里我们介绍了【图像生成】的GAN及其改进WGAN,还有对应的condition条件生成代码。这篇文章主要介绍另外一种生成网络VAE。


2.原理

VAE相对于GAN来说像是一种相反的存在:GAN是输入latent生成图像,再用生成的图像去修正网络;而VAE是输入图像生成latent,让latent的尽量接近原数据集的分布。这两者是不是有种奇妙的转置的感觉?

让我们从头来理解下VAE的由来和特点。首先我们从AE(Auto Encoder)说起,AE对图像进行encode后,生成一串可以表达图像特征的向量z,我们可以把这个特征向量z输入进decoder来还原出最初的图片。具体的流程图如下:

但是这样有个问题,AE结构只能对图像进行压缩和还原,并不能生成新的图像。那么怎么解决这个问题呢?把latent概率化就好了,这样我们就可以在特定的概率分布中获取一定的随机性。

VAE将latent表达为高斯的概率分布,同时通过网络去自动学习平衡图像生成的精确度和概率分布的拟合度,这两者可以分别用MSE和KL散度来计算。之所以使用高斯分布,是因为高斯分布可以去累加映射得到任何的数据分布,同时高斯分布可以通过参数重整化转换为标准正态分布的线性表达,因此VAE中的latent中包括了高斯分布的均值和标准差。具体的流程图如下:

下图应该可以更加直观的理解,概率分布所表示的意义:

每个属性对应了一种特征,我们可以从每个特征的概率分布中去随机抽取,来得到对应的新生成的图像。

那么怎么具体训练才能得到每个特征具体的分布呢?刚才说了我们需要去平衡图像生成的精确度和概率分布的拟合度,这两者分布用MSE和KL散度来计算。MSE是去为了使latent输入进decoder的图像尽可能的接近输入encoder的真实图像,KL散度是为了让latent中的mean和std更接近于正态分布。

但是在实际过程中,如果直接去通过均值和方差随机生成高斯分布,是没办法进行梯度求导的,所以这里采用了一个技巧:参数重整化(Reparameterization),即将高斯分布表达为均值和标准差的线性组合,如下图所示:

 所以最后的训练流程图如下所示:

左图是直接用均值和标准差生成高斯分布,但这样梯度是没法反向推导的。右图是进行了参数重整化,引入随机的标准正态分布,这样使得训练成为可能(妙啊!)。

最后一个问题!!!可能有些同学还是不理解,为什么要让均值和标准差去逼近于标准正态分布。最开始我也有这个疑问,图像那块的重建loss很好理解,这块确实会比较抽象一点。

思来想去原来是一个很简单的道理:采样!!!

我们test的时候进行采样,不会用其他的均值和标准差,肯定是用的正态分布的latent,输入进decoder中生成图像。这里让均值和标准差尽可能去接近正态分布,为的就是让encoder去学习到数据的一个分布规律,并将它们映射到正态分布中,这样采样时直接用正态分布就可以包含所有的情况。

下图也和我的理解相似,将不同的分布映射到同一个区域,可以便于特征更好的融合和采样插值。


3.代码

接下来我们用pytorch来实现VAE在MNIST数据集上的生成。

3.1模型

encoder和decoder均用全连接层来简化,encoder中有两个分支,一个预测均值,一个预测标准差。decoder输入latent得到生成图像。

  1. class VAE(nn.Module):
  2. def __init__(self, input_dim=1, output_dim=1, middle_dim=400, latent_dim=20, class_num=10):
  3. '''
  4. 初始化网络
  5. :param input_dim:输入维度,也是latent维度
  6. :param output_dim:输出维度,表示最终生成图片的通道数
  7. :param class_num:图像种类,代表condition种类
  8. '''
  9. super(VAE, self).__init__()
  10. self.fc1 = nn.Linear(784, middle_dim)
  11. self.fc_mu = nn.Linear(middle_dim, latent_dim)
  12. self.fc_logvar = nn.Linear(middle_dim, latent_dim)
  13. self.fc2 = nn.Linear(latent_dim, middle_dim)
  14. self.fc3 = nn.Linear(middle_dim, 784)
  15. self.recons_loss = nn.BCELoss(reduction='sum')
  16. def encode(self, x):
  17. x = torch.relu(self.fc1(x))
  18. mu = self.fc_mu(x)
  19. logvar = self.fc_logvar(x)
  20. return mu, logvar
  21. def reparametrization(self, mu, logvar):
  22. # sigma = 0.5*exp(log(sigma^2))= 0.5*exp(log(var))
  23. std = torch.exp(logvar / 2)
  24. eps = torch.randn_like(std)
  25. # N(mu, std^2) = N(0, 1) * std + mu
  26. z = eps * std + mu
  27. return z
  28. def decode(self, z):
  29. x = torch.relu(self.fc2(z))
  30. x = F.sigmoid(self.fc3(x))
  31. return x
  32. def forward(self, x):
  33. mu, logvar = self.encode(x)
  34. z = self.reparametrization(mu, logvar)
  35. x_out = self.decode(z)
  36. loss = self.loss_func(x_out, x, mu, logvar)
  37. return loss
  38. def loss_func(self, x_out, x, mu ,logvar):
  39. reconstruction_loss = self.recons_loss(x_out, x)
  40. KL_divergence = -0.5 * torch.sum(1 + logvar - torch.exp(logvar) - mu ** 2)
  41. # KLD_ele = mu.pow(2).add_(logvar.exp()).mul_(-1).add_(1).add_(logvar)
  42. # KLD = torch.sum(KLD_ele).mul_(-0.5)
  43. return reconstruction_loss + KL_divergence

3.2训练

我们训练时先将图像通过encoder得到预测的均值和标准差,然后通过参数重整化计算得到latent,再将latent输入进decoder得到生成图像,最终计算重建loss和KL散度loss。

注意,这里的loss我们使用的是BCEloss,所以数据集加载进来不能再进行normalise,不然范围就不会在0-1之间。

  1. def train(self):
  2. self.model.train()
  3. print('训练开始!!')
  4. for epoch in range(self.epoch):
  5. self.model.train()
  6. loss_mean = 0
  7. for i, (images, labels) in enumerate(self.train_dataloader):
  8. images, labels = images.to(self.device), labels.to(self.device)
  9. # 将latent和condition拼接后输入网络
  10. loss = self.model(images.view(images.shape[0], -1))
  11. loss_mean += loss.item()
  12. self.optimizer.zero_grad()
  13. loss.backward()
  14. self.optimizer.step()
  15. train_loss = loss_mean / len(self.train_dataloader)
  16. print('epoch:{}, loss:{:.4f}'.format(epoch, train_loss))
  17. self.visualize_results(epoch)

3.3推理&可视化

在预测的时候就只用使用随机的正态分布latent输入进decoder就可以得到生成图像。

  1. @torch.no_grad()
  2. def visualize_results(self, epoch):
  3. self.model.eval()
  4. # 保存结果路径
  5. output_path = 'results/VAE'
  6. if not os.path.exists(output_path):
  7. os.makedirs(output_path)
  8. tot_num_samples = self.sample_num
  9. image_frame_dim = int(np.floor(np.sqrt(tot_num_samples)))
  10. # 生成对应sample个condition
  11. z = torch.randn(tot_num_samples, self.latent_dim).to(self.device)
  12. generated_images = self.model.decode(z)
  13. generated_images = generated_images.view(generated_images.shape[0], 1, 28, 28)
  14. save_image(generated_images, os.path.join(output_path, '{}.jpg'.format(epoch)), nrow=image_frame_dim)

 可以看到结果比较模糊,这个是因为KL散度的loss不为0,代表着两个分布不能完全相似,只能得到一个大致的结果,所以才会导致模糊。

完整代码如下:

  1. import torch, time, os
  2. import numpy as np
  3. import torch.nn as nn
  4. import torch.optim as optim
  5. from torchvision.datasets import MNIST
  6. from torchvision import transforms
  7. from torch.utils.data import DataLoader
  8. from torchvision.utils import save_image
  9. import torch.nn.functional as F
  10. class VAE(nn.Module):
  11. def __init__(self, middle_dim=400, latent_dim=20, class_num=10):
  12. super(VAE, self).__init__()
  13. self.fc1 = nn.Linear(784, middle_dim)
  14. self.fc_mu = nn.Linear(middle_dim, latent_dim)
  15. self.fc_logvar = nn.Linear(middle_dim, latent_dim)
  16. self.fc2 = nn.Linear(latent_dim, middle_dim)
  17. self.fc3 = nn.Linear(middle_dim, 784)
  18. self.recons_loss = nn.BCELoss(reduction='sum')
  19. def encode(self, x):
  20. x = torch.relu(self.fc1(x))
  21. mu = self.fc_mu(x)
  22. logvar = self.fc_logvar(x)
  23. return mu, logvar
  24. def reparametrization(self, mu, logvar):
  25. # sigma = 0.5*exp(log(sigma^2))= 0.5*exp(log(var))
  26. std = torch.exp(logvar / 2)
  27. eps = torch.randn_like(std)
  28. # N(mu, std^2) = N(0, 1) * std + mu
  29. z = eps * std + mu
  30. return z
  31. def decode(self, z):
  32. x = torch.relu(self.fc2(z))
  33. x = F.sigmoid(self.fc3(x))
  34. return x
  35. def forward(self, x):
  36. mu, logvar = self.encode(x)
  37. z = self.reparametrization(mu, logvar)
  38. x_out = self.decode(z)
  39. loss = self.loss_func(x_out, x, mu, logvar)
  40. return loss
  41. def loss_func(self, x_out, x, mu ,logvar):
  42. reconstruction_loss = self.recons_loss(x_out, x)
  43. KL_divergence = -0.5 * torch.sum(1 + logvar - torch.exp(logvar) - mu ** 2)
  44. # KLD_ele = mu.pow(2).add_(logvar.exp()).mul_(-1).add_(1).add_(logvar)
  45. # KLD = torch.sum(KLD_ele).mul_(-0.5)
  46. return reconstruction_loss + KL_divergence
  47. class ImageGenerator(object):
  48. def __init__(self):
  49. '''
  50. 初始化,定义超参数、数据集、网络结构等
  51. '''
  52. self.epoch = 50
  53. self.sample_num = 100
  54. self.batch_size = 128
  55. self.latent_dim = 20
  56. self.lr = 0.001
  57. self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
  58. self.init_dataloader()
  59. self.model = VAE(latent_dim=self.latent_dim).to(self.device)
  60. self.optimizer = optim.Adam(self.model.parameters(), lr=self.lr)
  61. def init_dataloader(self):
  62. '''
  63. 初始化数据集和dataloader
  64. '''
  65. tf = transforms.Compose([
  66. transforms.ToTensor(),
  67. # transforms.Normalize((0.1307,), (0.3081,))
  68. ])
  69. train_dataset = MNIST('./data/',
  70. train=True,
  71. download=True,
  72. transform=tf)
  73. self.train_dataloader = DataLoader(train_dataset, batch_size=self.batch_size, shuffle=True, drop_last=True)
  74. val_dataset = MNIST('./data/',
  75. train=False,
  76. download=True,
  77. transform=tf)
  78. self.val_dataloader = DataLoader(val_dataset, batch_size=self.batch_size, shuffle=False)
  79. def train(self):
  80. self.model.train()
  81. print('训练开始!!')
  82. for epoch in range(self.epoch):
  83. self.model.train()
  84. loss_mean = 0
  85. for i, (images, labels) in enumerate(self.train_dataloader):
  86. images, labels = images.to(self.device), labels.to(self.device)
  87. loss = self.model(images.view(images.shape[0], -1))
  88. loss_mean += loss.item()
  89. self.optimizer.zero_grad()
  90. loss.backward()
  91. self.optimizer.step()
  92. train_loss = loss_mean / len(self.train_dataloader)
  93. print('epoch:{}, loss:{:.4f}'.format(epoch, train_loss))
  94. self.visualize_results(epoch)
  95. @torch.no_grad()
  96. def visualize_results(self, epoch):
  97. self.model.eval()
  98. # 保存结果路径
  99. output_path = 'results/VAE'
  100. if not os.path.exists(output_path):
  101. os.makedirs(output_path)
  102. tot_num_samples = self.sample_num
  103. image_frame_dim = int(np.floor(np.sqrt(tot_num_samples)))
  104. # 生成对应sample个condition
  105. z = torch.randn(tot_num_samples, self.latent_dim).to(self.device)
  106. generated_images = self.model.decode(z)
  107. generated_images = generated_images.view(generated_images.shape[0], 1, 28, 28)
  108. save_image(generated_images, os.path.join(output_path, '{}.jpg'.format(epoch)), nrow=image_frame_dim)
  109. if __name__ == '__main__':
  110. generator = ImageGenerator()
  111. generator.train()

4. condition代码及结果

如果我们要生成condition条件下的图像,与之前的做法很类似,在encoder中将图像和标签的embedding向量拼接,在decoder中将latent和标签的embedding向量拼接:

  1. import torch, time, os
  2. import numpy as np
  3. import torch.nn as nn
  4. import torch.optim as optim
  5. from torchvision.datasets import MNIST
  6. from torchvision import transforms
  7. from torch.utils.data import DataLoader
  8. from torchvision.utils import save_image
  9. import torch.nn.functional as F
  10. class VAE(nn.Module):
  11. def __init__(self, middle_dim=400, latent_dim=20, class_num=10):
  12. super(VAE, self).__init__()
  13. self.fc1 = nn.Linear(784 + class_num, middle_dim)
  14. self.fc_mu = nn.Linear(middle_dim, latent_dim)
  15. self.fc_logvar = nn.Linear(middle_dim, latent_dim)
  16. self.fc2 = nn.Linear(latent_dim + class_num, middle_dim)
  17. self.fc3 = nn.Linear(middle_dim, 784)
  18. self.recons_loss = nn.BCELoss(reduction='sum')
  19. def encode(self, x, labels):
  20. x = torch.cat((x, labels), dim=1)
  21. x = torch.relu(self.fc1(x))
  22. mu = self.fc_mu(x)
  23. logvar = self.fc_logvar(x)
  24. return mu, logvar
  25. def reparametrization(self, mu, logvar):
  26. # sigma = 0.5*exp(log(sigma^2))= 0.5*exp(log(var))
  27. std = torch.exp(logvar / 2)
  28. eps = torch.randn_like(std)
  29. # N(mu, std^2) = N(0, 1) * std + mu
  30. z = eps * std + mu
  31. return z
  32. def decode(self, z, labels):
  33. z = torch.cat((z, labels), dim=1)
  34. x = torch.relu(self.fc2(z))
  35. x = F.sigmoid(self.fc3(x))
  36. return x
  37. def forward(self, x, labels):
  38. mu, logvar = self.encode(x, labels)
  39. z = self.reparametrization(mu, logvar)
  40. x_out = self.decode(z, labels)
  41. loss = self.loss_func(x_out, x, mu, logvar)
  42. return loss
  43. def loss_func(self, x_out, x, mu ,logvar):
  44. reconstruction_loss = self.recons_loss(x_out, x)
  45. KL_divergence = -0.5 * torch.sum(1 + logvar - torch.exp(logvar) - mu ** 2)
  46. # KLD_ele = mu.pow(2).add_(logvar.exp()).mul_(-1).add_(1).add_(logvar)
  47. # KLD = torch.sum(KLD_ele).mul_(-0.5)
  48. return reconstruction_loss + KL_divergence
  49. class ImageGenerator(object):
  50. def __init__(self):
  51. '''
  52. 初始化,定义超参数、数据集、网络结构等
  53. '''
  54. self.epoch = 50
  55. self.sample_num = 100
  56. self.batch_size = 128
  57. self.latent_dim = 20
  58. self.lr = 0.001
  59. self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
  60. self.init_dataloader()
  61. self.model = VAE(latent_dim=self.latent_dim, class_num=10).to(self.device)
  62. self.optimizer = optim.Adam(self.model.parameters(), lr=self.lr)
  63. def init_dataloader(self):
  64. '''
  65. 初始化数据集和dataloader
  66. '''
  67. tf = transforms.Compose([
  68. transforms.ToTensor(),
  69. # transforms.Normalize((0.1307,), (0.3081,))
  70. ])
  71. train_dataset = MNIST('./data/',
  72. train=True,
  73. download=True,
  74. transform=tf)
  75. self.train_dataloader = DataLoader(train_dataset, batch_size=self.batch_size, shuffle=True, drop_last=True)
  76. val_dataset = MNIST('./data/',
  77. train=False,
  78. download=True,
  79. transform=tf)
  80. self.val_dataloader = DataLoader(val_dataset, batch_size=self.batch_size, shuffle=False)
  81. def train(self):
  82. self.model.train()
  83. print('训练开始!!')
  84. for epoch in range(self.epoch):
  85. self.model.train()
  86. loss_mean = 0
  87. for i, (images, labels) in enumerate(self.train_dataloader):
  88. images, labels = images.to(self.device), labels.to(self.device)
  89. labels = F.one_hot(labels, num_classes=10)
  90. # 将latent和condition拼接后输入网络
  91. loss = self.model(images.view(images.shape[0], -1), labels)
  92. loss_mean += loss.item()
  93. self.optimizer.zero_grad()
  94. loss.backward()
  95. self.optimizer.step()
  96. train_loss = loss_mean / len(self.train_dataloader)
  97. print('epoch:{}, loss:{:.4f}'.format(epoch, train_loss))
  98. self.visualize_results(epoch)
  99. @torch.no_grad()
  100. def visualize_results(self, epoch):
  101. self.model.eval()
  102. # 保存结果路径
  103. output_path = 'results/VAE'
  104. if not os.path.exists(output_path):
  105. os.makedirs(output_path)
  106. tot_num_samples = self.sample_num
  107. image_frame_dim = int(np.floor(np.sqrt(tot_num_samples)))
  108. # 生成对应sample个condition
  109. z = torch.randn(tot_num_samples, self.latent_dim).to(self.device)
  110. labels = F.one_hot(torch.Tensor(np.repeat(np.arange(10), 10)).to(torch.int64), num_classes=10).to(self.device)
  111. generated_images = self.model.decode(z, labels)
  112. generated_images = generated_images.view(generated_images.shape[0], 1, 28, 28)
  113. save_image(generated_images, os.path.join(output_path, '{}.jpg'.format(epoch)), nrow=image_frame_dim)
  114. if __name__ == '__main__':
  115. generator = ImageGenerator()
  116. generator.train()


业务合作/学习交流+v:lizhiTechnology

  如果想要了解更多图像生成相关知识,可以参考我的专栏和其他相关文章:

图像生成_Lcm_Tech的博客-CSDN博客

【图像生成】(一) DNN 原理 & pytorch代码实例_pytorch dnn代码-CSDN博客

【图像生成】(二) GAN 原理 & pytorch代码实例_gan代码-CSDN博客

【图像生成】(三) VAE原理 & pytorch代码实例_vae算法 是如何生成图的-CSDN博客

【图像生成】(四) Diffusion原理 & pytorch代码实例_diffusion unet-CSDN博客

如果想要了解更多深度学习相关知识,可以参考我的其他文章:

深度学习_Lcm_Tech的博客-CSDN博客

【优化器】(一) SGD原理 & pytorch代码解析_sgd优化器-CSDN博客

【损失函数】(一) L1Loss原理 & pytorch代码解析_l1 loss-CSDN博客

【diffusers】(一) diffusers库介绍 & 框架代码解析-CSDN博客

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/Gausst松鼠会/article/detail/453150
推荐阅读
相关标签
  

闽ICP备14008679号