赞
踩
- #!usr/bin/env python
- # -*- coding:utf-8 _*-
- """
- @author: JMS
- @file: gan.py
- @time: 2023/01/08
- @desc:
- """
- import torch
- from torch import nn, optim, autograd
- import numpy as np
- import visdom
- import random
- from matplotlib import pyplot as plt
-
- h_dim=400
- batchsz=512
- viz=visdom.Visdom()
-
- class Generator(nn.Module):
- def __init__(self):
- super(Generator,self).__init__()
-
- self.net=nn.Sequential(
- #z:[b,2]=>[b,2]
- nn.Linear(2,h_dim),
- nn.ReLU(True),
- nn.Linear(h_dim,h_dim),
- nn.ReLU(True),
- nn.Linear(h_dim,h_dim),
- nn.ReLU(True),
- nn.Linear(h_dim,2),
- )
-
- class Discriminator(nn.Module):
- def __init__(self):
- super(Discriminator,self).__init__()
-
- self.net=nn.Sequential(
- nn.Linear(2,h_dim),
- nn.ReLU(True),
- nn.Linear(h_dim,h_dim),
- nn.ReLU(True),
- nn.Linear(h_dim, h_dim),
- nn.ReLU(True),
- nn.Linear(h_dim, h_dim),
- nn.Sigmoid()
-
- )
- def forward(self, x):
- output=self.net(x)
- return output.view(-1)
- def data_generator():
- '''
- 8-gaussian mixture models
- :return:
- '''
- scale=2.
- centers=[
- (1,0),
- (-1,0),
- (0,1),
- (0,-1),
- (1./np.sqrt(2),1./np.sqrt(2)),
- (1. / np.sqrt(2), -1. / np.sqrt(2)),
- (-1. / np.sqrt(2), 1. / np.sqrt(2)),
- (-1. / np.sqrt(2), -1. / np.sqrt(2)),
-
-
- ]
- centers=[(scale * x,scale * y) for x, y in centers]
- while True:
- dataset=[]
- for i in range(batchsz):
- point= np.random.randn(2)*0.02
- center = random.choice(centers)
- #N(0,1)+center_x1/x2
- point[0]+=center[0]
- point[1] += center[0]
- dataset.append(point)
-
- dataset=np.array(dataset).astype(np.float32)
- dataset /=1.414
- yield dataset
- ##实现无限数据循环生成器
-
-
- def main():
-
- torch.manual_seed(23)
- np.random.seed(23)
- data_iter=data_generator()
- x=next(data_iter)
- #[b,2]
- # print(x.shape)
- G=Generator().cuda()
- D=Discriminator().cuda()
- #网络结构
- #print(G)
- #print(D)
- optim_G=optim.Adam(G.parameters(), lr=5e-4, betas=(0.5,0.9))
- optim_D=optim.Adam(D.parameters(), lr=5e-4, betas=(0.5,0.9))
- viz.line([[0,0],[0], win='loss', opts=(title='loss', legend['D','G'])])
- ##Gan核心部分
- for epoch in range(50000):
-
- #1. train discrimimator firstly
- for _in range(5):
- #1. train on real data
- xr=next(data_iter)
- xr = torch.from_numpy(x).cuda()
- #【b,2】=>[b,1]
- predict D(xr)
- #max predr,
- loss= -predr.mean()
- #1.2 train on fake data
- #[b,]
- z= torch. randn(batchsz,2).cuda()
- xf=G(z).datach() #类似 tf.stop_gradient()
- predf=D(xf)
- lossf=predf.mean()
-
- ##aggregate all
- loss D= lossr+ lossf
- #optimize
- optim_D.zero_grad()
- loss_D.backward()
- optim_D.step()
-
- #2. train generator
- z=torch.randn(batchsz,2).cuda()
- xf=G(z)
- predf = D(xf)
- # max predf.mean()
- loss_G=-predf.mean()
- #optimize
- optim_G.zero_grad()
- loss_G.backward()
- optim_G.step()
-
- if epoch % 100==0:
- viz.lines()
- print(loss_D.item,loss_G.item())
- generate_image(D,G,xr,epoch)
-
-
-
-
-
-
-
-
-
- if __name__=='__main__':
- main()
-
-
-

