赞
踩
变分自编码器(variational autoencoder,VAE)的原理介绍:VAE将经过神经网络编码后的隐藏层假设为一个标准的高斯分布,然后再从这个分布中采样一个特征,再用这个特征进行解码,期望得到与原始输入相同的结果,损失和AE几乎一样,只是增加编码推断分布与标准高斯分布的KL散度的正则项,显然增加这个正则项的目的就是防止模型退化成普通的AE,因为网络训练时为了尽量减小重构误差,必然使得方差逐渐被降到0,这样便不再会有随机采样噪声,也就变成了普通的AE。(出处:https://www.jianshu.com/p/ffd493e10751)
这里使用MNIST手写数字数据集,生成重建的图像
算法中一个重要的步骤是隐层中高斯分布的拟合过程,构建reparameterize函数进行参数mu和var的迭代调整
整个模型的过程就是,输入图像,编码过程实现特征提取和高斯分布拟合,解码过程根据特征重建图像
# variational_autoencoder import os import torch import torch.nn as nn import torch.nn.functional as F import torchvision from torchvision import transforms from torchvision.utils import save_image # Device configuration device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # Create a directory if not exists sample_dir = 'samples' if not os.path.exists(sample_dir): os.makedirs(sample_dir) # Hyper-parameters image_size = 784 h_dim = 400 z_dim = 20 num_epochs = 15 batch_size = 128 learning_rate = 1e-3 # MNIST dataset dataset = torchvision.datasets.MNIST( root='../../data', train=True, transform=transforms.ToTensor(), download=True) # Data loader data_loader = torch.utils.data.DataLoader( dataset=dataset, batch_size=batch_size, shuffle=True) # VAE model class VAE(nn.Module): def __init__(self, image_size=784, h_dim=400, z_dim=20): super(VAE, self).__init__() self.fc1 = nn.Linear(image_size, h_dim) self.fc2 = nn.Linear(h_dim, z_dim) self.fc3 = nn.Linear(h_dim, z_dim) self.fc4 = nn.Linear(z_dim, h_dim) self.fc5 = nn.Linear(h_dim, image_size) def encode(self, x): h = F.relu(self.fc1(x)) return self.fc2(h), self.fc3(h) # two encoders for mu and var, respectively def reparameterize(self, mu, log_var): std = torch.exp(log_var/2) eps = torch.randn_like(std) return mu + eps * std def decode(self, z): h = F.relu(self.fc4(z)) return F.sigmoid(self.fc5(h)) def forward(self, x): mu, log_var = self.encode(x) z = self.reparameterize(mu, log_var) x_reconst = self.decode(z) return x_reconst, mu, log_var model = VAE().to(device) optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate) # Start training for epoch in range(num_epochs): for i, (x, _) in enumerate(data_loader): # Forward pass x = x.to(device).view(-1, image_size) x_reconst, mu, log_var = model(x) # Compute reconstruction loss and kl divergence # For KL divergence, see Appendix B in VAE paper or http://yunjey47.tistory.com/43 reconst_loss = F.binary_cross_entropy(x_reconst, x, size_average=False) kl_div = - 0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp()) # Backprop and optimize loss = reconst_loss + kl_div optimizer.zero_grad() loss.backward() optimizer.step() if (i+1) % 100 == 0: print ("Epoch[{}/{}], Step [{}/{}], Reconst Loss: {:.4f}, KL Div: {:.4f}" .format(epoch+1, num_epochs, i+1, len(data_loader), reconst_loss.item(), kl_div.item())) with torch.no_grad(): # Save the sampled images z = torch.randn(batch_size, z_dim).to(device) out = model.decode(z).view(-1, 1, 28, 28) save_image(out, os.path.join(sample_dir, 'sampled-{}.png'.format(epoch+1))) # Save the reconstructed images out, _, _ = model(x) x_concat = torch.cat([x.view(-1, 1, 28, 28), out.view(-1, 1, 28, 28)], dim=3) save_image(x_concat, os.path.join(sample_dir, 'reconst-{}.png'.format(epoch+1)))
程序输出除了训练过程的重建Loss和KL散度,还有decode的图像,包括VAE和AE图像,这里展示epoch最大输出的两个例子,可以看出VAE生成的图像精度高出不少。
(1)VAE生成
(2)AE生成
Epoch[1/15], Step [100/469], Reconst Loss: 22479.4258, KL Div: 1291.5803 Epoch[1/15], Step [200/469], Reconst Loss: 18100.7598, KL Div: 1868.1145 Epoch[1/15], Step [300/469], Reconst Loss: 15728.3721, KL Div: 2310.5405 Epoch[1/15], Step [400/469], Reconst Loss: 14573.3467, KL Div: 2457.8628 Epoch[2/15], Step [100/469], Reconst Loss: 13343.7812, KL Div: 2688.8018 Epoch[2/15], Step [200/469], Reconst Loss: 12820.8320, KL Div: 2670.8801 Epoch[2/15], Step [300/469], Reconst Loss: 12177.3770, KL Div: 2857.9192 Epoch[2/15], Step [400/469], Reconst Loss: 12119.5352, KL Div: 2913.0271 Epoch[3/15], Step [100/469], Reconst Loss: 11455.6797, KL Div: 2936.3635 Epoch[3/15], Step [200/469], Reconst Loss: 11355.9854, KL Div: 2994.1992 Epoch[3/15], Step [300/469], Reconst Loss: 11565.2637, KL Div: 3027.9497 Epoch[3/15], Step [400/469], Reconst Loss: 11625.3447, KL Div: 3047.5901 Epoch[4/15], Step [100/469], Reconst Loss: 11396.5977, KL Div: 3111.6401 Epoch[4/15], Step [200/469], Reconst Loss: 11895.6436, KL Div: 3192.7432 Epoch[4/15], Step [300/469], Reconst Loss: 10787.6719, KL Div: 3120.9729 Epoch[4/15], Step [400/469], Reconst Loss: 10792.5635, KL Div: 3101.8181 Epoch[5/15], Step [100/469], Reconst Loss: 11358.7930, KL Div: 3227.3677 Epoch[5/15], Step [200/469], Reconst Loss: 10595.2998, KL Div: 3087.9536 Epoch[5/15], Step [300/469], Reconst Loss: 11012.8457, KL Div: 3079.9478 Epoch[5/15], Step [400/469], Reconst Loss: 11031.8301, KL Div: 3274.6953 Epoch[6/15], Step [100/469], Reconst Loss: 10727.4932, KL Div: 3074.3291 Epoch[6/15], Step [200/469], Reconst Loss: 10766.1553, KL Div: 3205.7544 Epoch[6/15], Step [300/469], Reconst Loss: 10917.2773, KL Div: 3153.5269 Epoch[6/15], Step [400/469], Reconst Loss: 11135.8389, KL Div: 3166.1350 Epoch[7/15], Step [100/469], Reconst Loss: 10622.8848, KL Div: 3265.8269 Epoch[7/15], Step [200/469], Reconst Loss: 10808.3926, KL Div: 3163.3755 Epoch[7/15], Step [300/469], Reconst Loss: 10255.6533, KL Div: 3148.1663 Epoch[7/15], Step [400/469], Reconst Loss: 10487.1641, KL Div: 3009.9302 Epoch[8/15], Step [100/469], Reconst Loss: 10424.5625, KL Div: 3154.1379 Epoch[8/15], Step [200/469], Reconst Loss: 10814.4883, KL Div: 3221.2366 Epoch[8/15], Step [300/469], Reconst Loss: 10307.5762, KL Div: 3272.5889 Epoch[8/15], Step [400/469], Reconst Loss: 10350.0527, KL Div: 3236.4878 Epoch[9/15], Step [100/469], Reconst Loss: 10028.5371, KL Div: 3131.3210 Epoch[9/15], Step [200/469], Reconst Loss: 10316.7715, KL Div: 3235.4766 Epoch[9/15], Step [300/469], Reconst Loss: 10969.9980, KL Div: 3212.9060 Epoch[9/15], Step [400/469], Reconst Loss: 10779.7207, KL Div: 3261.3821 Epoch[10/15], Step [100/469], Reconst Loss: 10576.8887, KL Div: 3287.4534 Epoch[10/15], Step [200/469], Reconst Loss: 10241.1055, KL Div: 3205.7427 Epoch[10/15], Step [300/469], Reconst Loss: 10066.1045, KL Div: 3187.4636 Epoch[10/15], Step [400/469], Reconst Loss: 10090.7051, KL Div: 3259.8884 Epoch[11/15], Step [100/469], Reconst Loss: 10129.8330, KL Div: 3116.4709 Epoch[11/15], Step [200/469], Reconst Loss: 10742.1025, KL Div: 3320.4324 Epoch[11/15], Step [300/469], Reconst Loss: 9563.3086, KL Div: 3134.5513 Epoch[11/15], Step [400/469], Reconst Loss: 10116.4502, KL Div: 3038.9773 Epoch[12/15], Step [100/469], Reconst Loss: 10564.5547, KL Div: 3168.8032 Epoch[12/15], Step [200/469], Reconst Loss: 10309.9707, KL Div: 3233.3108 Epoch[12/15], Step [300/469], Reconst Loss: 10618.7500, KL Div: 3306.9241 Epoch[12/15], Step [400/469], Reconst Loss: 9806.7266, KL Div: 3111.2397 Epoch[13/15], Step [100/469], Reconst Loss: 10133.7803, KL Div: 3193.5210 Epoch[13/15], Step [200/469], Reconst Loss: 9875.7354, KL Div: 3192.7917 Epoch[13/15], Step [300/469], Reconst Loss: 10075.5908, KL Div: 3207.8071 Epoch[13/15], Step [400/469], Reconst Loss: 10066.5723, KL Div: 3236.6553 Epoch[14/15], Step [100/469], Reconst Loss: 10055.6152, KL Div: 3273.6326 Epoch[14/15], Step [200/469], Reconst Loss: 9802.0449, KL Div: 3094.6011 Epoch[14/15], Step [300/469], Reconst Loss: 10179.3486, KL Div: 3199.3838 Epoch[14/15], Step [400/469], Reconst Loss: 10153.1211, KL Div: 3148.5806 Epoch[15/15], Step [100/469], Reconst Loss: 10289.4541, KL Div: 3192.3494 Epoch[15/15], Step [200/469], Reconst Loss: 10085.9668, KL Div: 3131.0674 Epoch[15/15], Step [300/469], Reconst Loss: 9941.9395, KL Div: 3189.7222 Epoch[15/15], Step [400/469], Reconst Loss: 10195.7031, KL Div: 3279.8584
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。