当前位置:   article > 正文

torch.eq(predict_ labels, labels).sum().item()注意事项_python sum().item()

python sum().item()

式中predict_ labels与labels是两个大小相同的tensor,而torch.eq()函数就是用来比较对应位置数字,相同则为1,否则为0,输出与那两个tensor大小相同,并且其中只有1和0。

predict_ labels = [0 1 2 3 4]
labels = [4 3 2 1 4]
torch.eq()得[0 0 1 0 1]

torch.eq().sum()就是将所有值相加,但得到的仍是tensor.
torch.eq().sum()得到的结果就是[2]。
torch.eq().sum().item()得到值2。

用这个来计算训练集、验证集准确率时,记得一个epoch后要除的分母是训练集、验证集的数据量大小!!!
(自己有次搞错,分母用的是一个epoch中的step,迷茫了好一会儿)

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/IT小白/article/detail/372515
推荐阅读
相关标签
  

闽ICP备14008679号