赞
踩
图像翻译 图像补全 数据增广
假设有两个网络,生成网络G(Generator)和判别网络D(Discriminator)。它们的功能分别是:· G负责生成图片,它接收一个随机的噪声z,通过该噪声生成图片,将生成的图片记为G(z)。· D负责判别一张图片是不是“真实的”。它的输入是x, x代表一张图片,输出D(x)表示x为真实图片的概率,如果为1,代表是真实图片的概率为100%,而输出为0,代表不可能是真实的图片。
对于这个损失函数,需要认识下面几点:· 整个式子由两项构成。x表示真实图片,z表示输入G网络的噪声,而G(z)表示G网络生成的图片。· D(x)表示D网络判断真实图片是否真实的概率(因为x是真实的,所以对于D来说,这个值越接近1越好)。而D(G(z))是D网络判断G生成的图片是否真实的概率。· G的目的:G应该希望自己生成的图片“越接近真实越好”。也是说,G希望D(G(z))尽可能得大,这时V(D, G)会变小。· D的目的:D的能力越强,D(x)应该越大,D(G(x))应该越小。因此D的目的和G不同,D应该希望V(D, G)越大越好。
在实际训练中,使用梯度下降法,对D和G交替做优化即可,详细的步骤为:
第1步:从已知的噪声分布pz(z)中选出一些样本{z(1), z(2), ····, z(m)}。
第2步:从训练数据中选出同样个数的真实图片{x(1), x(2), ····, x(m)}。
第3步:设判别器D的参数为,求出损失关于参数的梯度[插图],对θd更新时加上该梯度。
第4步:设生成器G的参数为θg,求出损失关于参数的梯度[插图],对θg更新时减去该梯度。
model.py
- # -*- coding: utf-8 -*-#
-
- #-------------------------------------------------------------------------------
- # Name: GANmodel
- # Description:
- # Author: Administrator
- # Date: 2020/12/9
- '''
- 参考博客:https://blog.csdn.net/jizhidexiaoming/article/details/96485095
- '''
- #-------------------------------------------------------------------------------
- # coding=utf-8
- import torch.autograd
- import torch.nn as nn
- from torch.autograd import Variable
- from torchvision import transforms
- from torchvision import datasets
- from torchvision.utils import save_image
- import os
- from CreateMyData import MyDataset
- import numpy as np
-
- # 创建文件夹
- if not os.path.exists('./img'):
- os.mkdir('./img')
-
-
- def to_img(x):
- out = 0.5 * (x + 1)
- out = out.clamp(0, 255) # Clamp函数可以将随机变化的数值限制在一个给定的区间[min, max]内:
- out = out.view(-1,3, 256, 256) # view()函数作用是将一个多行的Tensor,拼接成一行,256是图像的大小
- return out
-
- #每次喂入数据的数量
- batch_size =10
- #训练的轮数
- num_epoch = 100
- #噪声的维度
- z_dimension = 100
-
- # 图像预处理
- img_transform = transforms.Compose([
- transforms.ToTensor(),
- transforms.Normalize((0.5,), (0.5,)) # (x-mean) / std
- ])
-
- #加载数据集
- train_txt_path='F:\\ClassNetWork\\GANnet\\catmydata\\list.txt'
- train_data=MyDataset(txt_path=train_txt_path,transform=img_transform)
-
- # data loader 数据载入
- dataloader = torch.utils.data.DataLoader( dataset=train_data, batch_size=batch_size, shuffle=True)
-
-
- # 定义判别器 #####Discriminator######使用多层网络来作为判别器
- # 将图片256x256展开成784,然后通过多层感知器,中间经过斜率设置为0.2的LeakyReLU激活函数,
- # 最后接sigmoid激活函数得到一个0到1之间的概率进行二分类。
- class discriminator(nn.Module):
- def __init__(self):
- super(discriminator, self).__init__()
- self.dis = nn.Sequential(
- nn.Linear(196608, 256), # 输入特征数为256x256x3=196608,输出为256
- nn.LeakyReLU(0.2), # 进行非线性映射
- nn.Linear(256, 256), # 进行一个线性映射
- nn.LeakyReLU(0.2),
- nn.Linear(256, 1),
- nn.Sigmoid() # 也是一个激活函数,二分类问题中,
- # sigmoid可以班实数映射到【0,1】,作为概率值,
- # 多分类用softmax函数
- )
- #创建对象的时候回自动调用前向传播函数
- def forward(self, x):
- x = self.dis(x)
- return x
-
-
- # ###### 定义生成器 Generator #####
- # 输入一个100维的0~1之间的高斯分布,然后通过第一层线性变换将其映射到256维,
- # 然后通过LeakyReLU激活函数,接着进行一个线性变换,再经过一个LeakyReLU激活函数,
- # 然后经过线性变换将其变成784维,最后经过Tanh激活函数是希望生成的假的图片数据分布
- # 能够在-1~1之间。
- class generator(nn.Module):
- def __init__(self):
- super(generator, self).__init__()
- self.gen = nn.Sequential(
- nn.Linear(100, 256), # 用线性变换将输入映射到256维,100输入层维度,256输出层维度
- nn.ReLU(True), # relu激活
- nn.Linear(256, 256), # 线性变换
- nn.ReLU(True), # relu激活
- nn.Linear(256, 196608), # 线性变换,输出数据的维度为256*256*3=196608
- nn.Tanh() # Tanh激活使得生成数据分布在【-1,1】之间,因为输入的真实数据的经过transforms之后也是这个分布
- )
- #创建对象的时候回自动调用前向传播函数
- def forward(self, x):
- x = self.gen(x)
- return x
-
-
- # 创建对象
- D = discriminator()
- G = generator()
- #如果cuda可以用,调用使用,后文删去了cuda(显卡)
- if torch.cuda.is_available():
- D = D.cuda()
- G = G.cuda()
-
-
- # 首先需要定义loss的度量方式 (二分类的交叉熵)
- # 其次定义 优化函数,优化函数的学习率为0.0003
- criterion = nn.BCELoss() # 是单目标二分类交叉熵函数
- d_optimizer = torch.optim.Adam(D.parameters(), lr=0.0003)
- g_optimizer = torch.optim.Adam(G.parameters(), lr=0.0003)
-
- # ##########################进入训练##判别器的判断过程#####################
- for epoch in range(num_epoch): # 进行多个epoch的训练
- for i, img in enumerate(dataloader,0):
- print(i)
- #print(type(img))
- img,label=img
- num_img=img
- #print(type(img))
- #print(img.shape)
- # view()函数作用是将一个多行的Tensor,拼接成一行
- # 第一个参数是要拼接的tensor,第二个参数是-1
- # =============================训练判别器==================
- img_nums=img.size(0)
- #print(img.size(0))
- img = img.view(batch_size, -1) # 将图片展开为28*28=784
- #print(img.shape)
- real_img = Variable(img) # 将tensor变成Variable放入计算图中
- real_label = Variable(torch.ones( img_nums)) # 定义真实的图片label为1
- fake_label = Variable(torch.zeros( img_nums)) # 定义假的图片的label为0
-
- # ########判别器训练train#####################
- #------训练判别器时需要计算真图片和假图片两种图片对应的损失值,然后相加一起计算。------
- # 分为两部分:1、真的图像判别为真;2、假的图像判别为假
- # 计算真实图片的损失
- real_out = D(real_img) # 将真实图片放入判别器中
- real_label=real_label.reshape( img_nums,1)#将128变为[128,1],方便和real_out的维度相同
- fake_label =fake_label.reshape( img_nums,1)#将128变为[128,1],方便和fake_out的维度相同
- d_loss_real = criterion(real_out, real_label) # 得到真实图片的loss
- real_scores = real_out # 得到真实图片的判别值,输出的值越接近1越好
- # 计算假的图片的损失
- #从标准正态分布(均值为0,方差为1)中抽取的一组随机数
- z = Variable(torch.randn( img_nums, z_dimension)) # 随机生成一些噪声
- fake_img = G(z).detach() # 随机噪声放入生成网络中,生成一张假的图片。 # 避免梯度传到G,因为G不用更新, detach分离
- #print(fake_img.shape)
- fake_out = D(fake_img) # 判别器判断假的图片,
- d_loss_fake = criterion(fake_out, fake_label) # 得到假的图片的loss
- fake_scores = fake_out # 得到假图片的判别值,对于判别器来说,假图片的损失越接近0越好
- # 损失函数和优化
- d_loss = d_loss_real + d_loss_fake # 损失包括判真损失和判假损失
- d_optimizer.zero_grad() # 在反向传播之前,先将梯度归0
- d_loss.backward() # 将误差反向传播
- d_optimizer.step() # 更新参数
-
- # ==================训练生成器============================
- # ###############################生成网络的训练###############################
- #------训练生成器时候只需要考虑生成器对应的损失值,让生成器把假的图像当成真的
- # 原理:目的是希望生成的假的图片被判别器判断为真的图片,
- # 在此过程中,将判别器固定,将假的图片传入判别器的结果与真实的label对应,
- # 反向传播更新的参数是生成网络里面的参数,
- # 这样可以通过更新生成网络里面的参数,来训练网络,使得生成的图片让判别器以为是真的
- # 这样就达到了对抗的目的
- # 计算假的图片的损失
- z = Variable(torch.randn(batch_size, z_dimension)) # 得到随机噪声
- fake_img = G(z) # 随机噪声输入到生成器中,得到一副假的图片
- output = D(fake_img) # 经过判别器得到的结果
- g_loss = criterion(output, real_label) # 得到的假的图片与真实的图片的label的loss
- # bp and optimize
- g_optimizer.zero_grad() # 梯度归0
- g_loss.backward() # 进行反向传播
- g_optimizer.step() # .step()一般用在反向传播后面,用于更新生成网络的参数
-
- # 打印中间的损失
- if (i + 1) % 10 == 0:
- print('Epoch[{}/{}],d_loss:{:.6f},g_loss:{:.6f} '
- 'D real: {:.6f},D fake: {:.6f}'.format(
- epoch, num_epoch, d_loss.data.item(), g_loss.data.item(),
- real_scores.data.mean(), fake_scores.data.mean() # 打印的是真实图片的损失均值
- ))
- if epoch == 0:
- real_images = to_img(real_img.cpu().data)
- #print(real_images.shape)
- save_image(real_images, './img/real_images.png')
- fake_images = to_img(fake_img.cpu().data)
- save_image(fake_images, './img/fake_images-{}.png'.format(epoch + 1))
-
- # 保存模型
- torch.save(G.state_dict(), './generator.pth')
- torch.save(D.state_dict(), './discriminator.pth')
-
-

createdata.py
- # -*- coding: utf-8 -*-#
-
- #-------------------------------------------------------------------------------
- # Name: CreateMyData
- # Description:
- # Author: Administrator
- # Date: 2020/12/9
- #-------------------------------------------------------------------------------
- # coding: utf-8
- from PIL import Image
- from torch.utils.data import Dataset
- class MyDataset(Dataset):
- def __init__(self, txt_path, transform = None, target_transform = None):
- fh = open(txt_path, 'r')
- imgs = []
- for line in fh:
- line = line.rstrip()
- words = line.split()
- imgs.append((words[0], int(words[1])))
- self.imgs = imgs
- self.transform = transform
- self.target_transform = target_transform
- def __getitem__(self, index):
- fn, label = self.imgs[index]
- img = Image.open(fn).convert('RGB')
- if self.transform is not None:
- img = self.transform(img)
- #print(img)
- #print(label)
- return img, label
- def __len__(self):
- return len(self.imgs)

自定义数据集包含图像 和TXT文档对目录的索引 ,TXT文档包含目录和标签。
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。