当前位置:   article > 正文

yolov5训练记录(使用矩池云服务器)_矩池云训练yolo

矩池云训练yolo
1、数据集准备

我用的是开源的自动驾驶数据集BDD100K,数据集中包含train,val和test的image和label,其中label是json格式的。

yolo使用的label是txt格式的,所以需要先将json格式转换成txt格式。有多种方式可以转换:第一次训练的时候我用的是先将json转换成xml,再将xml转换成txt,这种方式我没找到怎么进行数据清洗的方法;第二次训练的时候我用的是直接将json转换成txt,这种方式可以在里面添加一些函数用于过滤掉黑夜和过于小的目标的图片,进行数据清洗。

第一次训练的时候分了10个类:

第二次训练的时候分了3个类:

这是json转xml的代码:

  1. # jsonToXml
  2. import os
  3. import json
  4. import sys
  5. from xml.etree import ElementTree
  6. from xml.etree.ElementTree import Element, SubElement
  7. from lxml import etree
  8. from xml.dom.minidom import parseString
  9. # 数据集个类别
  10. categorys = ['car', 'bus', 'person', 'bike', 'truck', 'motor', 'train', 'rider', 'traffic sign', 'traffic light']
  11. def parseJson(jsonFile):
  12. '''
  13. params:
  14. jsonFile -- BDD00K数据集的一个json标签文件
  15. return:
  16. 返回一个列表的列表,存储了一个json文件里面的方框坐标及其所属的类,
  17. 形如:[[325, 342, 376, 384, 'car'], [245, 333, 336, 389, 'car']]
  18. '''
  19. objs = []
  20. obj = []
  21. f = open(jsonFile)
  22. info = json.load(f)
  23. objects = info['frames'][0]['objects']
  24. for i in objects:
  25. if (i['category'] in categorys):
  26. obj.append(int(i['box2d']['x1']))
  27. obj.append(int(i['box2d']['y1']))
  28. obj.append(int(i['box2d']['x2']))
  29. obj.append(int(i['box2d']['y2']))
  30. obj.append(i['category'])
  31. objs.append(obj)
  32. obj = []
  33. # print("objs",objs)
  34. return objs
  35. class PascalVocWriter:
  36. def __init__(self, foldername, filename, imgSize, databaseSrc='Unknown', localImgPath=None):
  37. '''
  38. params:
  39. foldername -- 要存储的xml文件的父目录
  40. filename -- xml文件的文件名
  41. imgSize -- 图片的尺寸
  42. databaseSrc -- 数据库名,这里不需要,默认为Unknown
  43. localImaPath -- xml文件里面的<path></path>标签的内容
  44. '''
  45. self.foldername = foldername
  46. self.filename = filename
  47. self.databaseSrc = databaseSrc
  48. self.imgSize = imgSize
  49. self.boxlist = []
  50. self.localImgPath = localImgPath
  51. def prettify(self, elem):
  52. """
  53. params:
  54. elem -- xml的根标签,以<annotation>开始
  55. return:
  56. 返回一个美观输出的xml(用到minidom),本质是一个str
  57. """
  58. xml = ElementTree.tostring(elem)
  59. dom = parseString(xml)
  60. # print(dom.toprettyxml(' '))
  61. prettifyResult = dom.toprettyxml(' ')
  62. return prettifyResult
  63. def genXML(self):
  64. """
  65. return:
  66. 生成一个VOC格式的xml,返回一个xml的根标签,以<annotation>开始
  67. """
  68. # Check conditions
  69. if self.filename is None or \
  70. self.foldername is None or \
  71. self.imgSize is None or \
  72. len(self.boxlist) <= 0:
  73. return None
  74. top = Element('annotation') # 创建一个根标签<annotation>
  75. folder = SubElement(top, 'folder') # 在根标签<annotation>下创建一个子标签<folder>
  76. folder.text = self.foldername # 用self.foldername的数据填充子标签<folder>
  77. filename = SubElement(top, 'filename') # 在根标签<annotation>下创建一个子标签<filename>
  78. filename.text = self.filename # 用self.filename的数据填充子标签<filename>
  79. localImgPath = SubElement(top, 'path') # 在根标签<annotation>下创建一个子标签<path>
  80. localImgPath.text = self.localImgPath # 用self.localImgPath的数据填充子标签<path>
  81. source = SubElement(top, 'source') # 在根标签<annotation>下创建一个子标签<source>
  82. database = SubElement(source, 'database') # 在根标签<source>下创建一个子标签<database>
  83. database.text = self.databaseSrc # 用self.databaseSrc的数据填充子标签<database>
  84. size_part = SubElement(top, 'size') # 在根标签<annotation>下创建一个子标签<size>
  85. width = SubElement(size_part, 'width') # 在根标签<size>下创建一个子标签<width>
  86. height = SubElement(size_part, 'height') # 在根标签<size>下创建一个子标签<height>
  87. depth = SubElement(size_part, 'depth') # 在根标签<size>下创建一个子标签<depth>
  88. width.text = str(self.imgSize[1]) # 用self.imgSize[1]的数据填充子标签<width>
  89. height.text = str(self.imgSize[0]) # 用self.imgSize[0]的数据填充子标签<height>
  90. if len(self.imgSize) == 3: # 如果图片深度为3,则用self.imgSize[2]的数据填充子标签<height>,否则用1填充
  91. depth.text = str(self.imgSize[2])
  92. else:
  93. depth.text = '1'
  94. segmented = SubElement(top, 'segmented')
  95. segmented.text = '0'
  96. return top
  97. def addBndBox(self, xmin, ymin, xmax, ymax, name):
  98. '''
  99. 将检测对象框坐标及其对象类别作为一个字典加入到self.boxlist中
  100. params:
  101. xmin -- 检测框的左上角的x坐标
  102. ymin -- 检测框的左上角的y坐标
  103. xmax -- 检测框的右下角的x坐标
  104. ymax -- 检测框的右下角的y坐标
  105. name -- 检测框内的对象类别名
  106. '''
  107. bndbox = {'xmin': xmin, 'ymin': ymin, 'xmax': xmax, 'ymax': ymax}
  108. bndbox['name'] = name
  109. self.boxlist.append(bndbox)
  110. def appendObjects(self, top):
  111. '''
  112. 在xml文件中加入检测框的坐标及其对象类别名
  113. params:
  114. top -- xml的根标签,以<annotation>开始
  115. '''
  116. for each_object in self.boxlist:
  117. object_item = SubElement(top, 'object')
  118. name = SubElement(object_item, 'name')
  119. name.text = str(each_object['name'])
  120. pose = SubElement(object_item, 'pose')
  121. pose.text = "Unspecified"
  122. truncated = SubElement(object_item, 'truncated')
  123. truncated.text = "0"
  124. difficult = SubElement(object_item, 'Difficult')
  125. difficult.text = "0"
  126. bndbox = SubElement(object_item, 'bndbox')
  127. xmin = SubElement(bndbox, 'xmin')
  128. xmin.text = str(each_object['xmin'])
  129. ymin = SubElement(bndbox, 'ymin')
  130. ymin.text = str(each_object['ymin'])
  131. xmax = SubElement(bndbox, 'xmax')
  132. xmax.text = str(each_object['xmax'])
  133. ymax = SubElement(bndbox, 'ymax')
  134. ymax.text = str(each_object['ymax'])
  135. def save(self, targetFile=None):
  136. '''
  137. 以美观输出的xml格式来保存xml文件
  138. params:
  139. targetFile -- 存储的xml文件名,不包括.xml部分
  140. '''
  141. root = self.genXML()
  142. self.appendObjects(root)
  143. out_file = None
  144. subdir = self.foldername.split('/')[-1]
  145. if not os.path.isdir(subdir):
  146. os.mkdir(subdir)
  147. if targetFile is None:
  148. with open(self.foldername + '/' + self.filename + '.xml', 'w') as out_file:
  149. prettifyResult = self.prettify(root)
  150. out_file.write(prettifyResult)
  151. out_file.close()
  152. else:
  153. with open(targetFile, 'w') as out_file:
  154. prettifyResult = self.prettify(root)
  155. out_file.write(prettifyResult)
  156. out_file.close()
  157. class PascalVocReader:
  158. def __init__(self, filepath):
  159. # shapes type:
  160. # [labbel, [(x1,y1), (x2,y2), (x3,y3), (x4,y4)], color, color]
  161. self.shapes = []
  162. self.filepath = filepath
  163. self.parseXML()
  164. def getShapes(self):
  165. return self.shapes
  166. def addShape(self, label, bndbox):
  167. xmin = int(bndbox.find('xmin').text)
  168. ymin = int(bndbox.find('ymin').text)
  169. xmax = int(bndbox.find('xmax').text)
  170. ymax = int(bndbox.find('ymax').text)
  171. points = [(xmin, ymin), (xmax, ymin), (xmax, ymax), (xmin, ymax)]
  172. self.shapes.append((label, points, None, None))
  173. def parseXML(self):
  174. assert self.filepath.endswith('.xml'), "Unsupport file format"
  175. parser = etree.XMLParser(encoding='utf-8')
  176. xmltree = ElementTree.parse(self.filepath, parser=parser).getroot()
  177. filename = xmltree.find('filename').text
  178. for object_iter in xmltree.findall('object'):
  179. bndbox = object_iter.find("bndbox")
  180. label = object_iter.find('name').text
  181. self.addShape(label, bndbox)
  182. return True
  183. def main(srcDir, dstDir):
  184. i = 1
  185. # os.walk()
  186. # dirName是你所要遍历的目录的地址, 返回的是一个三元组(root,dirs,files)
  187. # root所指的是当前正在遍历的这个文件夹的本身的地址
  188. # dirs是一个 list ,内容是该文件夹中所有的目录的名字(不包括子目录)
  189. # files 同样是 list , 内容是该文件夹中所有的文件(不包括子目录)
  190. for dirpath, dirnames, filenames in os.walk(srcDir):
  191. # print(dirpath, dirnames, filenames)
  192. for filepath in filenames:
  193. fileName = os.path.join(dirpath, filepath)
  194. print(fileName)
  195. print("processing: {}, {}".format(i, fileName))
  196. i = i + 1
  197. xmlFileName = filepath[:-5] # remove ".json" 5 character
  198. # 解析该json文件,返回一个列表的列表,存储了一个json文件里面的所有方框坐标及其所属的类
  199. objs = parseJson(str(fileName))
  200. # 如果存在检测对象,创建一个与该json文件具有相同名的VOC格式的xml文件
  201. if len(objs):
  202. tmp = PascalVocWriter(dstDir, xmlFileName, (720, 1280, 3), fileName)
  203. for obj in objs:
  204. tmp.addBndBox(obj[0], obj[1], obj[2], obj[3], obj[4])
  205. tmp.save()
  206. else:
  207. print(fileName)
  208. if __name__ == '__main__':
  209. # 这里写自己的json标签路径
  210. srcDir = r'D:\postgraduate\competition\dataset\bdd100k_labels\bdd100k\labels\100k\val' # 原json存放路径
  211. dstDir = r'D:\postgraduate\competition\dataset\bdd100k_labels\bdd100k\labels\100k\val_xml' # 转换后xml存放路径
  212. # srcDir = r'D:\postgraduate\competition\dataset\bdd100k_labels\bdd100k\labels\100k\train'
  213. # dstDir = r"D:\postgraduate\competition\dataset\bdd100k_labels\bdd100k\labels\100k\train_xml"
  214. main(srcDir, dstDir)

 这是xml转txt的代码:

  1. # xmlToTxt
  2. import os
  3. import glob
  4. import xml.etree.ElementTree as ET
  5. xml_file=r'D:\postgraduate\competition\dataset\bdd100k_labels\bdd100k\labels\100k\val_xml'
  6. l=['car', 'bus', 'person', 'bike', 'truck', 'motor', 'train', 'rider', 'traffic sign', 'traffic light']
  7. def convert(box,dw,dh):
  8. x=(box[0]+box[2])/2.0
  9. y=(box[1]+box[3])/2.0
  10. w=box[2]-box[0]
  11. h=box[3]-box[1]
  12. x=x/dw
  13. y=y/dh
  14. w=w/dw
  15. h=h/dh
  16. return x,y,w,h
  17. def f(name_id):
  18. xml_o=open(r'D:\postgraduate\competition\dataset\bdd100k_labels\bdd100k\labels\100k\val_xml\%s.xml'%name_id)
  19. txt_o=open(r'D:\postgraduate\competition\dataset\bdd100k_labels\bdd100k\labels\100k\val_txt\%s.txt'%name_id,'w')
  20. pares=ET.parse(xml_o)
  21. root=pares.getroot()
  22. objects=root.findall('object')
  23. size=root.find('size')
  24. dw=int(size.find('width').text)
  25. dh=int(size.find('height').text)
  26. for obj in objects :
  27. c=l.index(obj.find('name').text)
  28. bnd=obj.find('bndbox')
  29. b=(float(bnd.find('xmin').text),float(bnd.find('ymin').text),
  30. float(bnd.find('xmax').text),float(bnd.find('ymax').text))
  31. x,y,w,h=convert(b,dw,dh)
  32. write_t="{} {:.5f} {:.5f} {:.5f} {:.5f}\n".format(c,x,y,w,h)
  33. txt_o.write(write_t)
  34. xml_o.close()
  35. txt_o.close()
  36. name=glob.glob(os.path.join(xml_file,"*.xml"))
  37. for i in name :
  38. name_id=os.path.basename(i)[:-4]
  39. f(name_id)

