当前位置:   article > 正文

SOLO实战——用自己的数据集训练实例分割模型_solo训练自己的数据集‘

solo训练自己的数据集‘

论文:https://arxiv.org/abs/1912.04488

代码:https://github.com/WXinlong/SOLO

1.配置训练环境

    小胖墩是用conda配置的训练环境,用一下几行命令即可配置成功:

  1. conda create -n solo python=3.7 -y
  2. conda activate solo
  3. conda install -c pytorch pytorch torchvision -y
  4. conda install cython -y
  5. git clone https://github.com/WXinlong/SOLO.git
  6. cd SOLO
  7. pip install -r requirements/build.txt
  8. pip install "git+https://github.com/cocodataset/cocoapi.git#subdirectory=PythonAPI"
  9. pip install -v -e .

       当然,如果你想用其他方式配置训练环境,参照官网

2.准备数据集

     建议将自己的数据集准备成coco的数据格式,这样代码的改用量会很少。

     关于coco数据集的格式,请参考我的博客:https://blog.csdn.net/Guo_Python/article/details/105839280

3. 修改代码

      1. 注册一下自己的数据集

          在mmdet/datasets/ 目录下创建Your_dataset.py 文件,内容如下(继承了CocoDataset):

  1. from .coco import CocoDataset
  2. from .registry import DATASETS
  3. #add new dataset
  4. @DATASETS.register_module
  5. class Your_Dataset(CocoDataset):
  6. CLASSES = ['people', 'dog', 'cat']

           在mmdet/datasets/__init__.py 将该数据格式添加进去,修改后的__init__.py如下:

  1. from .builder import build_dataset
  2. from .cityscapes import CityscapesDataset
  3. from .coco import CocoDataset
  4. from .custom import CustomDataset
  5. from .dataset_wrappers import ConcatDataset, RepeatDataset
  6. from .loader import DistributedGroupSampler, GroupSampler, build_dataloader
  7. from .registry import DATASETS
  8. from .voc import VOCDataset
  9. from .wider_face import WIDERFaceDataset
  10. from .xml_style import XMLDataset
  11. from .my_dataset import MyDataset
  12. from .Your_dataset import Your_Dataset
  13. __all__ = [
  14. 'CustomDataset', 'XMLDataset', 'CocoDataset', 'VOCDataset',
  15. 'CityscapesDataset', 'GroupSampler', 'DistributedGroupSampler',
  16. 'build_dataloader', 'ConcatDataset', 'RepeatDataset', 'WIDERFaceDataset',
  17. 'DATASETS', 'build_dataset', 'MyDataset', 'Your_Dataset'
  18. ]

       2. 修改配置文件 

           configs/solo/solo_r50_fpn_8gpu_3x.py, 修改后的内容如下:主要修改数据路径,训练尺寸和检测类别。

  1. # model settings
  2. model = dict(
  3. type='SOLO',
  4. pretrained='torchvision://resnet50',
  5. backbone=dict(
  6. type='ResNet',
  7. depth=50,
  8. num_stages=4,
  9. out_indices=(0, 1, 2, 3), # C2, C3, C4, C5
  10. frozen_stages=1,
  11. style='pytorch'),
  12. neck=dict(
  13. type='FPN',
  14. in_channels=[256, 512, 1024, 2048],
  15. out_channels=256,
  16. start_level=0,
  17. num_outs=5),
  18. bbox_head=dict(
  19. type='SOLOHead',
  20. num_classes=4, # 修改类别,种类+背景
  21. in_channels=256,
  22. stacked_convs=4,
  23. seg_feat_channels=256,
  24. strides=[8, 8, 16, 32, 32],
  25. scale_ranges=((1, 96), (48, 192), (96, 384), (192, 768), (384, 2048)),
  26. sigma=0.2,
  27. num_grids=[40, 36, 24, 16, 12],
  28. cate_down_pos=0,
  29. with_deform=False,
  30. loss_ins=dict(
  31. type='DiceLoss',
  32. use_sigmoid=True,
  33. loss_weight=3.0),
  34. loss_cate=dict(
  35. type='FocalLoss',
  36. use_sigmoid=True,
  37. gamma=2.0,
  38. alpha=0.25,
  39. loss_weight=1.0),
  40. ))
  41. # training and testing settings
  42. train_cfg = dict()
  43. test_cfg = dict(
  44. nms_pre=500,
  45. score_thr=0.1,
  46. mask_thr=0.5,
  47. update_thr=0.05,
  48. kernel='gaussian', # gaussian/linear
  49. sigma=2.0,
  50. max_per_img=100)
  51. # dataset settings
  52. dataset_type = 'Your_Dataset' # 修改数据格式
  53. data_root = '/home/gp/dukto/Data/' # 修改数据路径
  54. img_norm_cfg = dict(
  55. mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
  56. train_pipeline = [
  57. dict(type='LoadImageFromFile'),
  58. dict(type='LoadAnnotations', with_bbox=True, with_mask=True),
  59. dict(type='Resize',
  60. # 修改图片尺寸
  61. img_scale=[(640, 480), (640, 420), (640, 400),
  62. (640, 360), (640, 320), (640, 300)],
  63. multiscale_mode='value',
  64. keep_ratio=True),
  65. dict(type='RandomFlip', flip_ratio=0.5),
  66. dict(type='Normalize', **img_norm_cfg),
  67. dict(type='Pad', size_divisor=32),
  68. dict(type='DefaultFormatBundle'),
  69. dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels', 'gt_masks']),
  70. ]
  71. test_pipeline = [
  72. dict(type='LoadImageFromFile'),
  73. dict(
  74. type='MultiScaleFlipAug',
  75. img_scale=(640, 480), # 修改图片尺寸
  76. flip=False,
  77. transforms=[
  78. dict(type='Resize', keep_ratio=True),
  79. dict(type='RandomFlip'),
  80. dict(type='Normalize', **img_norm_cfg),
  81. dict(type='Pad', size_divisor=32),
  82. dict(type='ImageToTensor', keys=['img']),
  83. dict(type='Collect', keys=['img']),
  84. ])
  85. ]
  86. data = dict(
  87. imgs_per_gpu=4,
  88. workers_per_gpu=2,
  89. train=dict(
  90. type=dataset_type,
  91. ann_file=data_root + 'train.json',
  92. img_prefix=data_root + 'images/',
  93. pipeline=train_pipeline),
  94. val=dict(
  95. type=dataset_type,
  96. ann_file=data_root + 'annotations/test.json',
  97. img_prefix=data_root + 'images/',
  98. pipeline=test_pipeline),
  99. test=dict(
  100. type=dataset_type,
  101. ann_file=data_root + 'annotations/test.json',
  102. img_prefix=data_root + 'images/',
  103. pipeline=test_pipeline))
  104. # optimizer
  105. optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0001)
  106. optimizer_config = dict(grad_clip=dict(max_norm=35, norm_type=2))
  107. # learning policy
  108. lr_config = dict(
  109. policy='step',
  110. warmup='linear',
  111. warmup_iters=500,
  112. warmup_ratio=1.0 / 3,
  113. step=[27, 33])
  114. checkpoint_config = dict(interval=1)
  115. # yapf:disable
  116. log_config = dict(
  117. interval=50,
  118. hooks=[
  119. dict(type='TextLoggerHook'),
  120. # dict(type='TensorboardLoggerHook')
  121. ])
  122. # yapf:enable
  123. # runtime settings
  124. total_epochs = 36
  125. device_ids = range(8)
  126. dist_params = dict(backend='nccl')
  127. log_level = 'INFO'
  128. work_dir = './work_dirs/solo_release_r50_fpn_3x' # 模型和训练日志存放地址
  129. load_from = None
  130. resume_from = None
  131. workflow = [('train', 1)]

