赞
踩
对于目标检测来说, 非极大值抑制的含义就是对于重叠度较高的一部分同类候选框来说, 去掉那些置信度较低的框, 只保留置信度最大的那一个进行后面的流程, 这里的重叠度高低与否是通过 iou 阈值来判断的.
算法逻辑:
输入: n 行 5 列的候选框数组,每行依次为:左上角横坐标,左上角纵坐标,右下角横坐标,右下角纵坐标,置信度
输出: m 行 5 列的候选框数组, 每行含义同输入
算法流程:
计算 n 个候选框的面积大小
对置信度进行排序, 获取排序后的下标序号, 即采用argsort
将当前置信度最大的框加入返回值列表中
获取当前置信度最大的候选框与其他任意候选框的相交面积
利用相交的面积和两个框自身的面积计算框的交并比, 将交并比大于阈值的框删除.
对剩余的框重复以上过程
代码实现:
1、面试简单版本(不包含iou实现,iou实现可以参考本笔记本其他笔记,注意排序的实现方法)
- def nms(bboxs, threshold):
- """
- 计算nms
- 不包含根据得分预先过滤的功能
- :param bboxs: 列表[[xt1, yt1, xb1, yb1, score1], .....]
- :param threshold: iou过滤阈值
- :return:
- """
- if len(bboxs) == 0:
- return bboxs
-
- # 按得分排序
- def take_score(c):
- return c[4]
- bboxs.sort(key=take_score, revere=True)
-
- result = []
- for i in range(len(bboxs) - 1):
- if bboxs[i][4] > 0:
- result.append(bboxs[i])
- for j in range(i, len(bboxs)):
- if iou(bboxs[i][:4], bboxs[j][:4]) > threshold:
- bboxs[j][4] = 0
- return result
![](https://csdnimg.cn/release/blogv2/dist/pc/img/newCodeMoreWhite.png)
2、numpy实现(高效且包含iou计算)
- mport numpy as np
-
-
- def nms(bboxs, threshold):
- """
- 计算nms
- 不包含根据得分预先过滤的功能
- :param bboxs: numpy 矩阵[[xt1, yt1, xb1, yb1, score1], .....]
- :param threshold: iou过滤阈值
- :return:
- """
- if len(bboxs) == 0:
- return bboxs
-
- xt = bboxs[:, 0]
- yt = bboxs[:, 1]
- xb = bboxs[:, 2]
- yb = bboxs[:, 3]
- score = bboxs[:, 4]
- areas = (xb - xt + 1) * (yb - yt + 1)
-
- result = []
- # 对置信度进行排序, 获取排序后的下标序号, argsort 默认从小到大排序
- index_sorted = np.argsort(score)
- while index_sorted.size > 0:
- index_current = index_sorted[-1]
- # 将当前置信度最大的框加入返回值列表中
- result.append(bboxs[index_current])
-
- # 获取当前置信度最大的候选框与其他任意候选框的相交面积
- x_min = np.maximum(xt[index_sorted[:-1]], xt[index_current])
- x_max = np.minimum(xb[index_sorted[:-1]], xb[index_current])
- y_min = np.maximum(yt[index_sorted[:-1]], yt[index_current])
- y_max = np.minimum(yb[index_sorted[:-1]], yb[index_current])
- w = np.maximum(0, x_max - x_min + 1)
- h = np.maximum(0, y_max - y_min + 1)
- area_i = w * h
-
- iou = area_i / (areas[index_sorted[:-1]] + areas[index_current] - area_i)
- index_sorted = index_sorted(np.where(iou < threshold))
-
- return result
![](https://csdnimg.cn/release/blogv2/dist/pc/img/newCodeMoreWhite.png)
参考连接:
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。