当前位置:   article > 正文

MMRotate从零开始训练自己的数据集_mmrotate训练自己数据集

mmrotate训练自己数据集

1.虚拟环境安装

step1:下载并安装Anaconda,Anaconda的国内镜像:

Index of /anaconda/archive/ | 清华大学开源软件镜像站 | Tsinghua Open Source Mirror

这里建议选择较新的 Anaconda 版本

上面的是32位系统,下面的是64位系统(一般选第二个就可以)

step2:更新国内源

下面的指令都在 Anaconda Prompt 中操作

如果不更新国内源可能会导致安装某些包的时候出错

pypi | 镜像站使用帮助 | 清华大学开源软件镜像站 | Tsinghua Open Source Mirror

step3:Anaconda下创建虚拟环境

  1. conda create --name mmrotate python=3.8
  2. conda activate mmrotate

这里mmrotate是虚拟环境的名称,可以修改为你想要的,这里指定的是 python3.8  版本。

step4:下载torch和torchvision(本地安装稳定些)

https://download.pytorch.org/whl/torch_stable.html

这里我选择的版本是torch==1.8.1  torchvision==0.9.1(这里要注意python版本的对应,比如这里选择cp=38。还有我的环境是cuda10.1

(还有一点要注意的是30系列以上的显卡要下载cuda11以上的版本,否则会出错)

下载好whl文件后,从虚拟环境中进入到下载目录,然后pip install依次安装torch和torchvision ,如图所示:

​ 

step5:安装mmcv_full、mmdetection和mmrotate

安装完成后,下面先进行mmcv_fullmmdetection的安装,因为mmrotate是基于以上两个模型库的。

mmcv_full:https://download.openmmlab.com/mmcv/dist/cu101/torch1.8.0/index.html

Installation — mmcv 1.6.0 documentation

根据自己的版本进行下载,这里我下载的是:

下载之后还是用pip install 命令进行安装

mmdetection:

pip install mmdet

最后是安装mmrotate : 

pip install mmrotate

这里我下载官方的代码版本为:

cmd界面下cd进入到mmrotate目录下,再执行

pip install -r requirements.txt

至此,环境搭建部分就结束了。 

2.测试mmrotate是否安装成功

修改image_demo.py

  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. """Inference on single image.
  3. Example:
  4. ```
  5. wget -P checkpoint https://download.openmmlab.com/mmrotate/v0.1.0/oriented_rcnn/oriented_rcnn_r50_fpn_1x_dota_le90/oriented_rcnn_r50_fpn_1x_dota_le90-6d2b2ce0.pth # noqa: E501, E261.
  6. python demo/image_demo.py \
  7. demo/demo.jpg \
  8. configs/oriented_rcnn/oriented_rcnn_r50_fpn_1x_dota_le90.py \
  9. work_dirs/oriented_rcnn_r50_fpn_1x_dota_v3/epoch_12.pth
  10. ```
  11. """ # nowq
  12. from argparse import ArgumentParser
  13. from mmdet.apis import inference_detector, init_detector, show_result_pyplot
  14. import mmrotate # noqa: F401
  15. import os
  16. ROOT = os.getcwd()
  17. def parse_args():
  18. parser = ArgumentParser()
  19. parser.add_argument('--img', default=os.path.join(ROOT, 'demo.jpg'), help='Image file')
  20. parser.add_argument('--config', default=os.path.join(ROOT, '../configs/oriented_rcnn/oriented_rcnn_r50_fpn_1x_dota_le90.py'), help='Config file')
  21. parser.add_argument('--checkpoint', default=os.path.join(ROOT, '../pre-models/oriented_rcnn_r50_fpn_1x_dota_le90-6d2b2ce0.pth'), help='Checkpoint file')
  22. parser.add_argument(
  23. '--device', default='cuda:0', help='Device used for inference')
  24. parser.add_argument(
  25. '--palette',
  26. default='dota',
  27. choices=['dota', 'sar', 'hrsc', 'hrsc_classwise', 'random'],
  28. help='Color palette used for visualization')
  29. parser.add_argument(
  30. '--score-thr', type=float, default=0.3, help='bbox score threshold')
  31. args = parser.parse_args()
  32. return args
  33. def main(args):
  34. # build the model from a config file and a checkpoint file
  35. model = init_detector(args.config, args.checkpoint, device=args.device)
  36. # test a single image
  37. result = inference_detector(model, args.img)
  38. # show the results
  39. show_result_pyplot(
  40. model,
  41. args.img,
  42. result,
  43. palette=args.palette,
  44. score_thr=args.score_thr)
  45. if __name__ == '__main__':
  46. args = parse_args()
  47. main(args)

其中,需要自己下载预训练权重,网站在代码上方。下载慢的话可以复制链接到迅雷下载。

3.训练自己的数据集

训练自己的数据集,自定义数据集制作这部分其实是最麻烦的。MMrotate所使用的数据集格式是dota类型的,图片为.png格式且尺寸是 n×n 的(方形),不过不用担心,官方项目中有相应的工具包可自动转换。

不给我发现高宽不相等的数据集也可以进行训练。

具体参考:是否支持其他尺寸的图片输入而不用转化为DOTA类型1024*1024尺寸的图片? · Issue #237 · open-mmlab/mmrotate · GitHub

part1:训练数据集准备

这一部分可以参考我之前的博客:

记录使用yolov5进行旋转目标的检测_江小白jlj的博客-CSDN博客_yolov5旋转目标检测

这里给出rolabelimg生成的xml文件转dota数据格式的代码

  1. '''
  2. rolabelimg xml data to dota 8 points data
  3. '''
  4. import os
  5. import xml.etree.ElementTree as ET
  6. import math
  7. import cv2
  8. import numpy as np
  9. def edit_xml(xml_file):
  10. if ".xml" not in xml_file:
  11. return
  12. tree = ET.parse(xml_file)
  13. objs = tree.findall('object')
  14. txt=xml_file.replace(".xml",".txt")
  15. png=xml_file.replace(".xml",".png")
  16. src=cv2.imread(png,1)
  17. with open(txt,'w') as wf:
  18. wf.write("imagesource:Google\n")
  19. # wf.write("gsd:0.115726939386\n")
  20. for ix, obj in enumerate(objs):
  21. x0text = ""
  22. y0text =""
  23. x1text = ""
  24. y1text =""
  25. x2text = ""
  26. y2text = ""
  27. x3text = ""
  28. y3text = ""
  29. difficulttext=""
  30. className=""
  31. obj_type = obj.find('type')
  32. type = obj_type.text
  33. obj_name = obj.find('name')
  34. className = obj_name.text
  35. obj_difficult= obj.find('difficult')
  36. difficulttext = obj_difficult.text
  37. if type == 'bndbox':
  38. obj_bnd = obj.find('bndbox')
  39. obj_xmin = obj_bnd.find('xmin')
  40. obj_ymin = obj_bnd.find('ymin')
  41. obj_xmax = obj_bnd.find('xmax')
  42. obj_ymax = obj_bnd.find('ymax')
  43. xmin = float(obj_xmin.text)
  44. ymin = float(obj_ymin.text)
  45. xmax = float(obj_xmax.text)
  46. ymax = float(obj_ymax.text)
  47. x0text = str(xmin)
  48. y0text = str(ymin)
  49. x1text = str(xmax)
  50. y1text = str(ymin)
  51. x2text = str(xmin)
  52. y2text = str(ymax)
  53. x3text = str(xmax)
  54. y3text = str(ymax)
  55. points=np.array([[int(x0text),int(y0text)],[int(x1text),int(y1text)],[int(x2text),int(y2text)],[int(x3text),int(y3text)]],np.int32)
  56. cv2.polylines(src,[points],True,(255,0,0)) #画任意多边
  57. elif type == 'robndbox':
  58. obj_bnd = obj.find('robndbox')
  59. obj_bnd.tag = 'bndbox' # 修改节点名
  60. obj_cx = obj_bnd.find('cx')
  61. obj_cy = obj_bnd.find('cy')
  62. obj_w = obj_bnd.find('w')
  63. obj_h = obj_bnd.find('h')
  64. obj_angle = obj_bnd.find('angle')
  65. cx = float(obj_cx.text)
  66. cy = float(obj_cy.text)
  67. w = float(obj_w.text)
  68. h = float(obj_h.text)
  69. angle = float(obj_angle.text)
  70. x0text, y0text = rotatePoint(cx, cy, cx - w / 2, cy - h / 2, -angle)
  71. x1text, y1text = rotatePoint(cx, cy, cx + w / 2, cy - h / 2, -angle)
  72. x2text, y2text = rotatePoint(cx, cy, cx + w / 2, cy + h / 2, -angle)
  73. x3text, y3text = rotatePoint(cx, cy, cx - w / 2, cy + h / 2, -angle)
  74. points=np.array([[int(x0text),int(y0text)],[int(x1text),int(y1text)],[int(x2text),int(y2text)],[int(x3text),int(y3text)]],np.int32)
  75. cv2.polylines(src,[points],True,(255,0,0)) #画任意多边形
  76. # print(x0text,y0text,x1text,y1text,x2text,y2text,x3text,y3text,className,difficulttext)
  77. wf.write("{} {} {} {} {} {} {} {} {} {}\n".format(x0text,y0text,x1text,y1text,x2text,y2text,x3text,y3text,className,difficulttext))
  78. # cv2.imshow("ddd",src)
  79. # cv2.waitKey()
  80. # 转换成四点坐标
  81. def rotatePoint(xc, yc, xp, yp, theta):
  82. xoff = xp - xc;
  83. yoff = yp - yc;
  84. cosTheta = math.cos(theta)
  85. sinTheta = math.sin(theta)
  86. pResx = cosTheta * xoff + sinTheta * yoff
  87. pResy = - sinTheta * xoff + cosTheta * yoff
  88. return str(int(xc + pResx)), str(int(yc + pResy))
  89. if __name__ == '__main__':
  90. dir = r"H:\duocicaiji\biaozhu_all"
  91. filelist = os.listdir(dir)
  92. for file in filelist:
  93. edit_xml(os.path.join(dir, file))

part2:数据集划分与预处理

这一步主要是将 整个数据集划分为训练集、验证集与测试集。

其文件结构如下所示:(我是将其划分80%, 10%, 10%)

datasets

        --train

                --images

                --labels

        --val

                --images

                --labels

        --test

                --images

下一步是将对数据进行裁剪 ,要将其裁剪为n x n大小的,主要利用的是官方项目中提供的裁剪代码。./mmrotate-0.3.0/tools/data/dota/split/img_split.py (裁剪脚本),该脚本通过读取

./mmrotate-0.3.0/tools/data/dota/split/split_configs 文件夹下的各个json文件中的参数设置来进行图像裁剪。我们需要修改其中的参数,让其加载上述的train、test、val中的图像及标签,并进行裁剪。

具体操作如下:(以train为例,val和test的操作相同)(其中ss_表示单一尺度裁剪,ms_表示多尺度裁剪)

修改split_configs文件夹下的ss_train.json文件

 修改好以上的参数之后,再修改img_split.py 中的base_json参数

然后直接运行 img_split.py就行。

之后对val、test的裁剪也是同理。

至此完成对图像的裁剪预处理。

part3:模型训练与测试

以训练Rotated FasterRCNN为例:

训练:

首先,下载模型的预训练权重

mmrotate/README_zh-CN.md at main · open-mmlab/mmrotate · GitHub

从这里找到相应的链接进行权重文件下载

其次,修改 ./configs/rotated_faster_rcnn/rotated_faster_rcnn_r50_fpn_1x_dota_le90.py

主要就是修改其中的num_classes参数,根据你自己的数据集修改类别个数。

在该文件下还要设置预训练权重的地址,修改为你下载的预训练权重地址。

 

同时,修改 ./mmrotate-0.3.0/mmrotate/datasets/dota.py 中的类别名称

还需要修改的是, ./configs/_base_/datasets/dotav1.py 文件

  1. # dataset settings
  2. dataset_type = 'DOTADataset'
  3. # 修改为你裁剪后数据集存放的路径
  4. data_root = 'H:/jlj/mmrotate-0.3.0/datasets/split_TL_896/'
  5. img_norm_cfg = dict(
  6. mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
  7. train_pipeline = [
  8. dict(type='LoadImageFromFile'),
  9. dict(type='LoadAnnotations', with_bbox=True),
  10. dict(type='RResize', img_scale=(1024, 1024)),
  11. dict(type='RRandomFlip', flip_ratio=0.5),
  12. dict(type='Normalize', **img_norm_cfg),
  13. dict(type='Pad', size_divisor=32),
  14. dict(type='DefaultFormatBundle'),
  15. dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels'])
  16. ]
  17. test_pipeline = [
  18. dict(type='LoadImageFromFile'),
  19. dict(
  20. type='MultiScaleFlipAug',
  21. img_scale=(1024, 1024),
  22. flip=False,
  23. transforms=[
  24. dict(type='RResize'),
  25. dict(type='Normalize', **img_norm_cfg),
  26. dict(type='Pad', size_divisor=32),
  27. dict(type='DefaultFormatBundle'),
  28. dict(type='Collect', keys=['img'])
  29. ])
  30. ]
  31. data = dict(
  32. # 设置的batch_size
  33. samples_per_gpu=2,
  34. # 设置的num_worker
  35. workers_per_gpu=2,
  36. train=dict(
  37. type=dataset_type,
  38. ann_file=data_root + 'train/annfiles/',
  39. img_prefix=data_root + 'train/images/',
  40. pipeline=train_pipeline),
  41. val=dict(
  42. type=dataset_type,
  43. ann_file=data_root + 'val/annfiles/',
  44. img_prefix=data_root + 'val/images/',
  45. pipeline=test_pipeline),
  46. test=dict(
  47. type=dataset_type,
  48. ann_file=data_root + 'test/images/',
  49. img_prefix=data_root + 'test/images/',
  50. pipeline=test_pipeline))

还有  ./configs/_base_/schedules/schedule_1x.py 中

  1. # evaluation
  2. evaluation = dict(interval=5, metric='mAP') # 训练多少轮评估一次
  3. # optimizer
  4. optimizer = dict(type='SGD', lr=0.0025, momentum=0.9, weight_decay=0.0001)
  5. optimizer_config = dict(grad_clip=dict(max_norm=35, norm_type=2))
  6. # learning policy
  7. lr_config = dict(
  8. policy='step',
  9. warmup='linear',
  10. warmup_iters=500,
  11. warmup_ratio=1.0 / 3,
  12. step=[8, 11])
  13. runner = dict(type='EpochBasedRunner', max_epochs=100) # 训练的总次数
  14. checkpoint_config = dict(interval=10) # 训练多少次后保存模型

还有 ./configs/_base_/default_runtime.py

  1. # yapf:disable
  2. log_config = dict(
  3. interval=50, # 训练多少iter后打印输出训练日志
  4. hooks=[
  5. dict(type='TextLoggerHook'),
  6. # dict(type='TensorboardLoggerHook')
  7. ])
  8. # yapf:enable
  9. dist_params = dict(backend='nccl')
  10. log_level = 'INFO'
  11. load_from = None
  12. resume_from = None
  13. workflow = [('train', 1)]
  14. # disable opencv multithreading to avoid system being overloaded
  15. opencv_num_threads = 0
  16. # set multi-process start method as `fork` to speed up the training
  17. mp_start_method = 'fork'

最后,修改 train.py

主要有两个参数: - -config: 使用的模型文件 (我使用的是 faster rcnn) ; - -work-dir:训练得到的模型及配置信息保存的路径。

 一切都配置完毕后,运行 train.py 即可。

预测:

预测的话,修改 test.py 中的路径参数即可。

主要有三个参数: - -config: 使用的模型文件 ; - -checkpoint:训练得到的模型权重文件; --show-dir: 预测结果存放的路径。

测试效果: 

 参考博文:

基于MMRotate训练自定义数据集 做旋转目标检测 2022-3-30_YD-阿三的博客-CSDN博客_旋转目标检测数据集

 【扫盲】MMRotate旋转目标检测训练_哔哩哔哩_bilibili

https://github.com/open-mmlab/mmrotate

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/IT小白/article/detail/722318
推荐阅读
相关标签
  

闽ICP备14008679号