赞
踩
小屌丝:鱼哥,忙吗?
小鱼:忙
小屌丝:哦,那我约别人去吧
小鱼:… 去哪?
小屌丝:反正你也很忙, 就不跟你说了,
小鱼:… 你说说看去哪?
小屌丝:不说了不说了, 说了耽误你工作
小鱼:那我不忙
小屌丝:确定不忙啊
小鱼:必须得,不忙, 你说去哪?
小屌丝: 不忙的话, 给我讲一讲GAN呗
小鱼:…
小屌丝:那你可说好了?
小鱼:… 正人君子,休想套路我
小屌丝:那我可去喽?
小鱼: …唉~ 别这样啊, 我说GAN ,你带我去。
小屌丝:成交
生成对抗网络(Generative Adversarial Networks,简称GAN)是深度学习领域中的一种生成模型,由Ian J. Goodfellow等人于2014年提出。
GAN的核心思想是通过让两个神经网络:
进行对抗训练,以达到生成接近真实数据分布的人工样本的目的。
GAN的原理基于一种“零和游戏”的博弈思想。
具体来说,生成器接受一个随机噪声作为输入,通过一系列的非线性变换生成一个输出样本。
判别器则接收一个输入样本(可能是真实样本或生成样本),并输出一个概率值,表示该样本是真实样本的可能性。
在训练过程中,生成器和判别器通过反向传播算法更新各自的参数,以最大化自己的损失函数。
当训练达到收敛时,生成器能够生成与真实数据分布难以区分的样本,而判别器对于任何输入样本的输出概率都接近0.5(即无法区分真假)。
此时,可以认为GAN已经学会了真实数据的分布。
GAN的实现方式主要包括以下几个步骤:
GAN的核心是通过优化以下的min-max公式来训练生成器和判别器: [ min G max D V ( D , G ) = E x ∼ p d a t a ( x ) [ log D ( x ) ] + E z ∼ p z ( z ) [ log ( 1 − D ( G ( z ) ) ) ] ] [ \min_G \max_D V(D, G) = \mathbb{E}{x\sim p{data}(x)}[\log D(x)] + \mathbb{E}{z\sim p{z}(z)}[\log(1 - D(G(z)))] ] [GminDmaxV(D,G)=Ex∼pdata(x)[logD(x)]+Ez∼pz(z)[log(1−D(G(z)))]]
其中, ( G ( z ) ) (G(z)) (G(z))表示生成器试图通过输入噪声 ( z ) (z) (z)生成的数据, ( D ( x ) ) (D(x)) (D(x))表示判别器对于给定输入 ( x ) (x) (x)的判断(即该数据是真实的概率)
# -*- coding:utf-8 -*- # @Time : 2024-01-21 # @Author : Carl_DJ import torch import torch.nn as nn import torch.optim as optim from torchvision import datasets, transforms from torch.utils.data import DataLoader import matplotlib.pyplot as plt # 设定超参数 batch_size = 64 learning_rate = 0.0002 epochs = 50 latent_dim = 100 # 噪声向量的维度 # 数据预处理和加载 transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize([0.5], [0.5]) ]) train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform) train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) # 构建生成器 class Generator(nn.Module): def __init__(self): super(Generator, self).__init__() self.model = nn.Sequential( nn.Linear(latent_dim, 256), nn.LeakyReLU(0.2), nn.Linear(256, 512), nn.LeakyReLU(0.2), nn.Linear(512, 1024), nn.LeakyReLU(0.2), nn.Linear(1024, 28*28), nn.Tanh() ) def forward(self, z): img = self.model(z) img = img.view(img.size(0), 1, 28, 28) return img # 构建判别器 class Discriminator(nn.Module): def __init__(self): super(Discriminator, self).__init__() self.model = nn.Sequential( nn.Linear(28*28, 512), nn.LeakyReLU(0.2), nn.Linear(512, 256), nn.LeakyReLU(0.2), nn.Linear(256, 1), nn.Sigmoid() ) def forward(self, img): img_flat = img.view(img.size(0), -1) validity = self.model(img_flat) return validity # 初始化生成器和判别器 generator = Generator() discriminator = Discriminator() # 优化器 optimizer_G = optim.Adam(generator.parameters(), lr=learning_rate) optimizer_D = optim.Adam(discriminator.parameters(), lr=learning_rate) # 损失函数 adversarial_loss = nn.BCELoss() # 训练 for epoch in range(epochs): for i, (imgs, _) in enumerate(train_loader): # 真实数据和假数据的标签 real = torch.ones(imgs.size(0), 1) fake = torch.zeros(imgs.size(0), 1) # 训练生成器 optimizer_G.zero_grad() z = torch.randn(imgs.size(0), latent_dim) generated_imgs = generator(z) g_loss = adversarial_loss(discriminator(generated_imgs), real) g_loss.backward() optimizer_G.step() # 训练判别器 optimizer_D.zero_grad() real_loss = adversarial_loss(discriminator(imgs), real) fake_loss = adversarial_loss(discriminator(generated_imgs.detach()), fake) d_loss = (real_loss + fake_loss) / 2 d_loss.backward() optimizer_D.step() # 打印进度 print(f"Epoch [{epoch+1}/{epochs}] Batch {i+1}/{len(train_loader)} Loss D: {d_loss.item()}, loss G: {g_loss.item()}")
代码解析
首先设置了一些超参数,如批大小、学习率、训练周期和潜在空间维度。
然后,它定义了两个关键的神经网络:生成器和判别器。
在训练过程中,我们分别计算生成器和判别器的损失,并通过反向传播更新它们的权重。
随着训练的进行,生成器将变得越来越擅长生成逼真的图像,而判别器则会变得越来越擅长区分真假图像。
生成对抗网络(GAN)作为深度学习领域的一种强大生成模型,已经在图像生成、图像修复、图像转换、文本生成等多个领域取得了显著成果。
其基本原理是通过让生成器和判别器进行对抗训练,以达到生成接近真实数据分布的人工样本的目的。
GAN的实现过程涉及到深度神经网络、优化算法、损失函数等多个方面的知识。
我是小鱼:
关注小鱼,学习【机器学习】&【深度学习】领域的知识。
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。