赞
踩
本文代码出处
感悟:在学习某个网络的时候,一味的看概念越看越不理解,此时我们可以找一份源代码,试着写一遍,也许就会明白了。
#----------引入需要的库---------- import argparse #配置超参数的库 import os#用来创建文件夹 import numpy as np #对数据进行一些操作 import torchvision.transforms as transforms #保存图片 from torchvision.utils import save_image #数据加载器 from torch.utils.data import DataLoader #加载数据 from torchvision import datasets #在旧版本的PyTorch中,Variable 类被用于包装张量, #并自动跟踪对该张量的所有操作,从而支持自动计算梯度。 #它是PyTorch自动微分系统的核心组件。 from torch.autograd import Variable import torch.nn as nn import torch
#----------创建文件夹和配置一些参数---------- # 创建文件夹 os.makedirs("./images/gan/",exist_ok=True) os.makedirs("./save/gan/",exist_ok=True) os.makedirs("./datasets/mnist/",exist_ok=True) #超参数配置-->超参数就是我们可以设置的 parser = argparse.ArgumentParser() parser.add_argument("--n_epochs",type = int,default = 50,help = "number of epochs of training" ) parser.add_argument("--batch_size",type = int,default = 2,help="size of the batches") parser.add_argument("--lr",type = float,default = 0.0002,help="adam: learning rate") parser.add_argument("--b1",type = float,default = 0.5, help="adam: decay of first order momentum of gradient") parser.add_argument("--b2",type = float,default = 0.999, help="adam: decay of first order momentum of gradient") parser.add_argument("--n_cpu",type = int, default = 2, help="number of cpu threads to use during batch generation") parser.add_argument("--latent_dim",type = int,default = 100, help="dimensionality of the latent space") parser.add_argument("--img_size",type = int,default = 28, help="size of each image dimension") parser.add_argument("--channels", type = int,default = 1, help="number of image channels") parser.add_argument("--sample_interval",type = int,default = 500, help="interval betwen image samples") opt = parser.parse_known_args()[0] # print(opt)输出结果如下 #Namespace(b1=0.5, b2=0.999, batch_size=2, channels=1, img_size=28, latent_dim=100, lr=0.0002, n_cpu=2, n_epochs=50, sample_interval=500)
#----------下载数据并将数据份数分好---------- #(1,28,28) img_shape = (opt.channels,opt.img_size,opt.img_size) print(img_shape) #计算数组所有元素的乘积 img_area = np.prod(img_shape) print(img_area) #cuda 查看cuda是否可用 cuda = True if torch.cuda.is_available() else False #mnist 数据集下载并对数据做一些处理 mnist = datasets.MNIST(root = "./datasets",train = True, download = True, transform = transforms.Compose( [transforms.Resize(opt.img_size), transforms.ToTensor(), transforms.Normalize([0.5],[0.5])])) #加载器 分批次加载数据集, #将数据分成len(dataloader)/batchsize【不是很严谨】份送入网络 dataloader = DataLoader( mnist, batch_size = opt.batch_size, shuffle = True)
#----------判别器---------- class Discriminator(nn.Module): def __init__(self): super(Discriminator,self).__init__() self.model = nn.Sequential( nn.Linear(img_area,512), nn.LeakyReLU(0.2,inplace = True), nn.Linear(512,256), nn.LeakyReLU(0.2,inplace = True), nn.Linear(256,1), nn.Sigmoid() ) def forward(self,img): img_flat = img.view(img.size(0),-1) validity = self.model(img_flat) print(validity) return validity
#----------生成器---------- class Generator(nn.Module): def __init__(self): super(Generator,self).__init__( ) def block(in_feat,out_feat,normalize = True): layers = [nn.Linear(in_feat,out_feat)] if normalize: layers.append(nn.BatchNorm1d(out_feat,0.8)) layers.append(nn.LeakyReLU(0.2,inplace=True)) return layers self.model = nn.Sequential( # *操作符被称为解包操作符 # 这里可以理解成将block中的层一个一个的写到了Sequential中 *block(opt.latent_dim,128,normalize = False), *block(128,256), *block(256,512), *block(512,1024), nn.Linear(1024,img_area), nn.Tanh() ) def forward(self,z): imgs = self.model(z) #imgs.view的意思, #(2, batch_size * channels * height * width) # imgs.size() torch.Size([2, 1, 28, 28]) imgs = imgs.view(imgs.size(0),*img_shape) #imgs.size(0)就是2 #将imgs从一维向量重新reshape为合适的图像形状img_shape return imgs
# 创建生成器,判别器对象
generator = Generator()
discriminator = Discriminator()
#loss
criterion = torch.nn.BCELoss()
#youhua
optimizer_G = torch.optim.Adam(generator.parameters(),lr = opt.lr,betas = (opt.b1,opt.b2))
optimizer_D = torch.optim.Adam(discriminator.parameters(),lr = opt.lr,betas = (opt.b1,opt.b2))
#有cuda就在cuda上运行
if torch.cuda.is_available():
generator = generator.cuda()
discriminator = discriminator.cuda()
criterion = criterion.cuda()
# training for epoch in range(opt.n_epochs): # 注意enumerate返回值有两个,一个是序号,一个是数据(包含训练数据和标签) for i,(imgs,_) in enumerate(dataloader): imgs = imgs.view(imgs.size(0),-1) #使用cuda,就在后边加个.cuda() #real_img = Variable(imgs).cuda() real_img = Variable(imgs) # imgs.size() torch.Size([2, 1, 28, 28]) #2个三维数组,1个二维数组,28个一维数组 #real_label = Variable(torch.ones(imgs.size(0),1)).cuda()#全1 #fake_label = Variable(torch.zeros(imgs.size(0),1)).cuda()#全0 #-----------------重点----------------- #-为什么real_label是全1呢?fake_label全为0呢?- #------------这就是gan的原理啦------------ #在生成对抗网络(GAN)中,判别器的目标是将真实样本判别为1, #将生成的假样本判别为0。 real_label = Variable(torch.ones(imgs.size(0),1)) fake_label = Variable(torch.zeros(imgs.size(0),1)) ## --------------------- ## Train Discriminator ## 分为两部分:1、真的图像判别为真;2、假的图像判别为假 ## --------------------- ## 计算真实图片的损失 real_out = discriminator(real_img) #print(real_img.view(real_img.size(0),-1)) loss_real_D = criterion(real_out,real_label) real_scores = real_out ## 计算假的图片的损失 ## detach(): 从当前计算图中分离下来避免梯度传到G,因为G不用更新 # detach还没很理解 #z = Variable(torch.randn(imgs.size(0),opt.latent_dim)).cuda() z = Variable(torch.randn(imgs.size(0),opt.latent_dim)) fake_img = generator(z).detach() fake_out = discriminator(fake_img) loss_fake_D = criterion(fake_out,fake_label) fake_scores = fake_out #损失函数和优化 loss_D = loss_real_D + loss_fake_D optimizer_D.zero_grad() loss_D.backward() optimizer_D.step() # --------------------- # Train Generator # --------------------- #z = Variable(torch.randn(imgs.size(0),opt.latent_dim)).cuda() z = Variable(torch.randn(imgs.size(0),opt.latent_dim)) fake_img = generator(z) output = discriminator(fake_img) #损失函数和优化 loss_G = criterion(output,real_label) optimizer_G.zero_grad() loss_G.backward() optimizer_G.step() ## 打印训练过程中的日志 ## item():取出单元素张量的元素值并返回该值,保持原元素类型不变 if(i+1) % 100 == 0: print( "[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f] [D real: %f] [D fake: %f]" %(epoch,opt.n_epochs,i,len(dataloader),loss_D.item(),loss_G.item(),real_scores.data.mean(),fake_scores.data.mean()) ) # 保存训练过程中的图像 batches_done = epoch * len(dataloader) + i if batches_done % opt.sample_interval == 0: save_image(fake_img.data[:25],"./images/gan/%d.png"%batches_done,nrow = 5,normalize = True) # 保存模型 #将生成器和判别器的网络权重保存起来 torch.save(generator.state_dict(),"./save/gan/generator.pth") torch.save(discriminator.state_dict(),"./save/gan/discriminator.pth")
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。