当前位置:   article > 正文

yolov5中的初始anchor理解笔记_yolo anchor

yolo anchor

一、yolov5中的初始Anchor设定

在YOLOV5算法之中,针对不同的数据集,一般会预先设置固定的Anchor;

首先,在网络训练中,网络在初始锚框的基础上输出预测框,进而和Ground Truth进行比对,计算两者差距,再反向更新,迭代网络参数

可以看出Anchor也是比较重要的一部分,比如Yolov5在Coco数据集上初始设定的锚框:

  1. anchors:
  2. - [10,13, 16,30, 33,23] # P3/81行是在最小的特征图上的锚框;
  3. - [30,61, 62,45, 59,119] # P4/16
  4. - [116,90, 156,198, 373,326] # P5/32

其中:

第1行是在最小的特征图上的锚框;

第2行是在中间的特征图上的锚框;

第3行是在最大的特征图上的锚框;

注:阅读其它人的博客发现,原来yolov5也可以不预设anchor,也可以直接写个3,此时yolov5就会自动按照训练集聚类anchor,如下: 

  1. # Parameters
  2. nc: 80 # number of classes
  3. depth_multiple: 1.0 # model depth multiple
  4. width_multiple: 1.0 # layer channel multiple
  5. anchors: 3 # AutoAnchor evolves 3 anchors per P output layer

为啥anchor一行是六个数呢,xywh个数也不对啊?
这里就要说一下anchor是怎么生成的了。

对于输出层(Prediction),经过前面的一系列特征提取和计算操作后,会生成三个特定大小的特征,大小分别为608/8=76,608/16=38,608/32=19,可能这也是输入图像大小要求是32的倍数的原因。

下面是v5代码中采用kmeans计算anchor的过程。

path代表数据yaml路径,n代表聚类数,img_size代表模型输入图片的大小,thr代表长宽比的阈值(将长宽比限定在一定的范围内,这个可以自己统计一下数据集),gen代表kmeans迭代次数。
 

  1. def kmean_anchors(path='./data/coco128.yaml', n=9, img_size=640, thr=4.0, gen=1000, verbose=True):
  2. """ Creates kmeans-evolved anchors from training dataset
  3. Arguments:
  4. path: path to dataset *.yaml, or a loaded dataset
  5. n: number of anchors
  6. img_size: image size used for training
  7. thr: anchor-label wh ratio threshold hyperparameter hyp['anchor_t'] used for training, default=4.0
  8. gen: generations to evolve anchors using genetic algorithm
  9. Return:
  10. k: kmeans evolved anchors
  11. Usage:
  12. from utils.general import *; _ = kmean_anchors()
  13. """
  14. thr = 1. / thr
  15. def metric(k, wh): # compute metrics
  16. r = wh[:, None] / k[None]
  17. x = torch.min(r, 1. / r).min(2)[0] # ratio metric
  18. # x = wh_iou(wh, torch.tensor(k)) # iou metric
  19. return x, x.max(1)[0] # x, best_x
  20. def fitness(k): # mutation fitness
  21. _, best = metric(torch.tensor(k, dtype=torch.float32), wh)
  22. return (best * (best > thr).float()).mean() # fitness
  23. def print_results(k):
  24. k = k[np.argsort(k.prod(1))] # sort small to large
  25. x, best = metric(k, wh0)
  26. bpr, aat = (best > thr).float().mean(), (x > thr).float().mean() * n # best possible recall, anch > thr
  27. print('thr=%.2f: %.4f best possible recall, %.2f anchors past thr' % (thr, bpr, aat))
  28. print('n=%g, img_size=%s, metric_all=%.3f/%.3f-mean/best, past_thr=%.3f-mean: ' %
  29. (n, img_size, x.mean(), best.mean(), x[x > thr].mean()), end='')
  30. for i, x in enumerate(k):
  31. print('%i,%i' % (round(x[0]), round(x[1])), end=', ' if i < len(k) - 1 else '\n') # use in *.cfg
  32. return k
  33. if isinstance(path, str): # *.yaml file
  34. with open(path) as f:
  35. data_dict = yaml.load(f, Loader=yaml.FullLoader) # model dict
  36. from utils.datasets import LoadImagesAndLabels
  37. dataset = LoadImagesAndLabels(data_dict['train'], augment=True, rect=True)
  38. else:
  39. dataset = path # dataset
  40. # Get label wh
  41. shapes = img_size * dataset.shapes / dataset.shapes.max(1, keepdims=True)
  42. wh0 = np.concatenate([l[:, 3:5] * s for s, l in zip(shapes, dataset.labels)]) # wh
  43. # Filter
  44. i = (wh0 < 3.0).any(1).sum()
  45. if i:
  46. print('WARNING: Extremely small objects found. '
  47. '%g of %g labels are < 3 pixels in width or height.' % (i, len(wh0)))
  48. wh = wh0[(wh0 >= 2.0).any(1)] # filter > 2 pixels
  49. # Kmeans calculation
  50. print('Running kmeans for %g anchors on %g points...' % (n, len(wh)))
  51. s = wh.std(0) # sigmas for whitening
  52. k, dist = kmeans(wh / s, n, iter=30) # points, mean distance
  53. k *= s
  54. wh = torch.tensor(wh, dtype=torch.float32) # filtered
  55. wh0 = torch.tensor(wh0, dtype=torch.float32) # unflitered
  56. k = print_results(k)
  57. # Plot
  58. # k, d = [None] * 20, [None] * 20
  59. # for i in tqdm(range(1, 21)):
  60. # k[i-1], d[i-1] = kmeans(wh / s, i) # points, mean distance
  61. # fig, ax = plt.subplots(1, 2, figsize=(14, 7))
  62. # ax = ax.ravel()
  63. # ax[0].plot(np.arange(1, 21), np.array(d) ** 2, marker='.')
  64. # fig, ax = plt.subplots(1, 2, figsize=(14, 7)) # plot wh
  65. # ax[0].hist(wh[wh[:, 0]<100, 0],400)
  66. # ax[1].hist(wh[wh[:, 1]<100, 1],400)
  67. # fig.tight_layout()
  68. # fig.savefig('wh.png', dpi=200)
  69. # Evolve
  70. npr = np.random
  71. f, sh, mp, s = fitness(k), k.shape, 0.9, 0.1 # fitness, generations, mutation prob, sigma
  72. pbar = tqdm(range(gen), desc='Evolving anchors with Genetic Algorithm') # progress bar
  73. for _ in pbar:
  74. v = np.ones(sh)
  75. while (v == 1).all(): # mutate until a change occurs (prevent duplicates)
  76. v = ((npr.random(sh) < mp) * npr.random() * npr.randn(*sh) * s + 1).clip(0.3, 3.0)
  77. kg = (k.copy() * v).clip(min=2.0)
  78. fg = fitness(kg)
  79. if fg > f:
  80. f, k = fg, kg.copy()
  81. pbar.desc = 'Evolving anchors with Genetic Algorithm: fitness = %.4f' % f
  82. if verbose:
  83. print_results(k)
  84. return print_results(k)

