赞
踩
前言:
这里主要实现: Variational Autoencoders (VAEs) 变分自动编码器
其训练效果如下
训练的过程中要注意调节forward 中的kle ,调参。
整个工程两个文件:
vae.py
main.py
目录:
一 vae
文件名: vae.py
作用: Variational Autoencoders (VAE)
训练的过程中加入一些限制,使它的latent space规则一点呢。于是就引入了variational autoencoder(VAE),它被定义为一个有规律地训练以避免过度拟合的Autoencoder,可以确保潜在空间具有良好的属性从而实现内容的生成。
variational autoencoder的架构和Autoencoder差不多,区别在于不再是把输入当作一个点,而是把输入当成一个分布。
- # -*- coding: utf-8 -*-
- """
- Created on Wed Aug 30 14:19:19 2023
- @author: chengxf2
- """
-
- import torch
- from torch import nn
-
- #ae: AutoEncoder
-
- class VAE(nn.Module):
-
- def __init__(self,hidden_size=20):
-
- super(VAE, self).__init__()
-
- self.encoder = nn.Sequential(
- nn.Linear(in_features=784, out_features=256),
- nn.ReLU(),
- nn.Linear(in_features=256, out_features=128),
- nn.ReLU(),
- nn.Linear(in_features=128, out_features=64),
- nn.ReLU(),
- nn.Linear(in_features=64, out_features=hidden_size),
- nn.ReLU()
- )
- # hidden [batch_size, 10]
-
- h_dim = int(hidden_size/2)
- self.hDim = h_dim
-
- self.decoder = nn.Sequential(
- nn.Linear(in_features=h_dim, out_features=64),
- nn.ReLU(),
- nn.Linear(in_features=64, out_features=128),
- nn.ReLU(),
- nn.Linear(in_features=128, out_features=256),
- nn.ReLU(),
- nn.Linear(in_features=256, out_features=784),
- nn.Sigmoid()
- )
-
-
- def forward(self, x):
- '''
- param x:[batch, 1,28,28]
- return
-
- '''
-
- batchSz= x.size(0)
- #flatten
- x = x.view(batchSz, 784)
-
- #encoder
- h= self.encoder(x)
-
- #在给定维度上对所给张量进行分块,前一半的神经元看作u, 后一般的神经元看作sigma
- u, sigma = h.chunk(2,dim=1)
-
- #Reparameterize trick:
- #randn_like:产生一个正太分布 ~ N(0,1)
- #h.shape [batchSize,self.hDim]
- h = u+sigma* torch.randn_like(sigma)
-
- #kld :1e-8 防止sigma 平方为0
- kld = 0.5*torch.sum(
- torch.pow(u,2)+
- torch.pow(sigma,2)-
- torch.log(1e-8+torch.pow(sigma,2))-
- 1
- )
-
- #MSE loss 是平均loss, 所以kld 也要算一个平均值
- kld = kld/(batchSz*32*32)
- xHat = self.decoder(h)
-
- #reshape
- xHat = xHat.view(batchSz,1,28,28)
-
- return xHat,kld
-
-
-

二 main
文件名: main.py
作用: 训练,测试数据集
- # -*- coding: utf-8 -*-
- """
- Created on Wed Aug 30 14:24:10 2023
- @author: chengxf2
- """
-
- import torch
- from torch.utils.data import DataLoader
- from torchvision import transforms, datasets
- import time
- from torch import optim,nn
- from vae import VAE
- import visdom
-
-
-
-
-
- def main():
-
- batchNum = 32
- lr = 1e-3
- epochs = 20
- device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
- torch.manual_seed(1234)
- viz = visdom.Visdom()
- viz.line([0],[-1],win='train_loss',opts =dict(title='train acc'))
-
-
-
-
- tf= transforms.Compose([ transforms.ToTensor()])
- mnist_train = datasets.MNIST('mnist',True,transform= tf,download=True)
- train_data = DataLoader(mnist_train, batch_size=batchNum, shuffle=True)
-
- mnist_test = datasets.MNIST('mnist',False,transform= tf,download=True)
- test_data = DataLoader(mnist_test, batch_size=batchNum, shuffle=True)
- global_step =0
-
-
-
-
-
-
- model =VAE().to(device)
- criteon = nn.MSELoss().to(device) #损失函数
- optimizer = optim.Adam(model.parameters(),lr=lr) #梯度更新规则
-
- print("\n ----main-----")
- for epoch in range(epochs):
-
- start = time.perf_counter()
- for step ,(x,y) in enumerate(train_data):
- #[b,1,28,28]
- x = x.to(device)
- x_hat,kld = model(x)
-
- loss = criteon(x_hat, x)
-
- if kld is not None:
-
-
- elbo = -loss -1.0*kld
- loss = -elbo
- #backprop
- optimizer.zero_grad()
- loss.backward()
- optimizer.step()
- viz.line(Y=[loss.item()],X=[global_step],win='train_loss',update='append')
- global_step +=1
-
-
-
-
- end = time.perf_counter()
- interval = int(end - start)
-
- print("epoch: %d"%epoch, "\t 训练时间 %d"%interval, '\t 总loss: %4.7f'%loss.item(),"\t KL divergence: %4.7f"%kld.item())
-
- x,target = iter(test_data).next()
- x = x.to(device)
- with torch.no_grad():
- x_hat,kld = model(x)
-
- tip = 'hat'+str(epoch)
- viz.images(x,nrow=8, win='x',opts=dict(title='x'))
- viz.images(x_hat,nrow=8, win='x_hat',opts=dict(title=tip))
-
-
-
-
-
-
- if __name__ == '__main__':
-
- main()
-

参考:
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。