赞
踩
式中predict_ labels与labels是两个大小相同的tensor,而torch.eq()函数就是用来比较对应位置数字,相同则为1,否则为0,输出与那两个tensor大小相同,并且其中只有1和0。
predict_ labels = [0 1 2 3 4]
labels = [4 3 2 1 4]
torch.eq()得[0 0 1 0 1]
torch.eq().sum()就是将所有值相加,但得到的仍是tensor.
torch.eq().sum()得到的结果就是[2]。
torch.eq().sum().item()得到值2。
用这个来计算训练集、验证集准确率时,记得一个epoch后要除的分母是训练集、验证集的数据量大小!!!
(自己有次搞错,分母用的是一个epoch中的step,迷茫了好一会儿)
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。