上面的计算过程相当于将我画的长宽比先转化到resize640大小的长宽比下,再进行聚类,得到9个聚类中心,每个聚类中心包含(x,y)坐标就是我们需要的anchor如下:

134,38,  172,35,  135,48,  175,43,  209,38,  174,62,  254,69,  314,82,  373,95

 将其放入list

  1. #anchors:
  2. #1. [10,13, 16,30, 33,23] # P3/8 608/8=76
  3. #2. [30,61, 62,45, 59,119] # P4/16 608/16=38
  4. #3. [116,90, 156,198, 373,326] # P5/32 608/32=19
  5. 1. [134,38,135,48,172,35] # P3/8 608/8=76
  6. 2. [174,62,175,43,209,38] # P4/16 608/16=38
  7. 3. [254,69,314,82,373,95] # P5/32 608/32=19

 这里的thr其实是和hyp.scratch.yaml文件中的anchor_t一样,代表了anchor放大的scale,我的标注框长宽比最大在8左右,因此设置为8。

检测模块

接下来就是anchor在模型中的应用了。这就涉及到了yolo系列目标框回归的过程了。

yolov5中的detect模块沿用了v3检测方式,这里就用这种方式来阐述了。

1.检测到的不是框,是偏移量。

  1. tx,ty指的是针对所在grid的左上角坐标的偏移量
  2. tw,th指的是相对于anchor的宽高的偏移量

通过如下图的计算方式,得到bx,by,bw,bh就是最终的检测结果。

2. 前面经过backbone,neck, head是panet的三个分支,可见特征图size不同,每个特征图分了13个网格,同一尺度的特征图对应了3个anchor,检测了[c,x,y,w,h]和num_class个的one-hot类别标签。3个尺度的特征图,总共就有9个anchor。
 

 

参考:

(173条消息) YOLOv5的Backbone详解_Marlowee的博客-CSDN博客_yolov5 backbone

Yolov5-模型配置文件(yolov5l.yaml)讲解 - 知乎 (zhihu.com)

YOLOv5-Lite 详解教程 | 嚼碎所有原理和思想、训练自己数据集、TensorRT部署落地应有尽有 - 知乎 (zhihu.com)

(11条消息) yolov5的anchor详解_anny_jra的博客-CSDN博客_yolov5的anchor

(3条消息) YOLOv5的anchor设定_Marlowee的博客-CSDN博客_yolov5 anchor设置

(13条消息) YOLO系列详解:YOLOv1、YOLOv2、YOLOv3、YOLOv4、YOLOv5_AI小白一枚的博客-CSDN博客_yolo

(14条消息) 【目标检测】yolo系列:从yolov1到yolov5之YOLOv3详解及复现_看星星的月儿的博客-CSDN博客_yolo 输出维度 (14条消息) YOLOv1——YOLOX系列及FCOS目标检测算法详解_神洛华的博客-CSDN博客_fcos和yolo

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/木道寻08/article/detail/860991
推荐阅读
相关标签
  

闽ICP备14008679号