当前位置:   article > 正文

YOLO V5 训练自己的数据集(全网最详细)_yolov 数据集下载

yolov 数据集下载

1.YOLO V5介绍

YOLOv5-6.0版本的网络可以按照深度和宽度分为五个版本:n、s、m、l和x。在大多数情况下,为了满足模型轻量化设计并保证检测精度,我们选择YOLOv5s作为基础模型进行改进。

YOLOv5主要由四个部分组成:输入端(Input)、主干网络(Backbone)、颈部网络(Neck)和检测端(Head)。这些部分协同工作,使得模型能够高效地进行目标检测。

主干网络是模型的核心部分,负责提取图像的特征信息。颈部网络则将主干网络提取的特征信息进行融合,为检测端提供更加丰富的信息。检测端则负责根据这些特征信息对目标进行定位和分类。

通过选用合适的版本和改进基础模型,YOLOv5可以为用户提供准确、快速的目标检测服务。

源代码:https://github.com/ultralytics/yolov5

其预训练权重可在官方下载

本项目使用YOLOv5s.pt

2.数据集介绍

WiderPerson数据集是一个针对拥挤场景行人检测的基准数据集,其图像来源不再仅限于交通场景,而是从多种场景中精心挑选而来。该数据集包含13382张图像,并附带了约40万个遮挡标记作为注释。为了确保公平性和有效性,我们随机选取了8000张、1000张和4382张图像分别作为训练集、验证集和测试集。与CityPersons和WIDER FACE数据集类似,我们不会发布测试图像的标注文件,以防止潜在的作弊行为。

您可以通过以下网址下载WiderPerson数据集:WiderPerson: A Diverse Dataset for Dense Pedestrian Detection in the Wild

下载完成之后,其文件夹如下

其中Annotations文件下的txt文件如下,第一行数字为标签数目(没啥用处),前面数字为类别,本数据集内共有五个类别

  1. 0 : pedestrians
  2. 1 : riders
  3. 2 : partially-visible persons
  4. 3 : ignore regions
  5. 4 : crowd

为把她转成VOC格式文件,需要把这这个txt文件转换成xml文件,代码如下

  1. import os
  2. import numpy as np
  3. import scipy.io as sio
  4. import shutil
  5. from lxml.etree import Element, SubElement, tostring
  6. from xml.dom.minidom import parseString
  7. import cv2
  8. def make_voc_dir():
  9. # labels 目录若不存在,创建labels目录。若存在,则清空目录
  10. if not os.path.exists('../VOC2007/Annotations'):
  11. os.makedirs('../VOC2007/Annotations')
  12. if not os.path.exists('../VOC2007/ImageSets'):
  13. os.makedirs('../VOC2007/ImageSets')
  14. os.makedirs('../VOC2007/ImageSets/Main')
  15. if not os.path.exists('../VOC2007/JPEGImages'):
  16. os.makedirs('../VOC2007/JPEGImages')
  17. if __name__ == '__main__':
  18. classes = {'1': 'pedestrians',
  19. '2': 'riders',
  20. '3': 'partially',
  21. '4': 'ignore',
  22. '5': 'crowd'}
  23. VOCRoot = '../VOC2007'
  24. widerDir = './WiderPerson' # 数据集所在的路径
  25. wider_path = './WiderPerson/val.txt'
  26. make_voc_dir()
  27. with open(wider_path, 'r') as f:
  28. imgIds = [x for x in f.read().splitlines()]
  29. for imgId in imgIds:
  30. objCount = 0 # 一个标志位,用来判断该img是否包含我们需要的标注
  31. filename = imgId + '.jpg'
  32. img_path = './WiderPerson/Images/' + filename
  33. print('Img :%s' % img_path)
  34. img = cv2.imread(img_path)
  35. width = img.shape[1] # 获取图片尺寸
  36. height = img.shape[0] # 获取图片尺寸 360
  37. node_root = Element('annotation')
  38. node_folder = SubElement(node_root, 'folder')
  39. node_folder.text = 'JPEGImages'
  40. node_filename = SubElement(node_root, 'filename')
  41. node_filename.text = 'VOC2007/JPEGImages/%s' % filename
  42. node_size = SubElement(node_root, 'size')
  43. node_width = SubElement(node_size, 'width')
  44. node_width.text = '%s' % width
  45. node_height = SubElement(node_size, 'height')
  46. node_height.text = '%s' % height
  47. node_depth = SubElement(node_size, 'depth')
  48. node_depth.text = '3'
  49. label_path = img_path.replace('Images', 'Annotations') + '.txt'
  50. with open(label_path) as file:
  51. line = file.readline()
  52. count = int(line.split('\n')[0]) # 里面行人个数
  53. line = file.readline()
  54. while line:
  55. cls_id = line.split(' ')[0]
  56. xmin = int(line.split(' ')[1]) + 1
  57. ymin = int(line.split(' ')[2]) + 1
  58. xmax = int(line.split(' ')[3]) + 1
  59. ymax = int(line.split(' ')[4].split('\n')[0]) + 1
  60. line = file.readline()
  61. cls_name = classes[cls_id]
  62. obj_width = xmax - xmin
  63. obj_height = ymax - ymin
  64. difficult = 0
  65. if obj_height <= 6 or obj_width <= 6:
  66. difficult = 1
  67. node_object = SubElement(node_root, 'object')
  68. node_name = SubElement(node_object, 'name')
  69. node_name.text = cls_name
  70. node_difficult = SubElement(node_object, 'difficult')
  71. node_difficult.text = '%s' % difficult
  72. node_bndbox = SubElement(node_object, 'bndbox')
  73. node_xmin = SubElement(node_bndbox, 'xmin')
  74. node_xmin.text = '%s' % xmin
  75. node_ymin = SubElement(node_bndbox, 'ymin')
  76. node_ymin.text = '%s' % ymin
  77. node_xmax = SubElement(node_bndbox, 'xmax')
  78. node_xmax.text = '%s' % xmax
  79. node_ymax = SubElement(node_bndbox, 'ymax')
  80. node_ymax.text = '%s' % ymax
  81. node_name = SubElement(node_object, 'pose')
  82. node_name.text = 'Unspecified'
  83. node_name = SubElement(node_object, 'truncated')
  84. node_name.text = '0'
  85. image_path = VOCRoot + '/JPEGImages/' + filename
  86. xml = tostring(node_root, pretty_print=True) # 'annotation'
  87. dom = parseString(xml)
  88. xml_name = filename.replace('.jpg', '.xml')
  89. xml_path = VOCRoot + '/Annotations/' + xml_name
  90. with open(xml_path, 'wb') as f:
  91. f.write(xml)
  92. # widerDir = '../WiderPerson' # 数据集所在的路径
  93. shutil.copy(img_path, '../VOC2007/JPEGImages/' + filename)

