当前位置:   article > 正文

[PyTorch][chapter 54][Variational Auto-Encoder 实战]

[PyTorch][chapter 54][Variational Auto-Encoder 实战]

前言:

   
 

这里主要实现: Variational Autoencoders (VAEs) 变分自动编码器
其训练效果如下

 

训练的过程中要注意调节forward 中的kle ,调参

整个工程两个文件:

    vae.py

   main.py

目录:

  1.      vae
  2.       main

一  vae

  文件名: vae.py

   作用:   Variational Autoencoders (VAE)

 训练的过程中加入一些限制,使它的latent space规则一点呢。于是就引入了variational autoencoder(VAE),它被定义为一个有规律地训练以避免过度拟合的Autoencoder,可以确保潜在空间具有良好的属性从而实现内容的生成。
variational autoencoder的架构和Autoencoder差不多,区别在于不再是把输入当作一个点,而是把输入当成一个分布。

  1. # -*- coding: utf-8 -*-
  2. """
  3. Created on Wed Aug 30 14:19:19 2023
  4. @author: chengxf2
  5. """
  6. import torch
  7. from torch import nn
  8. #ae: AutoEncoder
  9. class VAE(nn.Module):
  10. def __init__(self,hidden_size=20):
  11. super(VAE, self).__init__()
  12. self.encoder = nn.Sequential(
  13. nn.Linear(in_features=784, out_features=256),
  14. nn.ReLU(),
  15. nn.Linear(in_features=256, out_features=128),
  16. nn.ReLU(),
  17. nn.Linear(in_features=128, out_features=64),
  18. nn.ReLU(),
  19. nn.Linear(in_features=64, out_features=hidden_size),
  20. nn.ReLU()
  21. )
  22. # hidden [batch_size, 10]
  23. h_dim = int(hidden_size/2)
  24. self.hDim = h_dim
  25. self.decoder = nn.Sequential(
  26. nn.Linear(in_features=h_dim, out_features=64),
  27. nn.ReLU(),
  28. nn.Linear(in_features=64, out_features=128),
  29. nn.ReLU(),
  30. nn.Linear(in_features=128, out_features=256),
  31. nn.ReLU(),
  32. nn.Linear(in_features=256, out_features=784),
  33. nn.Sigmoid()
  34. )
  35. def forward(self, x):
  36. '''
  37. param x:[batch, 1,28,28]
  38. return
  39. '''
  40. batchSz= x.size(0)
  41. #flatten
  42. x = x.view(batchSz, 784)
  43. #encoder
  44. h= self.encoder(x)
  45. #在给定维度上对所给张量进行分块,前一半的神经元看作u, 后一般的神经元看作sigma
  46. u, sigma = h.chunk(2,dim=1)
  47. #Reparameterize trick:
  48. #randn_like:产生一个正太分布 ~ N(0,1)
  49. #h.shape [batchSize,self.hDim]
  50. h = u+sigma* torch.randn_like(sigma)
  51. #kld :1e-8 防止sigma 平方为0
  52. kld = 0.5*torch.sum(
  53. torch.pow(u,2)+
  54. torch.pow(sigma,2)-
  55. torch.log(1e-8+torch.pow(sigma,2))-
  56. 1
  57. )
  58. #MSE loss 是平均loss, 所以kld 也要算一个平均值
  59. kld = kld/(batchSz*32*32)
  60. xHat = self.decoder(h)
  61. #reshape
  62. xHat = xHat.view(batchSz,1,28,28)
  63. return xHat,kld

二 main

文件名: main.py

作用: 训练,测试数据集

 

  1. # -*- coding: utf-8 -*-
  2. """
  3. Created on Wed Aug 30 14:24:10 2023
  4. @author: chengxf2
  5. """
  6. import torch
  7. from torch.utils.data import DataLoader
  8. from torchvision import transforms, datasets
  9. import time
  10. from torch import optim,nn
  11. from vae import VAE
  12. import visdom
  13. def main():
  14. batchNum = 32
  15. lr = 1e-3
  16. epochs = 20
  17. device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
  18. torch.manual_seed(1234)
  19. viz = visdom.Visdom()
  20. viz.line([0],[-1],win='train_loss',opts =dict(title='train acc'))
  21. tf= transforms.Compose([ transforms.ToTensor()])
  22. mnist_train = datasets.MNIST('mnist',True,transform= tf,download=True)
  23. train_data = DataLoader(mnist_train, batch_size=batchNum, shuffle=True)
  24. mnist_test = datasets.MNIST('mnist',False,transform= tf,download=True)
  25. test_data = DataLoader(mnist_test, batch_size=batchNum, shuffle=True)
  26. global_step =0
  27. model =VAE().to(device)
  28. criteon = nn.MSELoss().to(device) #损失函数
  29. optimizer = optim.Adam(model.parameters(),lr=lr) #梯度更新规则
  30. print("\n ----main-----")
  31. for epoch in range(epochs):
  32. start = time.perf_counter()
  33. for step ,(x,y) in enumerate(train_data):
  34. #[b,1,28,28]
  35. x = x.to(device)
  36. x_hat,kld = model(x)
  37. loss = criteon(x_hat, x)
  38. if kld is not None:
  39. elbo = -loss -1.0*kld
  40. loss = -elbo
  41. #backprop
  42. optimizer.zero_grad()
  43. loss.backward()
  44. optimizer.step()
  45. viz.line(Y=[loss.item()],X=[global_step],win='train_loss',update='append')
  46. global_step +=1
  47. end = time.perf_counter()
  48. interval = int(end - start)
  49. print("epoch: %d"%epoch, "\t 训练时间 %d"%interval, '\t 总loss: %4.7f'%loss.item(),"\t KL divergence: %4.7f"%kld.item())
  50. x,target = iter(test_data).next()
  51. x = x.to(device)
  52. with torch.no_grad():
  53. x_hat,kld = model(x)
  54. tip = 'hat'+str(epoch)
  55. viz.images(x,nrow=8, win='x',opts=dict(title='x'))
  56. viz.images(x_hat,nrow=8, win='x_hat',opts=dict(title=tip))
  57. if __name__ == '__main__':
  58. main()

 参考:

 课时118 变分Auto-Encoder实战-2_哔哩哔哩_bilibili

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/我家自动化/article/detail/72992
推荐阅读
相关标签
  

闽ICP备14008679号