当前位置:   article > 正文

Pytorch训练CRNN输出为NaN_使用tcn,突然输出全部变为nan

使用tcn,突然输出全部变为nan

记录一个非常隐蔽pytorch训练audio相关的Bug。

问题描述

spectrogram作为input feature,训练几轮之后输出和loss中就有NaN,导致神经网络不收敛。经排查每步的输出,feature中没有NaN,但是网络会输出NaN,loss为NaN。

问题复现

网络输入是torchaudio.transforms.Spectrogram(…,power=1)
model.backward()

torchaudio.transforms.Spectrogram默认的power=2,用默认参数不会有这个问题。

问题分析

在pytorch audio的issue中发现类似问题,torchaudio.transforms.Spectrogram在power=1,也就是用magnitude spectrogram作为输入时,反向传递会出现这个问题。

Spectrogram的计算如下:

stft = torch.stft(wav, n_fft, hop_length, win_length, ...)
norm_sq = stft.pow(2.).sum(-1)
result = norm_sq.sqrt()
  • 1
  • 2
  • 3

如果wav为0,则norm_sq为0,取根号后的result也为0,但是result在计算梯度的时候就是NaN, 因此会在反向传递时出问题。

解决方案

eps = 1e-14  # Add eps to ensure .sqrt is not 0

stft = torch.stft(wav, n_fft, hop_length, win_length, ...)
norm_sq = stft.pow(2.).sum(-1)
result = (norm_sq + eps).sqrt()
  • 1
  • 2
  • 3
  • 4
  • 5

给要取根号的值加个很小的数,使得取根号后的result不为0.

解决!

参考:https://github.com/pytorch/audio/issues/993

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/小蓝xlanll/article/detail/362609
推荐阅读
相关标签
  

闽ICP备14008679号