这是json转txt的代码:

  1. # jsonToTxt
  2. import re
  3. import os
  4. import json
  5. def search_file(data_dir, pattern=r'\.jpg$'):
  6. # 返回一个目录的绝对路径
  7. root_dir = os.path.abspath(data_dir)
  8. # print(root_dir)
  9. for root, dirs, files in os.walk(root_dir):
  10. for f in files:
  11. # print(f)
  12. # print("****************")
  13. if re.search(pattern, f, re.I):
  14. abs_path = os.path.join(root, f)
  15. print(abs_path)
  16. # print('new file %s' % absfn)
  17. yield abs_path
  18. class Bdd2yolov5:
  19. def __init__(self):
  20. self.bdd100k_width = 1280
  21. self.bdd100k_height = 720
  22. self.select_categorys = ["person", "car", "bus", "truck", 'traffic light']
  23. self.cat2id = {
  24. "person": 0,
  25. "car": 1,
  26. "bus": 1,
  27. "truck": 1,
  28. 'traffic light': 2
  29. }
  30. @property
  31. def all_categorys(self):
  32. return ["person", "rider", "car", "bus", "truck", "bike",
  33. "motor", "traffic light", "traffic sign", "train"]
  34. def _filter_by_attr(self, attr=None):
  35. if attr is None:
  36. return False
  37. # 过滤掉晚上的图片
  38. if attr['timeofday'] == 'night':
  39. return True
  40. return False
  41. def _filter_by_box(self, w, h):
  42. # size ratio
  43. # 过滤掉过于小的小目标
  44. threshold = 0.001
  45. if float(w * h) / (self.bdd100k_width * self.bdd100k_height) < threshold:
  46. return True
  47. return False
  48. def bdd2yolov5(self, path):
  49. lines = ""
  50. with open(path) as fp:
  51. j = json.load(fp)
  52. if self._filter_by_attr(j['attributes']):
  53. return
  54. for fr in j["frames"]:
  55. dw = 1.0 / self.bdd100k_width
  56. dh = 1.0 / self.bdd100k_height
  57. for obj in fr["objects"]:
  58. if obj["category"] in self.select_categorys:
  59. idx = self.cat2id[obj["category"]]
  60. cx = (obj["box2d"]["x1"] + obj["box2d"]["x2"]) / 2.0
  61. cy = (obj["box2d"]["y1"] + obj["box2d"]["y2"]) / 2.0
  62. w = obj["box2d"]["x2"] - obj["box2d"]["x1"]
  63. h = obj["box2d"]["y2"] - obj["box2d"]["y1"]
  64. if w <= 0 or h <= 0:
  65. continue
  66. if self._filter_by_box(w, h):
  67. continue
  68. # 根据图片尺寸进行归一化
  69. cx, cy, w, h = cx * dw, cy * dh, w * dw, h * dh
  70. line = f"{idx} {cx:.6f} {cy:.6f} {w:.6f} {h:.6f}\n"
  71. lines += line
  72. if len(lines) != 0:
  73. # 转换后的以*.txt结尾的标注文件我就直接和*.json放一具目录了
  74. # yolov5中用到的时候稍微挪一下就行了
  75. yolo_txt = path.replace(".json", ".txt")
  76. with open(yolo_txt, 'w') as fp2:
  77. fp2.writelines(lines)
  78. # print("%s has been dealt!" % path)
  79. if __name__ == "__main__":
  80. bdd_label_dir = "./val"
  81. cvt = Bdd2yolov5()
  82. for path in search_file(bdd_label_dir, r"\.json$"):
  83. cvt.bdd2yolov5(path)

