赞
踩
【图书推荐】《PyTorch深度学习与企业级项目实战》-CSDN博客
《PyTorch深度学习与企业级项目实战(人工智能技术丛书)》(宋立桓,宋立林)【摘要 书评 试读】- 京东图书 (jd.com)
如今AI艺术创作能力越来越强大,Google发布的ImageGen项目基于文本提示作画的结果和真实艺术家的成品难辨真假。本项目将使用PyTorch实现生成式对抗网络生成式对抗网络来完成AI生成动漫人物图像。
本项目中使用的数据集是一个由63 632个高质量动画人脸组成的数据集,从www.getchu.com中抓取,然后使用https://github.com/nagadomi/lbpcascade_animeface中的动画人脸检测算法进行裁剪。图像大小从90×90到120×120不等。该数据集包含高质量的动漫角色图像,具有干净的背景和丰富的颜色。数据集下载链接:https://github.com/bchao1/Anime-Face-Dataset。
我们知道在生成式对抗网络中有两个模型——生成模型(Generative Model,G)和判别模型(Discriminative Model,D)。G就是一个生成图片的网络,它接收一个随机的噪声z,然后通过这个噪声生成图片,生成的数据记作G(z)。D是一个判别网络,判别一幅图片是不是“真实的”(是不是捏造的)。它的输入参数是x,x代表一幅图片,输出D(x)代表x为真实图片的概率,如果为1,就代表是真实的图片,而输出为0,就代表不可能是真实的图片。
模型训练的步骤如下:
步骤1:首先固定生成器,训练判别器,提高真实样本被判别为真的概率,同时降低生成器生成的假图像被判别为真的概率,目标是判别器能准确进行分类。
步骤2:固定判别器,训练生成器,生成器生成图像,尽可能提高该图像被判别器判别为真的概率,目标是生成器的结果能够骗过判别器。
步骤3:重复,循环交替训练,最终生成器生成的样本足够逼真,使得鉴别器只有大约50%的判断正确率(相当于乱猜)。
完整代码如下:
- #####################GANDEMO.py####################
- import os
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
- from torch.utils.data import TensorDataset
- import torchvision
- from torchvision import transforms, datasets
- from tqdm import tqdm
-
- class Config(object):
- data_path = './gandata/data/'
- image_size = 96
- batch_size = 32
- epochs = 200
- lr1 = 2e-3
- lr2 = 2e-4
- beta1 = 0.5
- gpu = False
- device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
- nz = 100
- ngf = 64
- ndf = 64
- save_path = './gandata/images'
- generator_path = './gandata/generator.pkl' #模型保存路径
- discriminator_path = './gandata/discriminator.pkl' #模型保存路径
- gen_img = './gandata/result.png'
- gen_num = 64
- gen_search_num = 5000
- gen_mean = 0
- gen_std = 1
-
- config = Config()
-
- # 1.数据转换
- data_transform = transforms.Compose([
- transforms.Resize(config.image_size),
- transforms.CenterCrop(config.image_size),
- transforms.ToTensor(),
- transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
- ])
-
- # 2.形成训练集
- train_dataset = datasets.ImageFolder(root=os.path.join(config.data_path),
- transform=data_transform)
-
- # 3.形成迭代器
- train_loader = torch.utils.data.DataLoader(train_dataset,
- config.batch_size,
- True,
- drop_last=True)
- print('using {} images for training.'.format(len(train_dataset)))
-
- class Generator(nn.Module):
- def __init__(self, config):
- super().__init__()
-
- ngf = config.ngf
-
- self.model = nn.Sequential(
- nn.ConvTranspose2d(config.nz, ngf * 8, 4, 1, 0),
- nn.BatchNorm2d(ngf * 8),
- nn.ReLU(True),
-
- nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1),
- nn.BatchNorm2d(ngf * 4),
- nn.ReLU(True),
-
- nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1),
- nn.BatchNorm2d(ngf * 2),
- nn.ReLU(True),
-
- nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1),
- nn.BatchNorm2d(ngf),
- nn.ReLU(True),
-
- nn.ConvTranspose2d(ngf, 3, 5, 3, 1),
- nn.Tanh()
- )
-
- def forward(self, x):
- output = self.model(x)
- return output
-
-
- class Discriminator(nn.Module):
- def __init__(self, config):
- super().__init__()
-
- ndf = config.ndf
-
- self.model = nn.Sequential(
- nn.Conv2d(3, ndf, 5, 3, 1),
- nn.LeakyReLU(0.2, inplace=True),
-
- nn.Conv2d(ndf, ndf * 2, 4, 2, 1),
- nn.BatchNorm2d(ndf * 2),
- nn.LeakyReLU(0.2, inplace=True),
-
- nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1),
- nn.BatchNorm2d(ndf * 4),
- nn.LeakyReLU(0.2, inplace=True),
-
- nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1),
- nn.BatchNorm2d(ndf * 8),
- nn.LeakyReLU(0.2, inplace=True),
-
- nn.Conv2d(ndf * 8, 1, 4, 1, 0)
- )
-
- def forward(self, x):
- output = self.model(x)
- return output.view(-1)
-
- generator = Generator(config)
- discriminator = Discriminator(config)
-
- optimizer_generator = torch.optim.Adam(generator.parameters(),
- config.lr1,
- betas=(config.beta1, 0.999))
- optimizer_discriminator = torch.optim.Adam(discriminator.parameters(),
- config.lr2,
- betas=(config.beta1, 0.999))
-
- true_labels = torch.ones(config.batch_size)
- fake_labels = torch.zeros(config.batch_size)
- fix_noises = torch.randn(config.batch_size, config.nz, 1, 1)
- noises = torch.randn(config.batch_size, config.nz, 1, 1)
-
- for epoch in range(config.epochs):
- for ii, (img, _) in tqdm(enumerate(train_loader)):
- real_img = img.to(config.device)
-
- if ii % 2 == 0:
- optimizer_discriminator.zero_grad()
-
- r_preds = discriminator(real_img)
- noises.data.copy_(torch.randn(config.batch_size, config.nz, 1, 1))
- fake_img = generator(noises).detach()
- f_preds = discriminator(fake_img)
-
- r_f_diff = (r_preds - f_preds.mean()).clamp(max=1)
- f_r_diff = (f_preds - r_preds.mean()).clamp(min=-1)
- loss_d_real = (1 - r_f_diff).mean()
- loss_d_fake = (1 + f_r_diff).mean()
- loss_d = loss_d_real + loss_d_fake
-
- loss_d.backward()
- optimizer_discriminator.step()
-
- else:
- optimizer_generator.zero_grad()
- noises.data.copy_(torch.randn(config.batch_size, config.nz, 1, 1))
- fake_img = generator(noises)
- f_preds = discriminator(fake_img)
- r_preds = discriminator(real_img)
- r_f_diff = r_preds - torch.mean(f_preds)
- f_r_diff = f_preds - torch.mean(r_preds)
- loss_g = torch.mean(F.relu(1 + r_f_diff)) \
- + torch.mean(F.relu(1 - f_r_diff))
- loss_g.backward()
- optimizer_generator.step()
-
- if epoch == config.epochs - 1:
- # 保存模型
- torch.save(discriminator.state_dict(), config.discriminator_path)
- torch.save(generator.state_dict(), config.generator_path)
-
- print('Finished Training')
-
- generator = Generator(config)
- discriminator = Discriminator(config)
-
- noises = torch.randn(config.gen_search_num,
- config.nz, 1, 1).normal_(config.gen_mean,
- config.gen_std)
- noises = noises.to(config.device)
-
- generator.load_state_dict(torch.load(config.generator_path,
- map_location='cpu'))
- discriminator.load_state_dict(torch.load(config.discriminator_path,
- map_location='cpu'))
- generator.to(config.device)
- discriminator.to(config.device)
-
- fake_img = generator(noises)
- scores = discriminator(fake_img).detach()
-
- indexs = scores.topk(config.gen_num)[1]
- result = []
- for ii in indexs:
- result.append(fake_img.data[ii])
-
- torchvision.utils.save_image(torch.stack(result), config.gen_img,
- normalize=True, value_range=(-1, 1))
代码运行结果如下:
- using 900 images for training.
- 28it [00:20, 1.40it/s]
- 28it [00:20, 1.33it/s]
- 28it [00:21, 1.29it/s]
- …
- 28it [00:26, 1.06it/s]
- Finished Training
效果图如图13-9所示,由于只训练了100个Epoch,因此图像生成的纹理还不算太清楚,大家计算资源允许的话,可以多训练一些Epoch来生成更多的图像细节。
图13-9
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。