WGAN可以改善GAN的训练问题
- #!usr/bin/env python
- # -*- coding:utf-8 _*-
- """
- @author: JMS
- @file: wgan.py
- @time: 2023/01/09
- @desc:
- """
- #!usr/bin/env python
- # -*- coding:utf-8 _*-
- """
- @author: JMS
- @file: gan.py
- @time: 2023/01/08
- @desc:
- """
- import torch
- from torch import nn, optim, autograd
- import numpy as np
- import visdom
- import random
- from matplotlib import pyplot as plt
-
- ##WGAN解决GAN的训练不稳定问题
- h_dim=400
- batchsz=512
- viz=visdom.Visdom()
-
- class Generator(nn.Module):
- def __init__(self):
- super(Generator,self).__init__()
-
- self.net=nn.Sequential(
- #z:[b,2]=>[b,2]
- nn.Linear(2,h_dim),
- nn.ReLU(True),
- nn.Linear(h_dim,h_dim),
- nn.ReLU(True),
- nn.Linear(h_dim,h_dim),
- nn.ReLU(True),
- nn.Linear(h_dim,2),
- )
-
- class Discriminator(nn.Module):
- def __init__(self):
- super(Discriminator,self).__init__()
-
- self.net=nn.Sequential(
- nn.Linear(2,h_dim),
- nn.ReLU(True),
- nn.Linear(h_dim,h_dim),
- nn.ReLU(True),
- nn.Linear(h_dim, h_dim),
- nn.ReLU(True),
- nn.Linear(h_dim, h_dim),
- nn.Sigmoid()
-
- )
- def forward(self, x):
- output=self.net(x)
- return output.view(-1)
- def data_generator():
- '''
- 8-gaussian mixture models
- :return:
- '''
- scale=2.
- centers=[
- (1,0),
- (-1,0),
- (0,1),
- (0,-1),
- (1./np.sqrt(2),1./np.sqrt(2)),
- (1. / np.sqrt(2), -1. / np.sqrt(2)),
- (-1. / np.sqrt(2), 1. / np.sqrt(2)),
- (-1. / np.sqrt(2), -1. / np.sqrt(2)),
-
-
- ]
- centers=[(scale * x,scale * y) for x, y in centers]
- while True:
- dataset=[]
- for i in range(batchsz):
- point= np.random.randn(2)*0.02
- center = random.choice(centers)
- #N(0,1)+center_x1/x2
- point[0]+=center[0]
- point[1] += center[0]
- dataset.append(point)
-
- dataset=np.array(dataset).astype(np.float32)
- dataset /=1.414
- yield dataset
- ##实现无限数据循环生成器
- gradient_penalty(D,xr,xf):
- """
- :param D:
- :param xr[b,2]:
- :param xf[b,2]:
- :return:
- """
- #[b,1]
- t=torch.rand(batchsz,1).cuda()
- [b,1]=>[b,2]
- t=t.expand_as(xr)
- #interpolation
- mid=t * xr +[1-t] * xf
- #set it requires gradient
- mid.requires_grad_()
-
- pred=D(mid)
- grads=autograd.grad(outputs=pred, inputs=mid,
- grad_output=torch.ones_like(mid),
- create_graph=True, retain_graph=True, only_iputs=True)[0]
- gp = torch.pow(grds.norm(2,dim=1)-1,2).mean()
- return gp
-
-
-
-
- def main():
-
- torch.manual_seed(23)
- np.random.seed(23)
- data_iter=data_generator()
- x=next(data_iter)
- #[b,2]
- # print(x.shape)
- G=Generator().cuda()
- D=Discriminator().cuda()
- #网络结构
- #print(G)
- #print(D)
- optim_G=optim.Adam(G.parameters(), lr=5e-4, betas=(0.5,0.9))
- optim_D=optim.Adam(D.parameters(), lr=5e-4, betas=(0.5,0.9))
- viz.line([[0,0],[0], win='loss', opts=(title='loss', legend['D','G'])])
- ##Gan核心部分
- for epoch in range(50000):
-
- #1. train discrimimator firstly
- for _in range(5):
- #1. train on real data
- xr=next(data_iter)
- xr = torch.from_numpy(x).cuda()
- #【b,2】=>[b,1]
- predict D(xr)
- #max predr,
- loss= -predr.mean()
- #1.2 train on fake data
- #[b,]
- z= torch. randn(batchsz,2).cuda()
- xf=G(z).datach() #类似 tf.stop_gradient()
- predf=D(xf)
- lossf=predf.mean()
-
- #1.3 gradient penalty
- gp = gradient_penalty(D,xr,xf.detach())
-
-
- ##aggregate all
- loss D= lossr+ lossf
- #optimize
- optim_D.zero_grad()
- loss_D.backward()
- optim_D.step()
-
- #2. train generator
- z=torch.randn(batchsz,2).cuda()
- xf=G(z)
- predf = D(xf)
- # max predf.mean()
- loss_G=-predf.mean()
- #optimize
- optim_G.zero_grad()
- loss_G.backward()
- optim_G.step()
-
- if epoch % 100==0:
- viz.lines()
- print(loss_D.item,loss_G.item())
- generate_image(D,G,xr,epoch)
-
-
-
-
-
-
-
-
-
- if __name__=='__main__':
- main()
-
-
-

Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。