赞
踩
GAN:生成对抗网络。GAN网络中存在两个网络:G(Generator,生成网络)和D(Discriminator,判别网络)。
Generator接收一个随机的噪声z,通过这个噪声生成图片,记做G(z)
Discriminator功能是判别一张图片的真实。它的输入是一张图片x,输出D(x)代表x为真实图片的概率,如果为1就代表图片真实,而输出为0,就代表图片不真实。
在GAN网络的训练中,Generator的目标就是尽量生成真实的图片去欺骗Discriminator
而Discriminator的目标就是尽量把Generator生成的图片和真实的图片分别开来
除了之前使用过的pytorch-npl、numpy以外,我们还需要安装visdom。
pip install visdom
启动visdom
python -m visdom.server
visdom启动成功如下图,会占用8097端口,我们可以通过8097端口访问visdom
- import torch
- from torch import nn,optim,autograd
- import numpy as np
- import visdom
- import random
-
- h_dim = 400
- batchsz = 512
- viz = visdom.Visdom()
- class Generator(nn.Module):
- def __init__(self):
- super(Generator,self).__init__()
- self.net = nn.Sequential(
- # input[b, 2]
- nn.Linear(2,h_dim),
- nn.ReLU(True),
- nn.Linear(h_dim, h_dim),
- nn.ReLU(True),
- nn.Linear(h_dim, h_dim),
- nn.ReLU(True),
- nn.Linear(h_dim, 2)
- # output[b,2]
- )
-
- def forward(self, z):
- output = self.net(z)
- return output
- class Discriminator(nn.Module):
- def __init__(self):
- super(Discriminator, self).__init__()
- self.net = nn.Sequential(
- nn.Linear(2, h_dim),
- nn.ReLU(True),
- nn.Linear(h_dim, h_dim),
- nn.ReLU(True),
- nn.Linear(h_dim, h_dim),
- nn.ReLU(True),
- nn.Linear(h_dim, 1),
- nn.Sigmoid()
- )
-
- def forward(self, x):
- output = self.net(x)
- return output.view(-1)
- def data_generator():
- # 生成中心点
- scale = 2
- centers = [
- (1, 0),
- (-1, 0),
- (0, 1),
- (0, -1),
- (1. / np.sqrt(2), 1. / np.sqrt(2)),
- (1. / np.sqrt(2), -1. / np.sqrt(2)),
- (-1. / np.sqrt(2), 1. / np.sqrt(2)),
- (-1. / np.sqrt(2), -1. / np.sqrt(2))
- ]
- centers = [(scale * x, scale * y) for x,y in centers]
- while True:
- dataset = []
- for i in range(batchsz):
- point = np.random.randn(2) * 0.02
- # 随机选取一个中心点
- center = random.choice(centers)
- # 把刚刚随机到的高斯分布点根据center进行移动
- point[0] += center[0]
- point[1] += center[1]
- dataset.append(point)
- dataset = np.array(dataset).astype(np.float32)
- dataset /= 1.414
- yield dataset
将图片生成到visdom
- import matplotlib.pyplot as plt
- def generate_image(D, G, xr, epoch):
- N_POINTS = 128
- RANGE = 3
- plt.clf()
-
- points = np.zeros((N_POINTS, N_POINTS, 2), dtype='float32')
- points[:,:,0] = np.linspace(-RANGE, RANGE, N_POINTS)[:, None]
- points[:,:,1] = np.linspace(-RANGE, RANGE, N_POINTS)[None, :]
- points = points.reshape((-1,2))
-
- with torch.no_grad():
- points = torch.Tensor(points).cpu()
- disc_map = D(points).cpu().numpy()
- x = y = np.linspace(-RANGE,RANGE,N_POINTS)
- cs = plt.contour(x,y,disc_map.reshape((len(x), len(y))).transpose())
- plt.clabel(cs, inline=1,fontsize=10)
-
- with torch.no_grad():
- z = torch.randn(batchsz, 2).cpu()
- samples = G(z).cpu().numpy()
- plt.scatter(xr[:,0],xr[:,1],c='orange',marker='.')
- plt.scatter(samples[:,0], samples[:,1], c='green',marker='+')
-
- viz.matplot(plt, win='contour',opts=dict(title='p(x):%d'%epoch))
- def run():
- torch.manual_seed(23)
- np.random.seed(23)
-
- data_iter = data_generator()
- x = next(data_iter)
- # print(x.shape)
-
- # G = Generator().cuda()
- # D = Discriminator().cuda()
- # 无显卡环境
- device = torch.device("cpu")
- G = Generator().cpu()
- print(G)
- D = Discriminator().cpu()
- print(D)
-
- optim_G = optim.Adam(G.parameters(), lr = 5e-4, betas=(0.5,0.9))
- optim_D = optim.Adam(D.parameters(), lr = 5e-4, betas=(0.5,0.9))
-
- viz.line([[0,0]],[0],win='loss', opts=dict(title='loss',legend=['D','G']))
-
- """
- gan核心部分
- """
- for epoch in range(50000):
- # 训练判别网络
- for _ in range(5):
- # 真实数据训练
- xr = next(data_iter)
- xr = torch.from_numpy(xr).cpu()
- predr = D(xr)
- # 放大真实数据
- lossr = -predr.mean()
-
- # 虚假数据训练
- z = torch.randn(batchsz,2).cpu()
- xf = G(z).detach()
- predf = D(xf)
- # 缩小虚假数据
- lossf = predf.mean()
-
- loss_D = lossr + lossf
-
- # 梯度清零
- optim_D.zero_grad()
- # 向后传播
- loss_D.backward()
- optim_D.step()
-
-
- # 训练生成网络
- z = torch.randn(batchsz,2).cpu()
- xf = G(z)
- predf = D(xf)
- loss_G = -predf.mean()
- optim_G.zero_grad()
- loss_G.backward()
- optim_G.step()
-
- if epoch % 100 == 0:
- viz.line([[loss_D.item(),loss_G.item()]], [epoch],win='loss', update='append')
- print(loss_D.item(), loss_G.item())
- generate_image(D, G, xr, epoch)
run()
从结果中可以看到,判别网络的loss一直为0,而生成网络一直得不到更新,生成的数据点远离我们创建的中心点
WGAN主要从损失函数的角度对GAN做了改进,对更新后的权重强制截断到一定范围内
- def gradient_penalty(D,xr,xf):
- # [b,1]
- t = torch.rand(batchsz, 1).cpu()
- # 扩展为[b, 2]
- t = t.expand_as(xr)
- # 插值
- mid = t * xr + (1 - t) * xf
- # 设置需要的倒数信息
- mid.requires_grad_()
-
- pred = D(mid)
- grads = autograd.grad(outputs=pred,
- inputs=mid,
- grad_outputs=torch.ones_like(pred),
- create_graph=True,
- retain_graph=True,
- only_inputs=True)[0]
- gp = torch.pow(grads.norm(2, dim=1) - 1, 2).mean()
- return gp
- def run():
- torch.manual_seed(23)
- np.random.seed(23)
-
- data_iter = data_generator()
- x = next(data_iter)
- # print(x.shape)
-
- # G = Generator().cuda()
- # D = Discriminator().cuda()
- # 无显卡环境
- device = torch.device("cpu")
- G = Generator().cpu()
- print(G)
- D = Discriminator().cpu()
- print(D)
-
- optim_G = optim.Adam(G.parameters(), lr = 5e-4, betas=(0.5,0.9))
- optim_D = optim.Adam(D.parameters(), lr = 5e-4, betas=(0.5,0.9))
-
- viz.line([[0,0]],[0],win='loss', opts=dict(title='loss',legend=['D','G']))
-
- """
- gan核心部分
- """
- for epoch in range(50000):
- # 训练判别网络
- for _ in range(5):
- # 真实数据训练
- xr = next(data_iter)
- xr = torch.from_numpy(xr).cpu()
- predr = D(xr)
- # 放大真实数据
- lossr = -predr.mean()
-
- # 虚假数据训练
- z = torch.randn(batchsz,2).cpu()
- xf = G(z).detach()
- predf = D(xf)
- # 缩小虚假数据
- lossf = predf.mean()
-
- # 梯度惩罚值
- gp = gradient_penalty(D,xr,xf.detach())
- loss_D = lossr + lossf + 0.2 * gp
- # 梯度清零
- optim_D.zero_grad()
- # 向后传播
- loss_D.backward()
- optim_D.step()
-
-
- # 训练生成网络
- z = torch.randn(batchsz,2).cpu()
- xf = G(z)
- predf = D(xf)
- loss_G = -predf.mean()
- optim_G.zero_grad()
- loss_G.backward()
- optim_G.step()
-
- if epoch % 100 == 0:
- viz.line([[loss_D.item(),loss_G.item()]], [epoch],win='loss', update='append')
- print(loss_D.item(), loss_G.item())
- generate_image(D, G, xr, epoch)
run()
可以看到在wgan中,生成网络开始学习,生成的数据点也能基本根据高斯分布落在中心点附近
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。