赞
踩
基础GAN的原理还不懂的,先看:生成式对抗神经网络(GAN)原理给你讲的明明白白
- # 数据归一化
- transform = transforms.Compose([
- transforms.ToTensor(),
- transforms.Normalize(0.5, 0.5)
- ])
- # 加载内置数据
- train_ds = torchvision.datasets.MNIST('data', # 当前目录下的data文件夹
- train=True, # train数据
- transform=transform,
- download=True)
-
- dataloader = torch.utils.data.DataLoader(train_ds, batch_size=64, shuffle=True)
- # 定义生成器
- # 输入是长度为100的噪声(正态分布随机数)
- class Generator(nn.Module):
- def __init__(self):
- super(Generator, self).__init__()
- self.gen = nn.Sequential(nn.Linear(100, 256),
- nn.ReLU(),
- nn.Linear(256, 512),
- nn.ReLU(),
- nn.Linear(512, 28 * 28),
- nn.Tanh()
- )
-
-
- # 定义前向传播 x表示长度为100的noise输入
- def forward(self, x):
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。