当前位置:   article > 正文

[深度学习][pytorch][原创]crnn在高版本pytorch上训练loss为nan解决办法_crnn中train loss: nan

crnn中train loss: nan

最近研究了下CRNN各种pytorch版本,发现里面一大半都是训练有问题,典型问题就是Loss训练几个epoch就变成nan,这样项目在github上有很多,我使用的是pytorch==1.7.0版本,之后发现一个很好解决方法。像网上说什么改学习率,梯度裁剪等等一堆都试了全部没用,偶然成功了一个项目发现为啥他就是对的,原来是CTCLoss设置问题,在高版本pytorch里面,需要在初始CTCLoss时候加个参数即可。

from torch.nn import CTCLoss

ctc_loss=CTCLoss(zero_infinity=True)

这样就是不会出现loss为nan问题,而且测试发现模型预测也正常,看起来这种方法是可行的。如果您遇到这种问题可以试试,如果觉得有用可以在下方留言。

本文内容由网友自发贡献,转载请注明出处:【wpsshop博客】
推荐阅读
相关标签
  

闽ICP备14008679号