当前位置:   article > 正文

YOLOv5知识蒸馏实战篇

YOLOv5知识蒸馏实战篇

YOLOv5s学生网络训练

准备自己的数据集

首先将你的数据集按照标准的voc数据集格式进行如下图的目录树存放

其中Annotations存放的是标注文件,文件格式是xml,JPEGImages存放的是图片。

然后执行下面这个脚本就可以实现以下功能

  • 将xml文件格式转换为标准的yolo的txt格式
  • 将图片和标注文件划分训练集和验证集

注意将代码中的classes改为你的分类,顺序要对应,并且将convert_annotation函数中的文件路径改为你的对应的路径,路径中不存在中文并且尽量使用绝对路径。

  1. import xml.etree.ElementTree as ET
  2. import pickle
  3. import os
  4. from os import listdir, getcwd
  5. from os.path import join
  6. import random
  7. from shutil import copyfile
  8. classes=["ball","messi"]
  9. #classes=["ball"]
  10. TRAIN_RATIO = 80
  11. def clear_hidden_files(path):
  12. dir_list = os.listdir(path)
  13. for i in dir_list:
  14. abspath = os.path.join(os.path.abspath(path), i)
  15. if os.path.isfile(abspath):
  16. if i.startswith("._"):
  17. os.remove(abspath)
  18. else:
  19. clear_hidden_files(abspath)
  20. def convert(size, box):
  21. dw = 1./size[0]
  22. dh = 1./size[1]
  23. x = (box[0] + box[1])/2.0
  24. y = (box[2] + box[3])/2.0
  25. w = box[1] - box[0]
  26. h = box[3] - box[2]
  27. x = x*dw
  28. w = w*dw
  29. y = y*dh
  30. h = h*dh
  31. return (x,y,w,h)
  32. def convert_annotation(image_id):
  33. in_file = open('VOCdevkit/VOC2007/Annotations/%s.xml' %image_id)
  34. out_file = open('VOCdevkit/VOC2007/YOLOLabels/%s.txt' %image_id, 'w')
  35. tree=ET.parse(in_file)
  36. root = tree.getroot()
  37. size = root.find('size')
  38. w = int(size.find('width').text)
  39. h = int(size.find('height').text)
  40. for obj in root.iter('object'):
  41. difficult = obj.find('difficult').text
  42. cls = obj.find('name').text
  43. if cls not in classes or int(difficult) == 1:
  44. continue
  45. cls_id = classes.index(cls)
  46. xmlbox = obj.find('bndbox')
  47. b = (float(xmlbox.find('xmin').text), float(xmlbox.find('xmax').text), float(xmlbox.find('ymin').text), float(xmlbox.find('ymax').text))
  48. bb = convert((w,h), b)
  49. out_file.write(str(cls_id) + " " + " ".join([str(a) for a in bb]) + '\n')
  50. in_file.close()
  51. out_file.close()
  52. wd = os.getcwd()
  53. wd = os.getcwd()
  54. data_base_dir = os.path.join(wd, "VOCdevkit/")
  55. if not os.path.isdir(data_base_dir):
  56. os.mkdir(data_base_dir)
  57. work_sapce_dir = os.path.join(data_base_dir, "VOC2007/")
  58. if not os.path.isdir(work_sapce_dir):
  59. os.mkdir(work_sapce_dir)
  60. annotation_dir = os.path.join(work_sapce_dir, "Annotations/")
  61. if not os.path.isdir(annotation_dir):
  62. os.mkdir(annotation_dir)
  63. clear_hidden_files(annotation_dir)
  64. image_dir = os.path.join(work_sapce_dir, "JPEGImages/")
  65. if not os.path.isdir(image_dir):
  66. os.mkdir(image_dir)
  67. clear_hidden_files(image_dir)
  68. yolo_labels_dir = os.path.join(work_sapce_dir, "YOLOLabels/")
  69. if not os.path.isdir(yolo_labels_dir):
  70. os.mkdir(yolo_labels_dir)
  71. clear_hidden_files(yolo_labels_dir)
  72. yolov5_images_dir = os.path.join(data_base_dir, "images/")
  73. if not os.path.isdir(yolov5_images_dir):
  74. os.mkdir(yolov5_images_dir)
  75. clear_hidden_files(yolov5_images_dir)
  76. yolov5_labels_dir = os.path.join(data_base_dir, "labels/")
  77. if not os.path.isdir(yolov5_labels_dir):
  78. os.mkdir(yolov5_labels_dir)
  79. clear_hidden_files(yolov5_labels_dir)
  80. yolov5_images_train_dir = os.path.join(yolov5_images_dir, "train/")
  81. if not os.path.isdir(yolov5_images_train_dir):
  82. os.mkdir(yolov5_images_train_dir)
  83. clear_hidden_files(yolov5_images_train_dir)
  84. yolov5_images_test_dir = os.path.join(yolov5_images_dir, "val/")
  85. if not os.path.isdir(yolov5_images_test_dir):
  86. os.mkdir(yolov5_images_test_dir)
  87. clear_hidden_files(yolov5_images_test_dir)
  88. yolov5_labels_train_dir = os.path.join(yolov5_labels_dir, "train/")
  89. if not os.path.isdir(yolov5_labels_train_dir):
  90. os.mkdir(yolov5_labels_train_dir)
  91. clear_hidden_files(yolov5_labels_train_dir)
  92. yolov5_labels_test_dir = os.path.join(yolov5_labels_dir, "val/")
  93. if not os.path.isdir(yolov5_labels_test_dir):
  94. os.mkdir(yolov5_labels_test_dir)
  95. clear_hidden_files(yolov5_labels_test_dir)
  96. train_file = open(os.path.join(wd, "yolov5_train.txt"), 'w')
  97. test_file = open(os.path.join(wd, "yolov5_val.txt"), 'w')
  98. train_file.close()
  99. test_file.close()
  100. train_file = open(os.path.join(wd, "yolov5_train.txt"), 'a')
  101. test_file = open(os.path.join(wd, "yolov5_val.txt"), 'a')
  102. list_imgs = os.listdir(image_dir) # list image files
  103. prob = random.randint(1, 100)
  104. print("Probability: %d" % prob)
  105. for i in range(0,len(list_imgs)):
  106. path = os.path.join(image_dir,list_imgs[i])
  107. if os.path.isfile(path):
  108. image_path = image_dir + list_imgs[i]
  109. voc_path = list_imgs[i]
  110. (nameWithoutExtention, extention) = os.path.splitext(os.path.basename(image_path))
  111. (voc_nameWithoutExtention, voc_extention) = os.path.splitext(os.path.basename(voc_path))
  112. annotation_name = nameWithoutExtention + '.xml'
  113. annotation_path = os.path.join(annotation_dir, annotation_name)
  114. label_name = nameWithoutExtention + '.txt'
  115. label_path = os.path.join(yolo_labels_dir, label_name)
  116. prob = random.randint(1, 100)
  117. print("Probability: %d" % prob)
  118. if(prob < TRAIN_RATIO): # train dataset
  119. if os.path.exists(annotation_path):
  120. train_file.write(image_path + '\n')
  121. convert_annotation(nameWithoutExtention) # convert label
  122. copyfile(image_path, yolov5_images_train_dir + voc_path)
  123. copyfile(label_path, yolov5_labels_train_dir + label_name)
  124. else: # test dataset
  125. if os.path.exists(annotation_path):
  126. test_file.write(image_path + '\n')
  127. convert_annotation(nameWithoutExtention) # convert label
  128. copyfile(image_path, yolov5_images_test_dir + voc_path)
  129. copyfile(label_path, yolov5_labels_test_dir + label_name)
  130. train_file.close()
  131. test_file.close()

