赞
踩
论文:https://arxiv.org/abs/1912.04488
代码:https://github.com/WXinlong/SOLO
小胖墩是用conda配置的训练环境,用一下几行命令即可配置成功:
- conda create -n solo python=3.7 -y
- conda activate solo
-
- conda install -c pytorch pytorch torchvision -y
- conda install cython -y
- git clone https://github.com/WXinlong/SOLO.git
- cd SOLO
- pip install -r requirements/build.txt
- pip install "git+https://github.com/cocodataset/cocoapi.git#subdirectory=PythonAPI"
- pip install -v -e .
当然,如果你想用其他方式配置训练环境,参照官网。
建议将自己的数据集准备成coco的数据格式,这样代码的改用量会很少。
关于coco数据集的格式,请参考我的博客:https://blog.csdn.net/Guo_Python/article/details/105839280
1. 注册一下自己的数据集
在mmdet/datasets/ 目录下创建Your_dataset.py 文件,内容如下(继承了CocoDataset):
- from .coco import CocoDataset
- from .registry import DATASETS
-
- #add new dataset
- @DATASETS.register_module
- class Your_Dataset(CocoDataset):
- CLASSES = ['people', 'dog', 'cat']
-
在mmdet/datasets/__init__.py 将该数据格式添加进去,修改后的__init__.py如下:
- from .builder import build_dataset
- from .cityscapes import CityscapesDataset
- from .coco import CocoDataset
- from .custom import CustomDataset
- from .dataset_wrappers import ConcatDataset, RepeatDataset
- from .loader import DistributedGroupSampler, GroupSampler, build_dataloader
- from .registry import DATASETS
- from .voc import VOCDataset
- from .wider_face import WIDERFaceDataset
- from .xml_style import XMLDataset
- from .my_dataset import MyDataset
- from .Your_dataset import Your_Dataset
-
- __all__ = [
- 'CustomDataset', 'XMLDataset', 'CocoDataset', 'VOCDataset',
- 'CityscapesDataset', 'GroupSampler', 'DistributedGroupSampler',
- 'build_dataloader', 'ConcatDataset', 'RepeatDataset', 'WIDERFaceDataset',
- 'DATASETS', 'build_dataset', 'MyDataset', 'Your_Dataset'
- ]

