当前位置:   article > 正文

经典目标检测YOLOV1模型的训练及验证

经典目标检测YOLOV1模型的训练及验证

1、前期准备

准备好目录结构、数据集和关于YOLOv1的基础认知

1.1  创建目录结构

        自己创建项目目录结构,结构目录如下:

network                    CNN Backbone 存放位置
weights                    权重存放的位置
test_images             测试用的图片
utils                          辅助功能的代码存放位置 

models                    保存模型位置

data                         训练的数据集

1.2  数据集介绍与下载

1.2.1 数据集介绍

       首先了解数据集,对数据集了解后方便对数据进行相应处理。数据集详细介绍直通车:https://blog.csdn.net/qq_41946216/article/details/137683750?spm=1001.2014.3001.5501

1.2.1 数据集下载

       本次采用数据集: VOC2012数据集。

       数据集下载方式一http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar

       数据集下载方式二:

     下载并构建VOC2012数据集,从:https://gitee.com/ppov-nuc/pascal-vocdataset_-for_-yolo.git, 下载get_data文件generate_csv.py文件到本地,放到创建的目录结构中,修改get_data中下载的内容和相应路径,然后运行批处理文件get_data,在get_dat中会自动执行generate_csv.py,如下图所示。


2. 数据集处理

       在utils目录下创建工具类 generate_txt_file.py,主要用于数据集的划分和解析 Annotations/xxxxx.xml 文件中的类别bbox信息,并将信息存入voctrain.txt和voctest.txt文件,如下图所示:

具体代码:

  1. # author: baiCai
  2. # 1. 导包
  3. from xml.etree import ElementTree as ET
  4. import os
  5. import random
  6. # 2. 定义一些基本的参数
  7. # 定义所有的类名
  8. VOC_CLASSES = (
  9. 'aeroplane', 'bicycle', 'bird', 'boat',
  10. 'bottle', 'bus', 'car', 'cat', 'chair',
  11. 'cow', 'diningtable', 'dog', 'horse',
  12. 'motorbike', 'person', 'pottedplant',
  13. 'sheep', 'sofa', 'train', 'tvmonitor')
  14. '''
  15. 读取所有 xml 文件,存入列表
  16. '''
  17. # 要读取的xml文件路径,记得自己修改路径
  18. Annotations = '../data/VOC2012/Annotations/'
  19. # 列出所有的xml文件
  20. xml_files = os.listdir(Annotations)
  21. # 打乱数据集
  22. random.shuffle(xml_files)
  23. '''
  24. 定义训练集和测试比例
  25. 划分Annotations中的训练集和测试集文件列表
  26. '''
  27. # 训练集数量
  28. train_num = int(len(xml_files) * 0.7)
  29. # 训练列表
  30. train_file_list = xml_files[:train_num]
  31. # 测测试列表
  32. test_file_list = xml_files[train_num:]
  33. '''
  34. 定义 xml 解析后的信息存储路径和写对象
  35. '''
  36. # 训练集和测试集文件名字
  37. train_set_path = './voctrain.txt'
  38. test_set_path = './voctest.txt'
  39. # 3. 定义解析xml文件的函数
  40. '''
  41. 主要解析 xml 获取 类别名字和bbox,如
  42. {'name': 'person','bbox': [174, 101, 349, 351]}
  43. '''
  44. def parse_rec(filename):
  45. # 参数:输入xml文件名
  46. # 创建xml对象
  47. tree = ET.parse(filename)
  48. objects = []
  49. # 迭代读取xml文件中的object节点,即物体信息
  50. for obj in tree.findall('object'):
  51. obj_struct = {}
  52. # difficult属性,即这里不需要那些难判断的对象
  53. difficult = int(obj.find('difficult').text)
  54. if difficult == 1: # 若为1则跳过本次循环
  55. continue
  56. # 开始收集信息
  57. obj_struct['name'] = obj.find('name').text
  58. bbox = obj.find('bndbox')
  59. obj_struct['bbox'] =\
  60. [int(float(bbox.find('xmin').text)),
  61. int(float(bbox.find('ymin').text)),
  62. int(float(bbox.find('xmax').text)),
  63. int(float(bbox.find('ymax').text))]
  64. objects.append(obj_struct)
  65. return objects
  66. # 4. 把信息保存入文件中
  67. def write_txt(file_list,set_path):
  68. # # 生成训练集txt
  69. count = 0
  70. with open(set_path, 'w') as wt:
  71. for xml_file in file_list:
  72. count += 1
  73. # 获取图片名字
  74. image_name = xml_file.split('.')[0] + '.jpg' # 图片文件名
  75. # 对xml_file进行解析
  76. results = parse_rec(Annotations + xml_file)
  77. # 如果返回的对象为空,表示张图片难以检测,因此直接跳过
  78. if len(results) == 0:
  79. print(xml_file)
  80. continue
  81. # 否则,则写入文件中
  82. # 先写入图片名字
  83. wt.write(image_name)
  84. # 接着指定下面写入的格式
  85. for result in results:
  86. class_name = result['name']
  87. bbox = result['bbox']
  88. class_name = VOC_CLASSES.index(class_name) # 名字在类别中是下标位置
  89. wt.write(' ' + str(bbox[0]) +
  90. ' ' + str(bbox[1]) +
  91. ' ' + str(bbox[2]) +
  92. ' ' + str(bbox[3]) +
  93. ' ' + str(class_name))
  94. wt.write('\n')
  95. wt.close()
  96. # 5. 运行
  97. if __name__ == '__main__':
  98. write_txt(train_file_list,train_set_path)
  99. write_txt(test_file_list,test_set_path)

3. 构建数据加载器 

3.1定义初始化方法

       读取xxxx.xml解析后的文件
       对每行数据(每个图片信息)的所有中心点信息以【x,y,w,h】和标签分别存入box列表和label列表。
       当前图片的边界框和标签信息即box列表和label列表,转换为LongTensor格式添加到对应的boxex列表和labels列表。

3.2 定义增强图片方法

增加方法名称定义的函数
随机翻转图片和边界框random_flip(img, boxes)
随机缩放图片和边界框randomScale(img, boxes)
随机模糊图片randomBlur(img)
随机调整图片亮度RandomBrightness(img)
随机调整图片色调RandomHue(img)
随机调整图片饱和度RandomSaturation(img)
随机移动图片和边界框randomShift(img, boxes, labels)        
随机裁剪图片和边界框randomCrop(img, boxes, labels)
用于从图像中减去均值subMean(self, bgr, mean)
将BGR图像转换为RGB图像BGR2RGB(self, img)
将BGR图像转换为HSV图像BGR2HSV(self, img)
将HSV图像转换为BGR图像HSV2BGR(self, img)

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/笔触狂放9/article/detail/446269
推荐阅读
相关标签
  

闽ICP备14008679号