使用第二种直接将json转成txt的方法时,会出现不知道哪些image被过滤掉的情况,因为在转换的时候没有用到image,所以需要用到一个找到文件名称相同但后缀不同的代码:

  1. # 找同名但后缀不同的文件并输出
  2. import os
  3. import glob
  4. from PIL import Image
  5. # 指定找到文件后,另存为的文件夹路径
  6. outDir = os.path.abspath(r'D:\postgraduate\competition\dataset\data\val\img')
  7. # 指定第一个文件夹的位置
  8. imageDir1 = os.path.abspath(r'D:\BaiduNetdiskDownload\BDD100K\images\bdd100k\images\100k\val')
  9. # 定义要处理的第一个文件夹变量
  10. image1 = [] # image1指文件夹里的文件,包括文件后缀格式;
  11. imgname1 = [] # imgname1指里面的文件名称,不包括文件后缀格式
  12. # 通过glob.glob来获取第一个文件夹下,所有'.jpg'文件
  13. imageList1 = glob.glob(os.path.join(imageDir1, '*.jpg'))
  14. # 遍历所有文件,获取文件名称(包括后缀)
  15. for item in imageList1:
  16. image1.append(os.path.basename(item))
  17. # 遍历文件名称,去除后缀,只保留名称
  18. for item in image1:
  19. (temp1, temp2) = os.path.splitext(item)
  20. imgname1.append(temp1)
  21. # 对于第二个文件夹路径,做同样的操作
  22. imageDir2 = os.path.abspath(r'D:\postgraduate\competition\dataset\data\val\label')
  23. image2 = []
  24. imgname2 = []
  25. imageList2 = glob.glob(os.path.join(imageDir2, '*.txt'))
  26. for item in imageList2:
  27. image2.append(os.path.basename(item))
  28. for item in image2:
  29. (temp1, temp2) = os.path.splitext(item)
  30. imgname2.append(temp1)
  31. # 通过遍历,获取第一个文件夹下,文件名称(不包括后缀)与第二个文件夹相同的文件,并另存在outDir文件夹下。文件名称与第一个文件夹里的文件相同,后缀格式亦保持不变。
  32. for item1 in imgname1:
  33. for item2 in imgname2:
  34. if item1 == item2:
  35. dir = imageList1[imgname1.index(item1)]
  36. img = Image.open(dir)
  37. name = os.path.basename(dir)
  38. img.save(os.path.join(outDir, name))

这样datasets就准备好了,目录结构是这样的:

 2、yolov5训练参数调整

修改train.py文件:

主要是修改parse_opt函数里面的参数:

① 修改weights:我用的是yolov5l.pt。

yolov5共有4中预训练权重,区别如下:总的来说就是模型越小,检测速度越快,检测准确度越低。

 ② 修改cfg:这里需要对原有的yolov5l.yaml文件做一下修改,只需要修改class的数量,我这里一共检测3个类别,所以就把nc改成3.

 ③ 修改data:这里需要在原有的data文件夹下新建一个yaml文件,里面要写的内容如下:

  1. # YOLOv5
    声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/我家自动化/article/detail/337936
    推荐阅读
    相关标签