可以用以下代码展示一下数据集

  1. # -*- coding: utf-8 -*-
  2. import os
  3. import cv2
  4. if __name__ == '__main__':
  5. path = './WiderPerson/train.txt'
  6. with open(path, 'r') as f:
  7. img_ids = [x for x in f.read().splitlines()]
  8. for img_id in img_ids: # '000040'
  9. img_path = './WiderPerson/JPEGImages/' + img_id + '.jpg'
  10. print(img_path)
  11. img = cv2.imread(img_path)
  12. im_h = img.shape[0]
  13. im_w = img.shape[1]
  14. print(img_path)
  15. #label_path = img_path.replace('Images', 'Annotations') + '.txt'
  16. label_path = img_path.replace('JPEGImages', 'Annotations') + '.txt'
  17. print(label_path)
  18. with open(label_path) as file:
  19. line = file.readline()
  20. count = int(line.split('\n')[0]) # 里面行人个数
  21. line = file.readline()
  22. while line:
  23. cls = int(line.split(' ')[0])
  24. print(cls)
  25. # < class_label =1: pedestrians > 行人
  26. # < class_label =2: riders > 骑车的
  27. # < class_label =3: partially-visible persons > 遮挡的部分行人
  28. # < class_label =4: ignore regions > 一些假人,比如图画上的人
  29. # < class_label =5: crowd > 拥挤人群,直接大框覆盖了
  30. if cls == 1 or cls == 3:
  31. xmin = float(line.split(' ')[1])
  32. ymin = float(line.split(' ')[2])
  33. xmax = float(line.split(' ')[3])
  34. ymax = float(line.split(' ')[4].split('\n')[0])
  35. img = cv2.rectangle(img, (int(xmin), int(ymin)), (int(xmax), int(ymax)), (0, 255, 0), 2)
  36. line = file.readline()
  37. cv2.imshow('result', img)
  38. cv2.waitKey(0)

3.数据集处理

用上述代码可以生成以下文件夹

