当前位置:   article > 正文

实战PyTorch(三):Auto Encoder & Variational Autoencoder_variational autoencoder pytorch

variational autoencoder pytorch

1.Auto Encoder

具体流程:输入==RESHAPE==>784=>1000=>1000=>20=>1000=>1000=>784==RESHAPE==>输出

1.网络

网络层:encoder{[b, 784] => [b, 20]} + decoder{[b, 20] => [b, 784]}

连接层:input.reshape->encoder->decoder->output.reshpe

  1. class AE(nn.Module):
  2. def __init__(self):
  3. super(AE, self).__init__()
  4. # [b, 784] => [b, 20]
  5. self.encoder = nn.Sequential(
  6. nn.Linear(784, 256),
  7. nn.ReLU(),
  8. nn.Linear(256, 64),
  9. nn.ReLU(),
  10. nn.Linear(64, 20),
  11. nn.ReLU()
  12. )
  13. # [b, 20] => [b, 784]
  14. self.decoder = nn.Sequential(
  15. nn.Linear(20, 64),
  16. nn.ReLU(),
  17. nn.Linear(64, 256),
  18. nn.ReLU(),
  19. nn.Linear(256, 784),
  20. nn.Sigmoid() #输出压缩到0~1
  21. )
  22. def forward(self, x):
  23. """param x: [b, 1, 28, 28]"""
  24. batchsz = x.size(0)
  25. # flatten
  26. x = x.view(batchsz, 784)
  27. # encoder
  28. x = self.encoder(x)
  29. # decoder
  30. x = self.decoder(x)
  31. # reshape
  32. x = x.view(batchsz, 1, 28, 28)
  33. return x, None

2.训练&测试:

  1. def main():
  2. mnist_test = DataLoader(mnist_test, batch_size=32, shuffle=True)
  3. x, _ = iter(mnist_train).next()
  4. model = VAE().to(device)
  5. criteon = nn.MSELoss()
  6. optimizer = optim.Adam(model.parameters(), lr=1e-3)
  7. for epoch in range(1000):
  8. for batchidx, (x, _) in enumerate(mnist_train):
  9. # [b, 1, 28, 28]
  10. x = x.to(device)
  11. x_hat= model(x)
  12. loss = criteon(x_hat, x)
  13. # backprop
  14. optimizer.zero_grad()
  15. loss.backward()
  16. optimizer.step()
  17. print(epoch, 'loss:', loss.item())

2.Variational Autoencoder 变分自动编码器

 

1.网络

网络层:encoder{[b, 784] => [b, 20]} + [b,20]=>μ[b,10]+σ[b,10] +decoder{[b, 10] => [b, 784]}

连接层:input.reshape->encoder->->decoder->计算KL->output.reshpe

  1. class VAE(nn.Module):
  2. def __init__(self):
  3. super(VAE, self).__init__()
  4. # [b, 784] => [b, 10]
  5. # sigma: [b, 10]
  6. self.encoder = nn.Sequential(
  7. nn.Linear(784, 256),
  8. nn.ReLU(),
  9. nn.Linear(256, 64),
  10. nn.ReLU(),
  11. nn.Linear(64, 20),
  12. nn.ReLU()
  13. )
  14. # [b, 20] => [b, 784]
  15. self.decoder = nn.Sequential(
  16. '''修改部分'''
  17. nn.Linear(10, 64),
  18. ''''''
  19. nn.ReLU(),
  20. nn.Linear(64, 256),
  21. nn.ReLU(),
  22. nn.Linear(256, 784),
  23. nn.Sigmoid()
  24. )
  25. self.criteon = nn.MSELoss()
  26. def forward(self, x):
  27. """param x: [b, 1, 28, 28]"""
  28. batchsz = x.size(0)
  29. # flatten
  30. x = x.view(batchsz, 784)
  31. '''修改部分'''
  32. # encoder
  33. # [b, 20], 包含mean和σ 
  34. h_ = self.encoder(x)
  35. # [b, 20] => [b, 10] and [b, 10]
  36. mu, sigma = h_.chunk(2, dim=1)
  37. # reparametrize trick, epison~N(0, 1)
  38. h = mu + sigma * torch.randn_like(sigma)
  39. ''''''
  40. # decoder
  41. x_hat = self.decoder(h)
  42. # reshape
  43. x_hat = x_hat.view(batchsz, 1, 28, 28)
  44. '''KL'''
  45. kld = 0.5 * torch.sum(
  46. torch.pow(mu, 2) +
  47. torch.pow(sigma, 2) -
  48. torch.log(1e-8 + torch.pow(sigma, 2)) - 1
  49. ) / (batchsz*28*28)
  50. ''''''
  51. return x_hat, kld

2.训练&测试:

注意loss计算:

  1. for epoch in range(1000):
  2. for batchidx, (x, _) in enumerate(mnist_train):
  3. # [b, 1, 28, 28]
  4. x = x.to(device)
  5. x_hat, kld = model(x)
  6. ''''''
  7. loss = criteon(x_hat, x)
  8. elbo = - loss - 1.0 * kld
  9. loss = - elbo
  10. ''''''
  11. # backprop
  12. optimizer.zero_grad()
  13. loss.backward()
  14. optimizer.step()
  15. print(epoch, 'kld loss:', kld.item())

 

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/Gausst松鼠会/article/detail/72991
推荐阅读
相关标签
  

闽ICP备14008679号