2. 修改配置文件
configs/solo/solo_r50_fpn_8gpu_3x.py, 修改后的内容如下:主要修改数据路径,训练尺寸和检测类别。
- # model settings
- model = dict(
- type='SOLO',
- pretrained='torchvision://resnet50',
- backbone=dict(
- type='ResNet',
- depth=50,
- num_stages=4,
- out_indices=(0, 1, 2, 3), # C2, C3, C4, C5
- frozen_stages=1,
- style='pytorch'),
- neck=dict(
- type='FPN',
- in_channels=[256, 512, 1024, 2048],
- out_channels=256,
- start_level=0,
- num_outs=5),
- bbox_head=dict(
- type='SOLOHead',
- num_classes=4, # 修改类别,种类+背景
- in_channels=256,
- stacked_convs=4,
- seg_feat_channels=256,
- strides=[8, 8, 16, 32, 32],
- scale_ranges=((1, 96), (48, 192), (96, 384), (192, 768), (384, 2048)),
- sigma=0.2,
- num_grids=[40, 36, 24, 16, 12],
- cate_down_pos=0,
- with_deform=False,
- loss_ins=dict(
- type='DiceLoss',
- use_sigmoid=True,
- loss_weight=3.0),
- loss_cate=dict(
- type='FocalLoss',
- use_sigmoid=True,
- gamma=2.0,
- alpha=0.25,
- loss_weight=1.0),
- ))
- # training and testing settings
- train_cfg = dict()
- test_cfg = dict(
- nms_pre=500,
- score_thr=0.1,
- mask_thr=0.5,
- update_thr=0.05,
- kernel='gaussian', # gaussian/linear
- sigma=2.0,
- max_per_img=100)
- # dataset settings
- dataset_type = 'Your_Dataset' # 修改数据格式
- data_root = '/home/gp/dukto/Data/' # 修改数据路径
- img_norm_cfg = dict(
- mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
- train_pipeline = [
- dict(type='LoadImageFromFile'),
- dict(type='LoadAnnotations', with_bbox=True, with_mask=True),
- dict(type='Resize',
- # 修改图片尺寸
- img_scale=[(640, 480), (640, 420), (640, 400),
- (640, 360), (640, 320), (640, 300)],
- multiscale_mode='value',
- keep_ratio=True),
- dict(type='RandomFlip', flip_ratio=0.5),
- dict(type='Normalize', **img_norm_cfg),
- dict(type='Pad', size_divisor=32),
- dict(type='DefaultFormatBundle'),
- dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels', 'gt_masks']),
- ]
- test_pipeline = [
- dict(type='LoadImageFromFile'),
- dict(
- type='MultiScaleFlipAug',
- img_scale=(640, 480), # 修改图片尺寸
- flip=False,
- transforms=[
- dict(type='Resize', keep_ratio=True),
- dict(type='RandomFlip'),
- dict(type='Normalize', **img_norm_cfg),
- dict(type='Pad', size_divisor=32),
- dict(type='ImageToTensor', keys=['img']),
- dict(type='Collect', keys=['img']),
- ])
- ]
- data = dict(
- imgs_per_gpu=4,
- workers_per_gpu=2,
- train=dict(
- type=dataset_type,
- ann_file=data_root + 'train.json',
- img_prefix=data_root + 'images/',
- pipeline=train_pipeline),
- val=dict(
- type=dataset_type,
- ann_file=data_root + 'annotations/test.json',
- img_prefix=data_root + 'images/',
- pipeline=test_pipeline),
- test=dict(
- type=dataset_type,
- ann_file=data_root + 'annotations/test.json',
- img_prefix=data_root + 'images/',
- pipeline=test_pipeline))
- # optimizer
- optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0001)
- optimizer_config = dict(grad_clip=dict(max_norm=35, norm_type=2))
- # learning policy
- lr_config = dict(
- policy='step',
- warmup='linear',
- warmup_iters=500,
- warmup_ratio=1.0 / 3,
- step=[27, 33])
- checkpoint_config = dict(interval=1)
- # yapf:disable
- log_config = dict(
- interval=50,
- hooks=[
- dict(type='TextLoggerHook'),
- # dict(type='TensorboardLoggerHook')
- ])
- # yapf:enable
- # runtime settings
- total_epochs = 36
- device_ids = range(8)
- dist_params = dict(backend='nccl')
- log_level = 'INFO'
- work_dir = './work_dirs/solo_release_r50_fpn_3x' # 模型和训练日志存放地址
- load_from = None
- resume_from = None
- workflow = [('train', 1)]

