当前位置:   article > 正文

跑通基于YOLOv5的旋转框目标检测_yolov5旋转目标检测

yolov5旋转目标检测

首先介绍一下如何制作带有角度的YOLOv5数据集标注。

先去这个网站下标注软件代码GitHub - cgvict/roLabelImg: Label Rotated Rect On Images for training

随后为其创建独立的虚拟环境

conda create -n rolabel36 python=3.6

激活环境

conda activate rolabel36

安装对应的依赖包

  1. pip install pyqt5-tools
  2. pip install lxml
  3. pyrcc5 -o resources.py resources.qrc

在项目根目录下运行打开标注软件的程序

python roLabelImg.py

选中需要标注的文件夹即可对数据集进行标注,标注示例如下图所示。

然而软件标注的格式是xml格式,网络训练需要的标注格式是txt文件,因此需要把xml文件转换为txt格式。以下代码可以实现这个转换,只需要把XMLDIR和OUTDIR改为自己的路径即可。

  1. import os
  2. import os.path as osp
  3. import math
  4. BASEDIR = osp.dirname(osp.abspath(__file__)) #获取的当前执行脚本的完整路径
  5. XMLDIR = osp.join(BASEDIR, r'C:\Users\BJUT\Desktop\label')
  6. OUTDIR = osp.join(BASEDIR, r'C:\Users\BJUT\Desktop\labels')
  7. if not osp.exists(OUTDIR):
  8. os.makedirs(OUTDIR)
  9. # pi=3.1415926
  10. xmlnames = [i for i in os.listdir(XMLDIR) if i.endswith('.xml')]
  11. # print(names)
  12. pi = 3.1415926
  13. # 转换成四点坐标
  14. def convert(cx, cy, w, h, a):
  15. if a >= pi:
  16. a -= pi
  17. # 计算斜径半长
  18. l = math.sqrt(w ** 2 + h ** 2) / 2
  19. # 计算初始矩形角度
  20. a0 = math.atan(h / w)
  21. # 旋转,计算旋转角
  22. # 右上角点 ↗
  23. a1 = a0 + a
  24. x1 = cx + l * math.cos(a1)
  25. y1 = cy + l * math.sin(a1)
  26. # 右下角点 ↘
  27. a2 = -a0 + a
  28. x2 = cx + l * math.cos(a2)
  29. y2 = cy + l * math.sin(a2)
  30. # 左下角点 ↙
  31. a3 = a1 + pi
  32. x3 = cx + l * math.cos(a3)
  33. y3 = cy + l * math.sin(a3)
  34. # 左上角点 ↖
  35. a4 = a2 + pi
  36. x4 = cx + l * math.cos(a4)
  37. y4 = cy + l * math.sin(a4)
  38. return [x1, y1, x2, y2, x3, y3, x4, y4]
  39. # 点关于直线对称
  40. for xmlname in xmlnames:
  41. cx = []
  42. cy = []
  43. w = []
  44. h = []
  45. angle = []
  46. name = []
  47. txtname = xmlname.split('.')[-2] + '.txt' #从右往左数,右边第一个是-1
  48. with open(osp.join(OUTDIR, txtname), 'w') as fp:
  49. fp.write('') #向文件中写入东西
  50. with open(osp.join(XMLDIR, xmlname), 'rb') as fp:
  51. lines = fp.readlines() #依次读取每行
  52. for line in lines:
  53. # print(line, end='')
  54. if line.strip().startswith(b'<width>'): #strip()为去掉首尾空格
  55. img_width = eval(line.strip().strip(b'<width>').strip(b'</width>')) #eval函数的作用是获取返回值
  56. # print(img_width)
  57. if line.strip().startswith(b'<height>'):
  58. img_height = eval(line.strip().strip(b'<height>').strip(b'</height>'))
  59. # print(img_height)
  60. if line.strip().startswith(b'<depth>'):
  61. img_depth = eval(line.strip().strip(b'<depth>').strip(b'</depth>'))
  62. # print(img_depth)
  63. if line.strip().startswith(b'<cx>'):
  64. cx.append(eval(line.strip().strip(b'<cx>').strip(b'</cx>'))) #空列表里加上读取的值
  65. # print(cx)
  66. if line.strip().startswith(b'<cy>'):
  67. cy.append(eval(line.strip().strip(b'<cy>').strip(b'</cy>')))
  68. # print(cy)
  69. if line.strip().startswith(b'<w>'):
  70. w.append(eval(line.strip().strip(b'<w>').strip(b'</w>')))
  71. # print(w)
  72. if line.strip().startswith(b'<h>'):
  73. h.append(eval(line.strip().strip(b'<h>').strip(b'</h>')))
  74. # print(h)
  75. if line.strip().startswith(b'<angle>'):
  76. angle.append(eval(line.strip().strip(b'<angle>').strip(b'</angle>')))
  77. if line.strip().startswith(b'<name>'):
  78. name.append(line.strip().strip(b'<name>').strip(b'</name>').decode('utf-8'))
  79. # print(angle)
  80. #with open(osp.join(OUTDIR, txtname), 'a') as fp:
  81. #fp.write("imagesource:GoogleEarth")
  82. #fp.write('\n')
  83. #fp.write("gsd:0.146343590398")
  84. #fp.write('\n')
  85. for i in range(len(cx)):
  86. cls0 = 0.0
  87. cx_i = cx[i]
  88. cy_i = cy[i]
  89. w_i = w[i]
  90. h_i = h[i]
  91. a_i = angle[i]
  92. object = name[i]
  93. x0, y0, x1, y1, x2, y2, x3, y3 = convert(cx_i, cy_i, w_i, h_i, a_i)
  94. put_str = ' '.join(
  95. [str(x0), str(y0), str(x1), str(y1), str(x2), str(y2), str(x3), str(y3), str(object), str(0)])
  96. #put_str = ' '.join([str(cx_i), str(cy_i), str(w_i), str(h_i), str(a_i)])
  97. with open(osp.join(OUTDIR, txtname), 'a') as fp:
  98. fp.write(put_str)
  99. fp.write('\n')
  100. print(xmlname, 'to', txtname, 'done.')

