当前位置:   article > 正文

GAN网络的代码实现(学习ing)_gan代码实现

gan代码实现

本文代码出处
感悟:在学习某个网络的时候,一味的看概念越看越不理解,此时我们可以找一份源代码,试着写一遍,也许就会明白了。

#----------引入需要的库----------
import argparse #配置超参数的库
import os#用来创建文件夹
import numpy as np
#对数据进行一些操作
import torchvision.transforms as transforms
#保存图片
from torchvision.utils import save_image
#数据加载器
from torch.utils.data import DataLoader
#加载数据
from torchvision import datasets
#在旧版本的PyTorch中,Variable 类被用于包装张量,
#并自动跟踪对该张量的所有操作,从而支持自动计算梯度。
#它是PyTorch自动微分系统的核心组件。
from torch.autograd import Variable
import torch.nn as nn
import torch
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
#----------创建文件夹和配置一些参数----------
# 创建文件夹
os.makedirs("./images/gan/",exist_ok=True)
os.makedirs("./save/gan/",exist_ok=True)
os.makedirs("./datasets/mnist/",exist_ok=True)

#超参数配置-->超参数就是我们可以设置的
parser = argparse.ArgumentParser()
parser.add_argument("--n_epochs",type = int,default = 50,help = "number of epochs of training" )
parser.add_argument("--batch_size",type = int,default = 2,help="size of the batches")
parser.add_argument("--lr",type = float,default = 0.0002,help="adam: learning rate")
parser.add_argument("--b1",type = float,default = 0.5, help="adam: decay of first order momentum of gradient")
parser.add_argument("--b2",type = float,default = 0.999, help="adam: decay of first order momentum of gradient")
parser.add_argument("--n_cpu",type = int, default = 2, help="number of cpu threads to use during batch generation")
parser.add_argument("--latent_dim",type = int,default = 100, help="dimensionality of the latent space")
parser.add_argument("--img_size",type = int,default = 28, help="size of each image dimension")
parser.add_argument("--channels", type = int,default = 1, help="number of image channels")
parser.add_argument("--sample_interval",type = int,default = 500, help="interval betwen image samples")
opt = parser.parse_known_args()[0]

# print(opt)输出结果如下
#Namespace(b1=0.5, b2=0.999, batch_size=2, channels=1, img_size=28, latent_dim=100, lr=0.0002, n_cpu=2, n_epochs=50, sample_interval=500)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
#----------下载数据并将数据份数分好----------
#(1,28,28)
img_shape = (opt.channels,opt.img_size,opt.img_size)
print(img_shape)
#计算数组所有元素的乘积
img_area = np.prod(img_shape)
print(img_area)
#cuda 查看cuda是否可用
cuda = True if torch.cuda.is_available() else False

#mnist 数据集下载并对数据做一些处理
mnist = datasets.MNIST(root = "./datasets",train = True, download = True, transform = transforms.Compose(
                        [transforms.Resize(opt.img_size),
                         transforms.ToTensor(),
                         transforms.Normalize([0.5],[0.5])]))
#加载器 分批次加载数据集,
#将数据分成len(dataloader)/batchsize【不是很严谨】份送入网络
dataloader = DataLoader(
    mnist,
    batch_size = opt.batch_size,
    shuffle = True)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
#----------判别器----------
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator,self).__init__()
        self.model = nn.Sequential(
            nn.Linear(img_area,512),
            nn.LeakyReLU(0.2,inplace = True),
            nn.Linear(512,256),
            nn.LeakyReLU(0.2,inplace = True),
            nn.Linear(256,1),
            nn.Sigmoid()
            
        )
    def forward(self,img):
        img_flat = img.view(img.size(0),-1)
        validity = self.model(img_flat)
        print(validity)
        return validity
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
#----------生成器----------
class Generator(nn.Module):
    def __init__(self):
        super(Generator,self).__init__( )
        def block(in_feat,out_feat,normalize = True):
            layers = [nn.Linear(in_feat,out_feat)]
            if  normalize:
                layers.append(nn.BatchNorm1d(out_feat,0.8))
            layers.append(nn.LeakyReLU(0.2,inplace=True))
            return layers
        self.model = nn.Sequential(
        	# *操作符被称为解包操作符
        	# 这里可以理解成将block中的层一个一个的写到了Sequential中
            *block(opt.latent_dim,128,normalize = False),
            *block(128,256),
            *block(256,512),
            *block(512,1024),
            nn.Linear(1024,img_area),
            nn.Tanh()
        )
    def forward(self,z):
        imgs = self.model(z)
        #imgs.view的意思,
        #(2, batch_size * channels * height * width)
        # imgs.size()  torch.Size([2, 1, 28, 28])
        imgs = imgs.view(imgs.size(0),*img_shape)
        #imgs.size(0)就是2
        #将imgs从一维向量重新reshape为合适的图像形状img_shape
        return imgs 
        
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
# 创建生成器,判别器对象
generator = Generator()
discriminator = Discriminator()
#loss
criterion = torch.nn.BCELoss()
#youhua
optimizer_G = torch.optim.Adam(generator.parameters(),lr = opt.lr,betas = (opt.b1,opt.b2))
optimizer_D = torch.optim.Adam(discriminator.parameters(),lr = opt.lr,betas = (opt.b1,opt.b2))