4. 训练模型

      训练命令:

python tools/train.py configs/solo/solo_r50_fpn_8gpu_3x.py

      不出意外,你应该开始训练你的模型了,训练日志如下:

  1. 2020-06-06 08:59:44,359 - mmdet - INFO - Start running, host: gp@gp-System-Product-Name, work_dir: /home/gp/work/project/SOLO/work_dirs/solo_release_r50_fpn_3x
  2. 2020-06-06 08:59:44,360 - mmdet - INFO - workflow: [('train', 1)], max: 36 epochs
  3. 2020-06-06 09:00:16,755 - mmdet - INFO - Epoch [1][50/1077] lr: 0.00399, eta: 6:58:04, time: 0.648, data_time: 0.012, memory: 4378, loss_ins: 2.9387, loss_cate: 0.8999, loss: 3.8386
  4. 2020-06-06 09:00:49,382 - mmdet - INFO - Epoch [1][100/1077] lr: 0.00465, eta: 6:59:03, time: 0.653, data_time: 0.009, memory: 4404, loss_ins: 2.9341, loss_cate: 0.7440, loss: 3.6781
  5. 2020-06-06 09:01:23,307 - mmdet - INFO - Epoch [1][150/1077] lr: 0.00532, eta: 7:04:35, time: 0.678, data_time: 0.009, memory: 4404, loss_ins: 2.9176, loss_cate: 0.6833, loss: 3.6009
  6. 2020-06-06 09:01:57,125 - mmdet - INFO - Epoch [1][200/1077] lr: 0.00599, eta: 7:06:44, time: 0.676, data_time: 0.008, memory: 4404, loss_ins: 2.8258, loss_cate: 0.6749, loss: 3.5007
  7. 2020-06-06 09:02:31,695 - mmdet - INFO - Epoch [1][250/1077] lr: 0.00665, eta: 7:09:43, time: 0.691, data_time: 0.009, memory: 4404, loss_ins: 2.6120, loss_cate: 0.6709, loss: 3.2829
  8. 2020-06-06 09:03:05,609 - mmdet - INFO - Epoch [1][300/1077] lr: 0.00732, eta: 7:10:07, time: 0.678, data_time: 0.008, memory: 4404, loss_ins: 2.3819, loss_cate: 0.6161, loss: 2.9980
  9. 2020-06-06 09:03:41,319 - mmdet - INFO - Epoch [1][350/1077] lr: 0.00799, eta: 7:13:31, time: 0.714, data_time: 0.009, memory: 4404, loss_ins: 2.2753, loss_cate: 0.5814, loss: 2.8567
  10. 2020-06-06 09:04:16,615 - mmdet - INFO - Epoch [1][400/1077] lr: 0.00865, eta: 7:15:16, time: 0.706, data_time: 0.009, memory: 4404, loss_ins: 2.1902, loss_cate: 0.5767, loss: 2.7669
  11. 2020-06-06 09:04:50,585 - mmdet - INFO - Epoch [1][450/1077] lr: 0.00932, eta: 7:14:37, time: 0.679, data_time: 0.009, memory: 4404, loss_ins: 2.0392, loss_cate: 0.5654, loss: 2.6046
  12. 2020-06-06 09:05:23,447 - mmdet - INFO - Epoch [1][500/1077] lr: 0.00999, eta: 7:12:34, time: 0.657, data_time: 0.008, memory: 4404, loss_ins: 1.9716, loss_cate: 0.5497, loss: 2.5212
  13. 2020-06-06 09:05:58,198 - mmdet - INFO - Epoch [1][550/1077] lr: 0.01000, eta: 7:12:59, time: 0.695, data_time: 0.008, memory: 4404, loss_ins: 1.8707, loss_cate: 0.5178, loss: 2.3885
  14. 2020-06-06 09:06:31,807 - mmdet - INFO - Epoch [1][600/1077] lr: 0.01000, eta: 7:12:01, time: 0.672, data_time: 0.008, memory: 4404, loss_ins: 1.9311, loss_cate: 0.5885, loss: 2.5196
  15. 2020-06-06 09:07:05,118 - mmdet - INFO - Epoch [1][650/1077] lr: 0.01000, eta: 7:10:49, time: 0.666, data_time: 0.009, memory: 4404, loss_ins: 1.8655, loss_cate: 0.5750, loss: 2.4404
  16. 2020-06-06 09:07:39,825 - mmdet - INFO - Epoch [1][700/1077] lr: 0.01000, eta: 7:10:59, time: 0.694, data_time: 0.009, memory: 4404, loss_ins: 1.7929, loss_cate: 0.5022, loss: 2.2951
  17. 2020-06-06 09:08:15,143 - mmdet - INFO - Epoch [1][750/1077] lr: 0.01000, eta: 7:11:34, time: 0.706, data_time: 0.009, memory: 4404, loss_ins: 1.8265, loss_cate: 0.5218, loss: 2.3482

       loss变化曲线如下:

Caption

5. 模型测试

       参照github的代码即可,我的测试结果如下:

 

Caption

end!!! 如有疑问,请留言!!!

前两天SOLO V2也开源了,训练和测试方法和SOLO V1完全相同,博主已经尝试过。

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/小丑西瓜9/article/detail/122381
推荐阅读
相关标签
  

闽ICP备14008679号