赞
踩
之前在做自定义的dataset的时候其实也遇到过,但是没太在意,现在又遇到了所以记录一下,网上讲解有挺多,说一下我遇到的这种。除了上面这个断言还会有如下的报错提示。
可以看出出错的地方基本上就是在损失函数那个地方。本人较为愚钝,翻了半天也没搞懂标签哪错了(PS:有人的这种情况是全连接层输出和他用的训练集的输出不一样导致),后来才知道原因:数据集的标签要从0开始,因为交叉熵损失的y_hat(真实标签)的取值范围是[0,类别数-1],从0开始算的到类别数-1就是你的类别数。如果不是的话就会报上面标题的错误Assertion t >= 0 && t < n_classes
failed.注意这里t是在**[0,num_class)**之间的。因此解决这个问题的方法就是在你自定义的dataset中,把你重写的
这个函数中返回的真实标签值设置为[0,num_class),如果你有10类,那么你的label的取值就是[0,1,2,3,4,5,6,7,8,9],这样就解决了。
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。