赞
踩
- one=torch.tensor(1,dtype=torch.float32)|
- mone=one*-1
- moneg=one*-1*gan_loss_percent
- D = model_discrination()
- #为什么在辨别器反向传播时
- d_real = D(real_pair)
- d_real = d_real.mean()
- d_real.backward(mone)
- d_fake = D(fake_pair)
- d_fake = d_fake.mean()
- d_fake.backward(one)
- #生成器反向传播时:
- segloss.backward(retain_graph=True)
- gd_fake = D(fake_pair)
- gd_fake = gd_fake.mean()
- gd_fake.backward(moneg)
在生成对抗网络(GAN)的训练过程中,生成器(G)和鉴别器(D)通过反向传播更新权重以达到最优化。训练分为两个主要步骤:训练鉴别器识别真实图像和生成图像,以及训练生成器欺骗鉴别器。代码片段中的反向传播策略是实现这一训练过程的关键部分。我们来详细解释每一步:
- - **真实图像的梯度更新**:
- d_real = D(real_pair)
- d_real = d_real.mean()
- d_real.backward(mone)
这里,`d_real.mean()`计算一批真实图像的平均得分,意图是最大化这个分数(即让鉴别器更有信心这批图像是真实的)。使用`mone`(即-1)作为梯度方向,意味着要增加这个平均分数(因为在梯度下降中,我们是向梯度的“负方向”更新权重)。
- - **生成图像的梯度更新**:
- d_fake = D(fake_pair)
- d_fake = d_fake.mean()
- d_fake.backward(one)
对于生成图像,目标是最小化鉴别器给出的分数(即让鉴别器认为这批图像不是真实的)。使用`one`(即1)作为梯度方向,意味着要减少这个平均分数。
- - **生成器的梯度更新**(通过鉴别器反馈):
- gd_fake = D(fake_pair)
- gd_fake = gd_fake.mean()
- gd_fake.backward(moneg)
- - **生成器的梯度更新**(直接关于生成任务的损失):
- segloss.backward(retain_graph=True)
这一步涉及到生成器的直接优化目标(例如,图像分割任务的损失),而与鉴别器无关。`retain_graph=True`参数允许保留计算图,以便后续可以继续进行梯度计算,这在同一批数据上进行多次反向传播时是必要的。
综上所述,这种训练策略通过在不同的步骤中适当选择梯度方向(正或负),实现了鉴别器和生成器的对抗训练,从而使GAN能够生成高质量的图像。
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。