#有cuda就在cuda上运行
if torch.cuda.is_available():
    generator = generator.cuda()
    discriminator = discriminator.cuda()
    criterion = criterion.cuda()
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
# training
for epoch in range(opt.n_epochs):
    # 注意enumerate返回值有两个,一个是序号,一个是数据(包含训练数据和标签)
    for i,(imgs,_) in enumerate(dataloader):
        imgs = imgs.view(imgs.size(0),-1)
        #使用cuda,就在后边加个.cuda()
        #real_img = Variable(imgs).cuda()
        real_img = Variable(imgs)
        # imgs.size()  torch.Size([2, 1, 28, 28])
        #2个三维数组,1个二维数组,28个一维数组
        #real_label = Variable(torch.ones(imgs.size(0),1)).cuda()#全1
        #fake_label = Variable(torch.zeros(imgs.size(0),1)).cuda()#全0
        #-----------------重点-----------------
        #-为什么real_label是全1呢?fake_label全为0呢?-
        #------------这就是gan的原理啦------------
        #在生成对抗网络(GAN)中,判别器的目标是将真实样本判别为1,
        #将生成的假样本判别为0。
        real_label = Variable(torch.ones(imgs.size(0),1))
        fake_label = Variable(torch.zeros(imgs.size(0),1))
        ## ---------------------
        ##  Train Discriminator
        ## 分为两部分:1、真的图像判别为真;2、假的图像判别为假
        ## ---------------------
        ## 计算真实图片的损失
        real_out = discriminator(real_img)
        #print(real_img.view(real_img.size(0),-1))
        loss_real_D = criterion(real_out,real_label)
        real_scores = real_out
        ## 计算假的图片的损失
        ## detach(): 从当前计算图中分离下来避免梯度传到G,因为G不用更新
        #  detach还没很理解
        #z = Variable(torch.randn(imgs.size(0),opt.latent_dim)).cuda()
        z = Variable(torch.randn(imgs.size(0),opt.latent_dim))
        fake_img = generator(z).detach()
        fake_out = discriminator(fake_img)
        loss_fake_D = criterion(fake_out,fake_label)
        fake_scores = fake_out
        #损失函数和优化
        loss_D = loss_real_D + loss_fake_D
        
        optimizer_D.zero_grad()
        loss_D.backward()
        optimizer_D.step()
        # ---------------------
        # Train Generator
        # ---------------------
        #z = Variable(torch.randn(imgs.size(0),opt.latent_dim)).cuda()
        z = Variable(torch.randn(imgs.size(0),opt.latent_dim))
        fake_img = generator(z)
        output = discriminator(fake_img)
        #损失函数和优化
        loss_G = criterion(output,real_label)
        
        optimizer_G.zero_grad()
        loss_G.backward()
        optimizer_G.step()
        ## 打印训练过程中的日志
        ## item():取出单元素张量的元素值并返回该值,保持原元素类型不变
        if(i+1) % 100 == 0:
            print(
                "[Epoch %d/%d]  [Batch %d/%d]  [D loss: %f] [G loss: %f] [D real: %f] [D fake: %f]"
                %(epoch,opt.n_epochs,i,len(dataloader),loss_D.item(),loss_G.item(),real_scores.data.mean(),fake_scores.data.mean())
            )
        # 保存训练过程中的图像
        batches_done = epoch * len(dataloader) + i
        if batches_done % opt.sample_interval == 0:
            save_image(fake_img.data[:25],"./images/gan/%d.png"%batches_done,nrow = 5,normalize = True)
# 保存模型
#将生成器和判别器的网络权重保存起来
torch.save(generator.state_dict(),"./save/gan/generator.pth")
torch.save(discriminator.state_dict(),"./save/gan/discriminator.pth")
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/小桥流水78/article/detail/876128
推荐阅读
相关标签
  

闽ICP备14008679号