接着,将数据集按如下方式进行分类(数据集的名字可以换,但是对应的图像和标注的名字必须是images和labelTxt)。

接下来我们就可以训练模型了。

训练模型的第一步当然是下载代码。

https://github.com/hukaixuan19970627/yolov5_obb

为其创建独特的虚拟环境。

conda create -n YOLOv5 python=3.8

进入所创建的虚拟环境。

conda activate YOLOv5

在安装依赖包之前,需要安装与pytorch相对应的cuda版本,此项目使用的pytorch为1.7.0。去pytorch官网查看对应的CUDA版本(pytorch官网地址:PyTorch)。如下图所示,项目需要安装CUDA10.2。

通过命令行输入nvidia-smi查看自己的显卡驱动版本以及支持的最大CUDA版本,下图第一行就显示了这些信息,可以看到,最大支持CUDA12.2,因此可以放心地安装cuda10.2。

进入下面这个网页下载CUDA10.2

CUDA Toolkit 10.2 Download | NVIDIA Developer

可以通过命名行下载,也可以自行下载。下载后运行固定命令安装,安装命令如下:

sudo sh cuda_10.2.89_440.33.01_linux.run

安装好cuda以后,再安装pytorch,进入pytorch官网(Previous PyTorch Versions | PyTorch),找到对应的pytorch版本,运行相应的指令。例如我们需要1.7.0版本的pytorch,需要运行在创建的虚拟环境下运行这条指令:

conda install pytorch==1.7.0 torchvision==0.8.0 torchaudio==0.7.0 cudatoolkit=10.2 -c pytorch

等待安装完成即可,如果安装过程很慢,考虑换源,还有一点,一定要关掉VPN。

安装完pytorch和CUDA以后呢,我们可以再检查一下pytorch和cuda的版本是否匹配!(要在刚刚创建的YOLOv5虚拟环境内)首先在终端输入python,接着一一输入以下三条命令,如果返回的是True说明我们的pytorch和cuda的版本匹配,如果返回的是False,说明两者不匹配。

  1. import torch
  2. print(torch.__version__)
  3. print(torch.cuda.is_available())

接下来就可以安装依赖包了(必须在YOLOv5环境内)

pip install -r requirements.txt

等待安装完成即可。

接下来按着yolov5_obb作者的步骤走。

  1. cd utils/nms_rotated
  2. python setup.py develop
  1. cd yolov5_obb/DOTA_devkit
  2. sudo apt-get install swig
  3. swig -c++ -python polyiou.i
  4. python setup.py build_ext --inplace

中途我也遇到过问题,但一一百度,都能解决。

ok,下一步是训练模型!这部分和官方的YOLOv5步骤差不多。

1、在/data/scripts目录下创建自己的data,复制粘贴coco.yaml并重新命名为my_data.yaml。里面数据集的地址改成自己数据集的地址。我的数据集地址是这样:

2、编辑train.py :workers最好设置为0 ,weights设为空(如果嫌浪费时间,可以加入预训练模型),cfg是关于模型参数的一些设置(在models/hub下可以看到一些已经给的参数设置,复制yolov5s.yaml路径到cfg中),data为指定数据集(也就是刚刚创建的my_data.yaml),对应的default填其路径,在这里应为data/my_data.yaml。epochs代表训练的轮数,自己设置即可。

3、运行train.py,等待训练完成即可。

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/繁依Fanyi0/article/detail/850447
推荐阅读
相关标签
  

闽ICP备14008679号