当前位置:   article > 正文

YOLO_v5将数据集按比例分为训练集和测试集

YOLO_v5将数据集按比例分为训练集和测试集

YOLO_v5将数据集按比例分为训练集和测试集:

  1. import xml.etree.ElementTree as ET
  2. import pickle
  3. import os
  4. from os import listdir, getcwd
  5. from os.path import join
  6. import random
  7. from shutil import copyfile
  8. classes=["car","van","bus","truck","other"] #数据标注类别
  9. TRAIN_RATIO = 80 #训练集占数据集的百分比:80%
  10. def clear_hidden_files(path):
  11. dir_list = os.listdir(path)
  12. for i in dir_list:
  13. abspath = os.path.join(os.path.abspath(path), i)
  14. if os.path.isfile(abspath):
  15. if i.startswith("._"):
  16. os.remove(abspath)
  17. else:
  18. clear_hidden_files(abspath)
  19. def convert(size, box):
  20. dw = 1./size[0]
  21. dh = 1./size[1]
  22. x = (box[0] + box[1])/2.0
  23. y = (box[2] + box[3])/2.0
  24. w = box[1] - box[0]
  25. h = box[3] - box[2]
  26. x = x*dw
  27. w = w*dw
  28. y = y*dh
  29. h = h*dh
  30. return (x,y,w,h)
  31. def convert_annotation(image_id):
  32. in_file = open('VOCdevkit/VOC2007/Annotations/%s.xml' %image_id)
  33. out_file = open('VOCdevkit/VOC2007/YOLOLabels/%s.txt' %image_id, 'w')
  34. tree=ET.parse(in_file)
  35. root = tree.getroot()
  36. size = root.find('size')
  37. w = int(size.find('width').text)
  38. h = int(size.find('height').text)
  39. for obj in root.iter('object'):
  40. difficult = obj.find('difficult').text
  41. cls = obj.find('name').text
  42. if cls not in classes or int(difficult) == 1:
  43. continue
  44. cls_id = classes.index(cls)
  45. xmlbox = obj.find('bndbox')
  46. b = (float(xmlbox.find('xmin').text), float(xmlbox.find('xmax').text), float(xmlbox.find('ymin').text), float(xmlbox.find('ymax').text))
  47. bb = convert((w,h), b)
  48. out_file.write(str(cls_id) + " " + " ".join([str(a) for a in bb]) + '\n')
  49. in_file.close()
  50. out_file.close()
  51. wd = os.getcwd()
  52. wd = os.getcwd()
  53. data_base_dir = os.path.join(wd, "VOCdevkit/")
  54. if not os.path.isdir(data_base_dir):
  55. os.mkdir(data_base_dir)
  56. work_sapce_dir = os.path.join(data_base_dir, "VOC2007/")
  57. if not os.path.isdir(work_sapce_dir):
  58. os.mkdir(work_sapce_dir)
  59. annotation_dir = os.path.join(work_sapce_dir, "Annotations/")
  60. if not os.path.isdir(annotation_dir):
  61. os.mkdir(annotation_dir)
  62. clear_hidden_files(annotation_dir)
  63. image_dir = os.path.join(work_sapce_dir, "JPEGImages/")
  64. if not os.path.isdir(image_dir):
  65. os.mkdir(image_dir)
  66. clear_hidden_files(image_dir)
  67. yolo_labels_dir = os.path.join(work_sapce_dir, "YOLOLabels/")
  68. if not os.path.isdir(yolo_labels_dir):
  69. os.mkdir(yolo_labels_dir)
  70. clear_hidden_files(yolo_labels_dir)
  71. yolov5_images_dir = os.path.join(data_base_dir, "images/")
  72. if not os.path.isdir(yolov5_images_dir):
  73. os.mkdir(yolov5_images_dir)
  74. clear_hidden_files(yolov5_images_dir)
  75. yolov5_labels_dir = os.path.join(data_base_dir, "labels/")
  76. if not os.path.isdir(yolov5_labels_dir):
  77. os.mkdir(yolov5_labels_dir)
  78. clear_hidden_files(yolov5_labels_dir)
  79. yolov5_images_train_dir = os.path.join(yolov5_images_dir, "train/")
  80. if not os.path.isdir(yolov5_images_train_dir):
  81. os.mkdir(yolov5_images_train_dir)
  82. clear_hidden_files(yolov5_images_train_dir)
  83. yolov5_images_test_dir = os.path.join(yolov5_images_dir, "val/")
  84. if not os.path.isdir(yolov5_images_test_dir):
  85. os.mkdir(yolov5_images_test_dir)
  86. clear_hidden_files(yolov5_images_test_dir)
  87. yolov5_labels_train_dir = os.path.join(yolov5_labels_dir, "train/")
  88. if not os.path.isdir(yolov5_labels_train_dir):
  89. os.mkdir(yolov5_labels_train_dir)
  90. clear_hidden_files(yolov5_labels_train_dir)
  91. yolov5_labels_test_dir = os.path.join(yolov5_labels_dir, "val/")
  92. if not os.path.isdir(yolov5_labels_test_dir):
  93. os.mkdir(yolov5_labels_test_dir)
  94. clear_hidden_files(yolov5_labels_test_dir)
  95. train_file = open(os.path.join(wd, "yolov5_train.txt"), 'w')
  96. test_file = open(os.path.join(wd, "yolov5_val.txt"), 'w')
  97. train_file.close()
  98. test_file.close()
  99. train_file = open(os.path.join(wd, "yolov5_train.txt"), 'a')
  100. test_file = open(os.path.join(wd, "yolov5_val.txt"), 'a')
  101. list_imgs = os.listdir(image_dir) # list image files
  102. prob = random.randint(1, 100)
  103. # print("Probability: %d" % prob)
  104. for i in range(0,len(list_imgs)):
  105. path = os.path.join(image_dir,list_imgs[i])
  106. if os.path.isfile(path):
  107. image_path = image_dir + list_imgs[i]
  108. voc_path = list_imgs[i]
  109. (nameWithoutExtention, extention) = os.path.splitext(os.path.basename(image_path))
  110. (voc_nameWithoutExtention, voc_extention) = os.path.splitext(os.path.basename(voc_path))
  111. annotation_name = nameWithoutExtention + '.xml'
  112. annotation_path = os.path.join(annotation_dir, annotation_name)
  113. label_name = nameWithoutExtention + '.txt'
  114. label_path = os.path.join(yolo_labels_dir, label_name)
  115. prob = random.randint(1, 100)
  116. print("Probability: %d" % prob,i)
  117. if(prob < TRAIN_RATIO): # train dataset
  118. if os.path.exists(annotation_path):
  119. train_file.write(image_path + '\n')
  120. convert_annotation(nameWithoutExtention) # convert label
  121. copyfile(image_path, yolov5_images_train_dir + voc_path)
  122. copyfile(label_path, yolov5_labels_train_dir + label_name)
  123. else: # test dataset
  124. if os.path.exists(annotation_path):
  125. test_file.write(image_path + '\n')
  126. convert_annotation(nameWithoutExtention) # convert label
  127. copyfile(image_path, yolov5_images_test_dir + voc_path)
  128. copyfile(label_path, yolov5_labels_test_dir + label_name)
  129. train_file.close()
  130. test_file.close()

需要修改数据标注的类别:

classes=["car","van","bus","truck","other"]

需要修改训练集占数据集的百分比

TRAIN_RATIO=80

生成的图像文件夹:

./VOCdevkit/images

生成的标注文件夹:

./VOCdevkit/labels
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/不正经/article/detail/77627
推荐阅读
相关标签
  

闽ICP备14008679号