赞
踩
不同于一般的GAN网络使用交叉熵损失函数,这里使用的是最小二乘损失函数,借此来避免梯度丢失的现象。
最小二乘损失函数
特征匹配损失函数是用来衡量真实样本和生成样本在鉴定器上提取出来的特征的差异程度。不同于上一个mel频谱图的特征衡量,这里是直接衡量鉴定器生成的中间特征的差异程度。
定义
参数说明
效果
注意
def feature_loss(fmap_r, fmap_g):
# 特征损失函数
# fmap_r是真实音频信号的特征图,fmap_g是生成音频信号的特征图
loss = 0
for dr, dg in zip(fmap_r, fmap_g):
for rl, gl in zip(dr, dg):
# 遍历每一层特征图,计算特征损失,做差,求绝对值,求均值
loss += torch.mean(torch.abs(rl - gl))
# 根据经验,特征损失函数的权重为10
return loss*2
def generator_loss(disc_outputs):
# 生成器的损失函数
# disc_outputs是鉴定器的输出
loss = 0
gen_losses = []
for dg in disc_outputs:
l = torch.mean((1-dg)**2)
gen_losses.append(l)
loss += l
# loss是生成器的总损失,用于反向传播来更新生成器的参数
# gen_losses是生成器的损失列表,用于记录鉴定器中每一个元素对应的损失,可以用于调试设备
return loss, gen_losses
def discriminator_loss(disc_real_outputs, disc_generated_outputs):
# 鉴定器的损失函数
# disc_real_outputs是真实音频信号的鉴定器的输出
# disc_generated_outputs是生成音频信号的鉴定器的输出
loss = 0
r_losses = []
g_losses = []
for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
# 计算真实音频信号的损失
r_loss = torch.mean((1-dr)**2)
# 计算生成音频信号的损失
g_loss = torch.mean(dg**2)
# 将两个损失相加
loss += (r_loss + g_loss)
# 记录各个鉴定器的损失
r_losses.append(r_loss.item())
g_losses.append(g_loss.item())
return loss, r_losses, g_losses
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。