赞
踩
最近研究了下CRNN各种pytorch版本,发现里面一大半都是训练有问题,典型问题就是Loss训练几个epoch就变成nan,这样项目在github上有很多,我使用的是pytorch==1.7.0版本,之后发现一个很好解决方法。像网上说什么改学习率,梯度裁剪等等一堆都试了全部没用,偶然成功了一个项目发现为啥他就是对的,原来是CTCLoss设置问题,在高版本pytorch里面,需要在初始CTCLoss时候加个参数即可。
from torch.nn import CTCLoss
ctc_loss=CTCLoss(zero_infinity=True)
这样就是不会出现loss为nan问题,而且测试发现模型预测也正常,看起来这种方法是可行的。如果您遇到这种问题可以试试,如果觉得有用可以在下方留言。
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。