下面划分数据集和验证集,用split_train_val.py

  1. # coding:utf-8
  2. # coding:utf-8
  3. import os
  4. import random
  5. import argparse
  6. parser = argparse.ArgumentParser()
  7. #xml文件的地址,根据自己的数据进行修改 xml一般存放在Annotations下
  8. parser.add_argument('--xml_path', default='./VOC2007/Annotations', type=str, help='input xml label path')
  9. #数据集的划分,地址选择自己数据下的ImageSets/Main
  10. parser.add_argument('--txt_path', default='./VOC2007/ImageSets/Main', type=str, help='output txt label path')
  11. opt = parser.parse_args()
  12. trainval_percent = 1
  13. train_percent = 0.9
  14. xmlfilepath = opt.xml_path
  15. txtsavepath = opt.txt_path
  16. print(xmlfilepath)
  17. total_xml = os.listdir(xmlfilepath)
  18. if not os.path.exists(txtsavepath):
  19. os.makedirs(txtsavepath)
  20. num = len(total_xml)
  21. list_index = range(num)
  22. tv = int(num * trainval_percent)
  23. tr = int(tv * train_percent)
  24. trainval = random.sample(list_index, tv)
  25. train = random.sample(trainval, tr)
  26. file_trainval = open(txtsavepath + '/trainval.txt', 'w')
  27. file_test = open(txtsavepath + '/test.txt', 'w')
  28. file_train = open(txtsavepath + '/train.txt', 'w')
  29. file_val = open(txtsavepath + '/val.txt', 'w')
  30. for i in list_index:
  31. name = total_xml[i][:-4] + '\n'
  32. if i in trainval:
  33. file_trainval.write(name)
  34. if i in train:
  35. file_train.write(name)
  36. else:
  37. file_val.write(name)
  38. else:
  39. file_test.write(name)
  40. file_trainval.close()
  41. file_train.close()
  42. file_val.close()
  43. file_test.close()

生成的txt文件如下

再一步,使用voc_labels.py  names修改成自己的类别

  1. # -*- coding: utf-8 -*-
  2. import xml.etree.ElementTree as ET
  3. import os
  4. from os import getcwd
  5. sets = ['train', 'val', 'test']
  6. classes = ["pedestrians","riders","partially-visible persons","ignore regions","crowd"] # 改成自己的类别
  7. abs_path = os.getcwd()
  8. print(abs_path)
  9. def convert(size, box):
  10. dw = 1. / (size[0])
  11. dh = 1. / (size[1])
  12. x = (box[0] + box[1]) / 2.0 - 1
  13. y = (box[2] + box[3]) / 2.0 - 1
  14. w = box[1] - box[0]
  15. h = box[3] - box[2]
  16. x = x * dw
  17. w = w * dw
  18. y = y * dh
  19. h = h * dh
  20. return x, y, w, h
  21. def convert_annotation(image_id):
  22. in_file = open('D:/V5/VOC2007/Annotations/%s.xml' % (image_id), encoding='UTF-8')
  23. out_file = open('D:/V5/VOC2007/labels/%s.txt' % (image_id), 'w')
  24. tree = ET.parse(in_file)
  25. root = tree.getroot()
  26. size = root.find('size')
  27. w = int(size.find('width').text)
  28. h = int(size.find('height').text)
  29. for obj in root.iter('object'):
  30. # difficult = obj.find('difficult').text
  31. difficult = obj.find('difficult').text
  32. cls = obj.find('name').text
  33. if cls not in classes or int(difficult) == 1:
  34. continue
  35. cls_id = classes.index(cls)
  36. xmlbox = obj.find('bndbox')
  37. b = (float(xmlbox.find('xmin').text), float(xmlbox.find('xmax').text), float(xmlbox.find('ymin').text),
  38. float(xmlbox.find('ymax').text))
  39. b1, b2, b3, b4 = b
  40. # 标注越界修正
  41. if b2 > w:
  42. b2 = w
  43. if b4 > h:
  44. b4 = h
  45. b = (b1, b2, b3, b4)
  46. bb = convert((w, h), b)
  47. out_file.write(str(cls_id) + " " + " ".join([str(a) for a in bb]) + '\n')
  48. wd = getcwd()
  49. for image_set in sets:
  50. if not os.path.exists('D:/V5/VOC2007/labels/'):
  51. os.makedirs('D:/V5/VOC2007/labels/')
  52. image_ids = open('D:/V5/VOC2007/ImageSets/Main/%s.txt' % (image_set)).read().strip().split()
  53. list_file = open('D:/V5/VOC2007/%s.txt' % (image_set), 'w')
  54. for image_id in image_ids:
  55. list_file.write('D:/V5/VOC2007/JPEGImages/%s.jpg\n' % (image_id))
  56. convert_annotation(image_id)
  57. list_file.close()

四.训练过程

找到data文件夹的xView.yaml文件,复制一份,改成data.yaml文件,里面放自己的类别

修改前

修改后

找到yolov5s.yaml,复制一份改成yolov5s_s.yaml修改其中的nc参数

修改train.py中的参数,weights改成下载的预训练权重,cfg放yolov5s_s.yaml,data放 data.yaml,修改合理的epoch和batch_size,看着自己的显卡来

运行train.py报错,找网上教程自己修改

五.结果显示

训练完成之后,运行detect.py文件,修改参数,weights的权重在runs下面,source更改可以实现图片,摄像头,,视频的检测,别的参数看着修改,一般不修改。

运行结果如下

欢迎交流评论,有啥问题评论区交流

 

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/IT小白/article/detail/338018
推荐阅读
相关标签
  

闽ICP备14008679号