执行后如下图所示:

 执行完毕以后会生成label文件夹,文件夹下已经划分好训练集和验证集。在yolov5下生成了两个文件yolov5_train.txt和yolov5_val.txt,yolov5_train.txt和yolov5_val.txt分别给出了训练图片文件和验证图片文件的列表, 含有每个图片的路径和文件名。

修改配置文件

创建data/voc_bm.yaml


# Train/val/test sets as 1) dir: path/to/imgs, 2) file: path/to/imgs.txt, or 3) list: [path/to/imgs1, path/to/imgs2, ..]
path: ./
train: # train images (relative to 'path')  16551 images
  - VOCdevkit/images/train/ 
val: # val images (relative to 'path')  4952 images
  - VOCdevkit/images/val/
test: # test images (optional)
# Classes
nc: 2  # number of classes
names: ['ball', 'messi']  # class names

 这里的路径要写对,有三种写路径和的方法,主要推荐前面两种,配置文件中的实例就是第一种。

1) dir: path/to/imgs

2) file: path/to/imgs.txt 

val: data/nwpu vhr-10/val.txt                                  #例子
train: data/nwpu vhr-10/train.txt                             #例子
test: data/nwpu vhr-10/test.txt                                #例子

 修改nc类别

然后将model文件夹下面的yolov5s和yolov5m的nc改为2,或者另外新建yolov5svoc和yolov5mvoc。

yolov5s训练自己的数据集

python train.py --data data/voc_bm.yaml --cfg models/yolov5s_bm.yaml --weights weights/yolov5s.pt --batch-size 16 --epochs 100 --workers 4 --name yolov5sbase

 如果开始训练出现长期卡在这里的话

Downloading https://ultralytics.com/assets/Arial.ttf to /root/.config/Ultralytics/Arial.ttf...

 是因为项目中缺少这个东西,我直接把这个文件放这里,下载完成以后直接放在总目录下也可以识别

链接:https://pan.baidu.com/s/11AwkBdV0fsavRcRC2EXoOg?pwd=gkgk 
提取码:gkgk

 然后可以正常的开始训练了。

 学生网络训练结果如下:

yolov5m教师网络训练

同上,仅需要将网络模型改为yolov5m就可以,训练命令如下:

python train.py --data data/voc_bm.yaml --cfg models/yolov5m_bm.yaml --weights weights/yolov5m.pt --batch-size 16 --epochs 100 --workers 4 --name yolov5m-base

 教师网络训练结果如下:

知识蒸馏训练

学生网络和教师网络训练完毕进入重点,蒸馏训练

把runs/train/yolov5s-base/weights/best.pt和runs/train/yolov5m-base/weights/best.pt拷贝到weights文件夹下并改名为yolov5szl.pt和yolov5mzl.pt

python train_distillation.py --weights weights/yolov5szl.pt --cfg models/yolov5s_bm.yaml --data data/voc_bm.yaml --batch-size 8 --epochs 100 --workers 4 --t_weights weights/yolov5mzl.pt --hyp data/hyps/hyp.scratch-low-distillation.yaml --distill --dist_loss l2 --name yolov5s-distilled

 训练过程可视化:

结果如下:

对比结果

蒸馏对比结果
参数yolov5syolov5m蒸馏L2
P0.870.9140.907
R0.880.9650.991
MAP500.9180.9590.976
MAP950.6960.7410.714
GFLOPs15.847.915.8

对比发现这个表格显示了蒸馏L2相对于yolov5s在不同指标下的数值提升。在精度(P)、召回率(R)和平均精度(MAP)方面,蒸馏L2都取得了显著的提升。

本文内容由网友自发贡献,转载请注明出处:【wpsshop博客】
推荐阅读
相关标签
  

闽ICP备14008679号