训练命令:
python tools/train.py configs/solo/solo_r50_fpn_8gpu_3x.py
不出意外,你应该开始训练你的模型了,训练日志如下:
- 2020-06-06 08:59:44,359 - mmdet - INFO - Start running, host: gp@gp-System-Product-Name, work_dir: /home/gp/work/project/SOLO/work_dirs/solo_release_r50_fpn_3x
- 2020-06-06 08:59:44,360 - mmdet - INFO - workflow: [('train', 1)], max: 36 epochs
- 2020-06-06 09:00:16,755 - mmdet - INFO - Epoch [1][50/1077] lr: 0.00399, eta: 6:58:04, time: 0.648, data_time: 0.012, memory: 4378, loss_ins: 2.9387, loss_cate: 0.8999, loss: 3.8386
- 2020-06-06 09:00:49,382 - mmdet - INFO - Epoch [1][100/1077] lr: 0.00465, eta: 6:59:03, time: 0.653, data_time: 0.009, memory: 4404, loss_ins: 2.9341, loss_cate: 0.7440, loss: 3.6781
- 2020-06-06 09:01:23,307 - mmdet - INFO - Epoch [1][150/1077] lr: 0.00532, eta: 7:04:35, time: 0.678, data_time: 0.009, memory: 4404, loss_ins: 2.9176, loss_cate: 0.6833, loss: 3.6009
- 2020-06-06 09:01:57,125 - mmdet - INFO - Epoch [1][200/1077] lr: 0.00599, eta: 7:06:44, time: 0.676, data_time: 0.008, memory: 4404, loss_ins: 2.8258, loss_cate: 0.6749, loss: 3.5007
- 2020-06-06 09:02:31,695 - mmdet - INFO - Epoch [1][250/1077] lr: 0.00665, eta: 7:09:43, time: 0.691, data_time: 0.009, memory: 4404, loss_ins: 2.6120, loss_cate: 0.6709, loss: 3.2829
- 2020-06-06 09:03:05,609 - mmdet - INFO - Epoch [1][300/1077] lr: 0.00732, eta: 7:10:07, time: 0.678, data_time: 0.008, memory: 4404, loss_ins: 2.3819, loss_cate: 0.6161, loss: 2.9980
- 2020-06-06 09:03:41,319 - mmdet - INFO - Epoch [1][350/1077] lr: 0.00799, eta: 7:13:31, time: 0.714, data_time: 0.009, memory: 4404, loss_ins: 2.2753, loss_cate: 0.5814, loss: 2.8567
- 2020-06-06 09:04:16,615 - mmdet - INFO - Epoch [1][400/1077] lr: 0.00865, eta: 7:15:16, time: 0.706, data_time: 0.009, memory: 4404, loss_ins: 2.1902, loss_cate: 0.5767, loss: 2.7669
- 2020-06-06 09:04:50,585 - mmdet - INFO - Epoch [1][450/1077] lr: 0.00932, eta: 7:14:37, time: 0.679, data_time: 0.009, memory: 4404, loss_ins: 2.0392, loss_cate: 0.5654, loss: 2.6046
- 2020-06-06 09:05:23,447 - mmdet - INFO - Epoch [1][500/1077] lr: 0.00999, eta: 7:12:34, time: 0.657, data_time: 0.008, memory: 4404, loss_ins: 1.9716, loss_cate: 0.5497, loss: 2.5212
- 2020-06-06 09:05:58,198 - mmdet - INFO - Epoch [1][550/1077] lr: 0.01000, eta: 7:12:59, time: 0.695, data_time: 0.008, memory: 4404, loss_ins: 1.8707, loss_cate: 0.5178, loss: 2.3885
- 2020-06-06 09:06:31,807 - mmdet - INFO - Epoch [1][600/1077] lr: 0.01000, eta: 7:12:01, time: 0.672, data_time: 0.008, memory: 4404, loss_ins: 1.9311, loss_cate: 0.5885, loss: 2.5196
- 2020-06-06 09:07:05,118 - mmdet - INFO - Epoch [1][650/1077] lr: 0.01000, eta: 7:10:49, time: 0.666, data_time: 0.009, memory: 4404, loss_ins: 1.8655, loss_cate: 0.5750, loss: 2.4404
- 2020-06-06 09:07:39,825 - mmdet - INFO - Epoch [1][700/1077] lr: 0.01000, eta: 7:10:59, time: 0.694, data_time: 0.009, memory: 4404, loss_ins: 1.7929, loss_cate: 0.5022, loss: 2.2951
- 2020-06-06 09:08:15,143 - mmdet - INFO - Epoch [1][750/1077] lr: 0.01000, eta: 7:11:34, time: 0.706, data_time: 0.009, memory: 4404, loss_ins: 1.8265, loss_cate: 0.5218, loss: 2.3482

loss变化曲线如下:
参照github的代码即可,我的测试结果如下:
end!!! 如有疑问,请留言!!!
前两天SOLO V2也开源了,训练和测试方法和SOLO V1完全相同,博主已经尝试过。
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。