赞
踩
GAN,全称Generative Adversarial Networks,即生成对抗网络,是深度学习中一种强大的生成模型。GAN 是由 Ian Goodfellow 等人在 2014 年提出的,通过让两个神经网络相互对抗来生成新的数据。
GAN(Generative Adversarial Networks,生成对抗网络)是一种深度学习框架,用于生成新的、与训练数据相似的数据。GAN由两个神经网络组成:生成器(Generator)和判别器(Discriminator)。这两个网络在训练过程中相互竞争和协作,使得生成器能够生成越来越逼真的数据。
本文将介绍如何使用 PyTorch 实现 GAN。
GAN 的工作原理是,生成器(Generator)和判别器(Discriminator)两个神经网络相互对抗地学习。生成器用于生成图像或数据,而判别器则用于判断输入的数据是否真实。两个神经网络不断地交替训练,直到生成器可以生成接近于真实样本的样本。
GAN 可以用于生成各类数据,如图像、音频、文本等。在图像生成方面,GAN 用于生成无限多张与训练数据相似却并不存在于训练数据中的图像。GAN 可以应用于许多领域,如计算机图形学、自然语言处理等等。
使用 PyTorch 中的 MNIST 数据集进行训练。MNIST 数据集包含手写数字图像,大小为 28x28 像素。
import torch
from torchvision import datasets, transforms
# 使用GPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# 加载 MNIST 数据集
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
mnist_data = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
data_loader = torch.utils.data.DataLoader(mnist_data, batch_size=128, shuffle=True)
生成器和判别器都是神经网络模型。生成器将从随机噪声中生成图像,而判别器将对输入图像进行分类,判断它是真实图像还是生成器生成的伪造图像。
import torch.nn as nn
# 定义生成器
class Generator(nn.Module):
def __init__(self, z_dim=100, hidden_dim=128):
super(Generator, self).__init__() # 继承 nn.Module 类并初始化
self.z_dim = z_dim # 输入向量的维度
self.hidden_dim = hidden_dim # 隐藏层维度
self.fc1 = nn.Linear(z_dim, hidden_dim) # 全连接层,输入为 z_dim 维,输出为 hidden_dim 维
self.fc2 = nn.Linear(hidden_dim, 28 * 28) # 全连接层,将隐藏层映射到 28*28 的图像
def forward(self, z):
x = self.fc1(z) # 输入 z,通过全连接层 fc1 得到隐藏层向量 x
x = nn.functional.leaky_relu(x, 0.2) # 在隐藏层中应用 LeakyReLU 激活函数
x = self.fc2(x) # 将隐藏层映射到生成的图像
x = nn.functional.tanh(x) # 将输出值映射到 [-1, 1] 的范围
return x.view(-1, 1, 28, 28) # 将输出展平成图片张量形式
# 定义判别器
class Discriminator(nn.Module):
def __init__(self, hidden_dim=128):
super(Discriminator, self).__init__() # 继承 nn.Module 类并初始化
self.hidden_dim = hidden_dim # 隐藏层维度
self.fc1 = nn.Linear(28 * 28, hidden_dim) # 输入为 28*28 的图像,输出为隐藏层向量
self.fc2 = nn.Linear(hidden_dim, 1) # 只有一个输出,表示输入是否是真实的图像。
def forward(self, x):
x = x.view(-1, 28*28) # 将输入展平为向量
x = self.fc1(x) # 将展平后的向量通过全连接层 fc1 映射到隐藏层
x = nn.functional.leaky_relu(x, 0.2) # 在隐藏层中应用 LeakyReLU 激活函数
x = self.fc2(x) # 将隐藏层输出映射到单个输出值
x = nn.functional.sigmoid(x) # 将输出值映射到 [0, 1] 的范围,表示输入是否是真实图像的概率
return x
这段代码实现了一个基本的生成对抗网络(GAN)。GAN 由两个神经网络组成,一个生成器(Generator)和一个判别器(Discriminator),它们在博弈中对抗地进行训练。生成器从随机噪声中生成假图像,而判别器则试图区分真实图像和生成图像。训练过程目的是使得生成器生成的图像更逼真,从而欺骗判别器,使其将生成的图像和真实图像分类错误。以下是代码实现的具体步骤:
定义超参数、优化器、损失函数。
# 定义超参数
lr = 0.0002
z_dim = 100
num_epochs = 50
generator = Generator(z_dim).to(device)
discriminator = Discriminator().to(device)
# 定义优化器和损失函数
optimizer_G = torch.optim.Adam(generator.parameters(), lr=lr)
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=lr)
criterion = nn.BCELoss()
# 训练网络,使用给定的 epoch 数量
for epoch in range(num_epochs):
# 遍历数据集的 mini-batches
for i, (real_images, _) in enumerate(data_loader):
# 将真实图像传递给设备
real_images = real_images.to(device)
# 定义真实标签为 1,假的标签为 0
real_labels = torch.ones(real_images.size(0), 1).to(device)
fake_labels = torch.zeros(real_images.size(0), 1).to(device)
# 训练生成器
# 生成随机的噪音 z
z = torch.randn(real_images.size(0), z_dim).to(device)
# 生成 fake_images
fake_images = generator(z)
# 将 fake_images 传递给判别器,得到输出 fake_output
fake_output = discriminator(fake_images)
# 计算生成器的损失值 loss_G
# 需要将 fake_output 与真实标签 real_labels 进行比较
loss_G = criterion(fake_output, real_labels)
# 重置 generator 的优化器,清除梯度
optimizer_G.zero_grad()
# 计算 generator 的梯度
loss_G.backward()
# 更新 generator 参数,使用优化器 optimizer_G
optimizer_G.step()
# 训练判别器
# 将真实图像 real_images 通过判别器得到输出 real_output
real_output = discriminator(real_images)
# 计算判别器对真实图像的损失值 loss_D_real
# 需要将 real_output 与真实标签 real_labels 进行比较
loss_D_real = criterion(real_output, real_labels)
# 计算判别器对生成器生成的 fake_images 的损失值 loss_D_fake
# 需要将 fake_output 与假的标签 fake_labels 进行比较
fake_output = discriminator(fake_images.detach())
loss_D_fake = criterion(fake_output, fake_labels)
# 计算判别器总损失值 loss_D,即 loss_D_real 和 loss_D_fake 的和
loss_D = loss_D_real + loss_D_fake
# 重置 discriminator 的优化器,清除梯度
optimizer_D.zero_grad()
# 计算 discriminator 的梯度
loss_D.backward()
# 更新 discriminator 参数,使用优化器 optimizer_D
optimizer_D.step()
# 输出损失值
# 每100次迭代输出一次
if (i + 1) % 100 == 0:
print('Epoch [{}/{}], Step [{}/{}], d_loss: {:.4f}, g_loss: {:.4f}'
.format(epoch+1, num_epochs, i+1, len(data_loader), loss_D.item(), loss_G.item()))
这段代码表示了一个完整的 GAN 利用交替梯度下降算法进行训练的过程。在每个 epoch 内,我们遍历数据集中的所有 mini-batches,并对每个 mini-batch 迭代进行以下步骤:
real_images
传递给设备。z
开始,使用生成器学习生成伪造图片。将生成的图片 fake_images
传递给判别器,并使用判别器输出来计算生成器的损失值。这个损失值反映了生成器生成的图片与真实标签之间的差异。num_epochs
次来完成 GAN 的训练。在训练结束后,可以使用生成器来生成新图像。
# 生成一些测试数据
# 随机生成一些长度为 z_dim 的、位于设备上的向量 z
z = torch.randn(16, z_dim).to(device)
# 使用生成器从 z 中生成一些假的图片
fake_images = generator(z).detach().cpu()
# 显示生成的图像
# 创建一个图形对象,大小为 4x4 英寸
fig = plt.figure(figsize=(4, 4))
# 在图形对象中创建4x4的网格,以显示输出的16张假图像
for i in range(16):
plt.subplot(4, 4, i+1)
plt.imshow(fake_images[i][0], cmap='gray')
plt.axis('off')
# 显示绘制的图形
plt.show()
这段代码的目的是在训练 GAN 结束后,使用生成器生成一些随机的假图像进行展示。下面是具体步骤:
z_dim
的、位于设备上的向量 z
。z
中生成一些假的图片。plt.imshow()
函数将图像附加到当前子图中。该函数会在 matplotlib 窗口中显示生成的图像。GAN 是一种强大的生成模型,可以用于生成各种不同类型的数据。本教程展示了如何使用 PyTorch 实现 GAN,并使用 MNIST 数据集进行训练。
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。