赞
踩
一. 概念
分类任务中常见的四种指标:准确率、精确率、召回率和F值。不过那什么又是Top-K准确率呢?简单一句话概括:Top-K准确率就是用来计算预测结果中概率最大的前K个结果包含正确标签的占比。换句话说,平常我们所说的准确率其实就是Top-1准确率。下面我们通过一个例子来进行说明:
假如现在有一个用于手写体识别的分类器(10分类),现将一张正确标签为3的图片输入到分类器中且得到了如下所示的一个概率分布:
p=[0.1,0.05,0.1,0.2,0.35,0.01,0.03,0.05,0.01,0.1]
显然,根据预测的结果来看,其最大概率0.35所对应的标签为4,这也就代表着如果按照以往的标准(Top-1准确率)来看,分类器对于这张图片的预测结果就是错误的。但如果我们以Top-2的标准来看的话,分类器对于这个图片的预测结果就是正确的,因为p中概率值最大的前两个中包含有真实的标签。也就是说,虽然0.35对应的标签是错的,但是排名第二的概率值0.2所对应的标签是正确的,所以我们在计算Top-2准确率的时候也将上述结果当作是预测正确的。
因此我们可以看出,Top-K准确率考虑的是预测结果中最有可能的K个结果是否包含有真实标签,如果包含则算预测正确,如果不包含则算预测错误。所以在这里我们能够知道,K值取得越大计算得到的Top-K准确率就会越高,极端情况下如果取K值为分类数,那么得到的准确率就肯定是1。但通常情况下我们只会看模型的Top-1、Top-3和Top-5准确率。
二. 实现
下面我们来看实现:
函数的输入:
output:模型的输出,即模型对不同类别的评分。shape: [batch_size, num_classes]
target:真实的类别标签。shape: [batch_size, ]
topk:需要计算top_k准确率中的k值,元组类型。默认为(1, 5),即函数返回top1和top5的分类准确率
下面我们还是先举一个例子:
- import torch
- output=torch.Tensor([[0.1,0.05,0.1,0.2,0.35,0.01,0.03,0.05,0.01,0.1],
- [0.2,0.05,0.1,0.35,0.2,0.01,0.02,0.04,0.01,0.0],
- [0.1,0.05,0.1,0.15,0.05,0.01,0.03,0.4,0.01,0.1],
- [0.1,0.05,0.1,0.15,0.05,0.01,0.08,0.1,0.01,0.35]])# 模型预测的概率分布
- target=torch.Tensor([[4],[3],[7],[3]]) # 实际的类别索引
- # output预测的值:每一行最大值对应的索引,为torch.Tensor([[4],[3],[7],[9]])
- topk=(1,3)# 这里预测top-1和top-3
- maxk = max(topk) # 按topk最大值构建张量
- batch_size = target.size(0) # 这里批量数等于样本数4
- _, pred = output.topk(maxk, 1, True, True) # topk返回两个张量:values和indices,分别对应前k大值的数值和索引
- print(_)
- print(pred) # size:batch_size*maxk=4*3
topk的输出:
- tensor([[0.3500, 0.2000, 0.1000],
- [0.3500, 0.2000, 0.2000],
- [0.4000, 0.1500, 0.1000],
- [0.3500, 0.1500, 0.1000]])
- tensor([[4, 3, 0],
- [3, 0, 4],
- [7, 3, 0],
- [9, 3, 0]])
pred存储了每个样本预测概率前三位的索引值。下面把target的维度改变一下进行比较:
- pred = pred.t() # 转置,size:maxk*batch_size=3*4
- correct = pred.eq(target.view(1, -1).expand_as(pred))
- # eq输出元素相等的布尔值
- # expand_as将张量扩展为pred的大小
- # view()的作用相当于numpy中的reshape,重新定义矩阵的形状
- print(pred) # size:maxk*batch_size=3*4
- print(target.view(1, -1).expand_as(pred)) # 扩展维度和pred一样
- print(correct)
correct的输出:size:max(topk)*batch_size,行数代表第几大概率,所以correct前n行就代表了前n大概率的预测情况。再看列:列中的前n行有True就表示topn预测正确。比如第4列代表的第4个样本,真实的标签是第二行,但是模型预测的标签最大是第一行,次大是第二行。那么top-1正确率就是False,预测失败,不计入top-1准确率。但是top-3正确率就是True,预测成功,计入top-3准确率。
- # pred转置
- tensor([[4, 3, 7, 9],
- [3, 0, 3, 3],
- [0, 4, 0, 0]])
- # target转置并改变维度
- tensor([[4., 3., 7., 3.],
- [4., 3., 7., 3.],
- [4., 3., 7., 3.]])
- # correct输出:比较pred和target
- tensor([[ True, True, True, False],
- [False, False, False, True],
- [False, False, False, False]])
然后我们输出topk准确率:
- res = []
- for k in topk:
- correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
- res.append(correct_k.mul_(100.0 / batch_size)) # 以百分比形式输出
- print(res)
输出:
[tensor([75.]), tensor([100.])]
最后打包成函数,以后使用直接复制即可:
- def accuracy(output, target, topk=(1,5)):
- """Computes the accuracy over the k top predictions for the specified values of k"""
- # 根据指定值k,计算top-k准确度
- with torch.no_grad():
- maxk = max(topk)
- batch_size = target.size(0)
-
- _, pred = output.topk(maxk, 1, True, True)
- # topk取一个tensor的topk元素(降序后的前k个大小的元素值及索引)
- # 返回两个张量:values和indices,分别对应前k大值的数值和索引
- pred = pred.t() # 转置
- correct = pred.eq(target.view(1, -1).expand_as(pred))
- # eq输出元素相等的布尔值,expand_as将张量扩展为参数tensor的大小,view()的作用相当于numpy中的reshape,重新定义矩阵的形状。
-
- res = []
- for k in topk:
- correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
- res.append(correct_k.mul_(100.0 / batch_size))
- return res
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。