当前位置:   article > 正文

mmsegmentation框架SegFormer训练自己的数据集_segformer训练自己数据集

segformer训练自己数据集

        主要为记录自己学习实践mmsegmentation框架的过程,并顺便为一起学习的同学们提供参考,分享一下自己学习到的一些知识和所踩的坑,与大家共勉!

        我个人主要是想要使用mmsegmentation框架训练自己的数据集,一开始跟着网上的教程使用了PspNet网络,但是可能由于数据集过小最后达到的效果不尽人意,因此考虑使用更新的、性能更好的SegFormer进行尝试,也是看到了SegFormer在各种数据集上的准确率都相较传统的神经网络有了较大提升,所以比较心动。

                                                SegFormer在ADE20K数据集上的表现 

        那么让我们现在开始吧(这里默认大家都配置好mmsegmentation了):

        首先对自己的数据集进行处理,我比较习惯于处理voc类型的数据集,因此这里主要介绍voc类型数据集的处理结构:

  1. -------ImageSets
  2. -----------Segmentation
  3. ----------------train.txt #训练集图片的文件名
  4. ----------------trainval.txt #训练验证集图片的文件名
  5. ----------------val.txt #验证集图片的文件名
  6. -------JPEGImages #存放训练与测试的所有图片文件
  7. -------SegmentationClass #存放图像分割结果图

         然后是部署我们自己的配置文件,由于mmsegmentation的SegFormer并没有针对voc数据集的配置文件,因此需要我们自己对其进行修改以适配voc类型数据集

一、首先修改mmseg\.mim\configs\_base_\datasets\pascal_voc12.py文件(建议把mmseg文件夹复制到自己的项目文件夹下,以便于修改)

  1. dataset_type = 'PascalVOCDataset'
  2. data_root = 'data/VOCdevkit/VOC2012' #修改为自己数据集的路径,推荐使用绝对路径
  1. dict(
  2. type='MultiScaleFlipAug',
  3. # img_scale=(2048, 512),
  4. img_scale=(640, 640), #这里的图片大小按照自己数据集的图片进行修改
  5. # img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75],
  6. flip=False,
  7. transforms=[
  8. dict(type='Resize', keep_ratio=True),
  9. dict(type='RandomFlip'),
  10. dict(type='Normalize', **img_norm_cfg),
  11. dict(type='ImageToTensor', keys=['img']),
  12. dict(type='Collect', keys=['img']),
  13. ])

二、然后修改mmseg\datasets\voc.py

        主要将类别修改为自己的数据集类别以及想要为分割的各类别显示的颜色

  1. CLASSES = ('sky', 'tree', 'road', 'grass', 'background') #写你实际的类别名就好了,最后再加上一个background
  2. PALETTE = [[128, 128, 128], [129, 127, 38], [120, 69, 125], [53, 125, 34],
  3. [0, 11, 123]] #数量与类别数相对应

三、接着修改_base_\models\segformer.py(没有则创建一个)

  1. # model settings
  2. norm_cfg = dict(type='BN', requires_grad=True) # 单卡改为BN
  3. find_unused_parameters = True
  4. model = dict(
  5. type='EncoderDecoder',
  6. pretrained=True,
  7. backbone=dict(
  8. type='MixVisionTransformer',
  9. in_channels=3,
  10. embed_dims=32,
  11. num_stages=4,
  12. num_layers=[2, 2, 2, 2],
  13. num_heads=[1, 2, 5, 8],
  14. patch_sizes=[7, 3, 3, 3],
  15. sr_ratios=[8, 4, 2, 1],
  16. out_indices=(0, 1, 2, 3),
  17. mlp_ratio=4,
  18. qkv_bias=True,
  19. drop_rate=0.0,
  20. attn_drop_rate=0.0,
  21. drop_path_rate=0.1),
  22. decode_head=dict(
  23. type='SegformerHead',
  24. in_channels=[32, 64, 160, 256],
  25. in_index=[0, 1, 2, 3],
  26. channels=256,
  27. dropout_ratio=0.1,
  28. num_classes=2, # 与数据集类别数量相同
  29. norm_cfg=norm_cfg,
  30. align_corners=False,
  31. loss_decode=dict(type='FocalLoss', use_sigmoid=True, loss_weight=1.0)), # focal loss使用更多
  32. # model training and testing settings
  33. train_cfg=dict(),
  34. test_cfg=dict(mode='whole'))

 四、再创建总体配置文件

        我这里将该文件拷贝到了项目文件夹中了,包括_base_文件夹,便于路径读取和修改,创建segformer_mit-b5.py总配置文件,然后更改继承的数据集类型:

  1. _base_ = [
  2. './_base_/models/segformer.py',
  3. './_base_/datasets/pascal_voc12_aug.py',
  4. './_base_/default_runtime.py',
  5. './_base_/schedules/schedule_160k.py'
  6. ]
  7. # model settings
  8. norm_cfg = dict(type='BN', requires_grad=True) # 单卡BN
  9. find_unused_parameters = True
  10. model = dict(
  11. type='EncoderDecoder',
  12. pretrained='mit_b5.pth', # 配置好pth路径
  13. backbone=dict(
  14. type='MixVisionTransformer',
  15. in_channels=3,
  16. embed_dims=32,
  17. num_stages=4,
  18. num_layers=[2, 2, 2, 2],
  19. num_heads=[1, 2, 5, 8],
  20. patch_sizes=[7, 3, 3, 3],
  21. sr_ratios=[8, 4, 2, 1],
  22. out_indices=(0, 1, 2, 3),
  23. mlp_ratio=4,
  24. qkv_bias=True,
  25. drop_rate=0.0,
  26. attn_drop_rate=0.0,
  27. drop_path_rate=0.1),
  28. decode_head=dict(
  29. type='SegformerHead',
  30. in_channels=[32, 64, 160, 256],
  31. in_index=[0, 1, 2, 3],
  32. channels=256,
  33. dropout_ratio=0.1,
  34. num_classes=2,
  35. norm_cfg=norm_cfg,
  36. align_corners=False,
  37. loss_decode=dict(type='FocalLoss', use_sigmoid=True, loss_weight=1.0)),
  38. # model training and testing settings
  39. train_cfg=dict(),
  40. test_cfg=dict(mode='whole'))
  41. # optimizer
  42. optimizer = dict(_delete_=True, type='AdamW', lr=0.00006, betas=(0.9, 0.999), weight_decay=0.01,
  43. paramwise_cfg=dict(custom_keys={'pos_block': dict(decay_mult=0.),
  44. 'norm': dict(decay_mult=0.),
  45. 'head': dict(lr_mult=10.)
  46. }))
  47. lr_config = dict(_delete_=True, policy='poly',
  48. warmup='linear',
  49. warmup_iters=1500,
  50. warmup_ratio=1e-6,
  51. power=1.0, min_lr=0.0, by_epoch=False)
  52. evaluation = dict(interval=16000, metric='mIoU')

五、下载对应的预训练模型 

        直接上链接:

        链接:https://pan.baidu.com/s/1c-d5ghbVyLWqDvylJ24VSw?pwd=2023 
        提取码:2023 

         由于SegFormer官方并没有针对voc数据集进行预训练,只能使用ade2k数据集的预训练模型进行训练

 六、train文件构建

        直接上代码吧,亲测可运行

  1. # Let's take a look at the dataset
  2. import mmcv
  3. import matplotlib.pyplot as plt
  4. import os.path as osp
  5. import numpy as np
  6. from PIL import Image
  7. from mmseg.datasets import build_dataset
  8. from mmseg.models import build_segmentor
  9. from mmseg.apis import train_segmentor
  10. from mmseg.datasets.builder import DATASETS
  11. from mmseg.datasets.custom import CustomDataset
  12. from mmcv import Config
  13. from mmseg.apis import set_random_seed
  14. # convert dataset annotation to semantic segmentation map
  15. data_root = 'test_yinzhang'
  16. img_dir = 'JPEGImages'
  17. ann_dir = 'SegmentationClass'
  18. # define class and plaette for better visualization
  19. classes = ('background','yinzhang')
  20. palette = [[0,0,0],[255, 0, 0]]
  21. for file in mmcv.scandir(osp.join(data_root, ann_dir), suffix='.regions.txt'):
  22. seg_map = np.loadtxt(osp.join(data_root, ann_dir, file)).astype(np.uint8)
  23. seg_img = Image.fromarray(seg_map).convert('P')
  24. seg_img.putpalette(np.array(palette, dtype=np.uint8))
  25. seg_img.save(osp.join(data_root, ann_dir, file.replace('.regions.txt',
  26. '.png')))
  27. # Let's take a look at the segmentation map we got
  28. import matplotlib.patches as mpatches
  29. # split train/val set randomly
  30. split_dir = 'ImageSets/Segmentation'
  31. mmcv.mkdir_or_exist(osp.join(data_root, split_dir))
  32. filename_list = [osp.splitext(filename)[0] for filename in mmcv.scandir(
  33. osp.join(data_root, ann_dir), suffix='.png')]
  34. with open(osp.join(data_root, split_dir, 'train.txt'), 'w') as f:
  35. # select first 4/5 as train set
  36. train_length = int(len(filename_list)*4/5)
  37. f.writelines(line + '\n' for line in filename_list[:train_length])
  38. with open(osp.join(data_root, split_dir, 'val.txt'), 'w') as f:
  39. # select last 1/5 as train set
  40. f.writelines(line + '\n' for line in filename_list[train_length:])
  41. @DATASETS.register_module()
  42. class MyDataset(CustomDataset):
  43. CLASSES = classes
  44. PALETTE = palette
  45. def __init__(self, split, **kwargs):
  46. super().__init__(img_suffix='.jpg', seg_map_suffix='.png',
  47. split=split, **kwargs)
  48. assert osp.exists(self.img_dir) and self.split is not None
  49. cfg = Config.fromfile('segformer_mit-b5.py')
  50. # add CLASSES and PALETTE to checkpoint
  51. cfg.checkpoint_config.meta = dict(CLASSES = classes, PALETTE = palette)
  52. # Since we use only one GPU, BN is used instead of SyncBN
  53. cfg.norm_cfg = dict(type='BN', requires_grad=True)
  54. # cfg.model.backbone.norm_cfg = cfg.norm_cfg
  55. cfg.model.decode_head.norm_cfg = cfg.norm_cfg
  56. # cfg.model.auxiliary_head.norm_cfg = cfg.norm_cfg
  57. # modify num classes of the model in decode/auxiliary head
  58. cfg.model.decode_head.num_classes = 8
  59. # cfg.model.auxiliary_head.num_classes = 8
  60. # Modify dataset type and path
  61. cfg.dataset_type = 'PascalVOCDataset'
  62. cfg.data_root = data_root
  63. cfg.data.samples_per_gpu = 2
  64. cfg.data.workers_per_gpu = 2
  65. cfg.img_norm_cfg = dict(
  66. mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
  67. cfg.crop_size = (512, 512)
  68. cfg.train_pipeline = [
  69. dict(type='LoadImageFromFile'),
  70. dict(type='LoadAnnotations'),
  71. dict(type='Resize', img_scale=(640, 640), ratio_range=(0.5, 2.0)),
  72. dict(type='RandomCrop', crop_size=cfg.crop_size, cat_max_ratio=0.75),
  73. dict(type='RandomFlip', prob=0.5),
  74. dict(type='PhotoMetricDistortion'),
  75. dict(type='Normalize', **cfg.img_norm_cfg),
  76. dict(type='Pad', size=cfg.crop_size, pad_val=0, seg_pad_val=255),
  77. dict(type='DefaultFormatBundle'),
  78. dict(type='Collect', keys=['img', 'gt_semantic_seg']),
  79. ]
  80. cfg.test_pipeline = [
  81. dict(type='LoadImageFromFile'),
  82. dict(
  83. type='MultiScaleFlipAug',
  84. # img_scale=(2048, 512),
  85. img_scale=(640, 640),
  86. # img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75],
  87. flip=False,
  88. transforms=[
  89. dict(type='Resize', keep_ratio=True),
  90. dict(type='RandomFlip'),
  91. dict(type='Normalize', **cfg.img_norm_cfg),
  92. dict(type='ImageToTensor', keys=['img']),
  93. dict(type='Collect', keys=['img']),
  94. ])
  95. ]
  96. cfg.data.train.type = cfg.dataset_type
  97. cfg.data.train.data_root = cfg.data_root
  98. cfg.data.train.img_dir = img_dir
  99. cfg.data.train.ann_dir = ann_dir
  100. cfg.data.train.pipeline = cfg.train_pipeline
  101. cfg.data.train.split = 'ImageSets/Segmentation/train.txt'
  102. cfg.data.val.type = cfg.dataset_type
  103. cfg.data.val.data_root = cfg.data_root
  104. cfg.data.val.img_dir = img_dir
  105. cfg.data.val.ann_dir = ann_dir
  106. cfg.data.val.pipeline = cfg.test_pipeline
  107. cfg.data.val.split = 'ImageSets/Segmentation/val.txt'
  108. cfg.data.test.type = cfg.dataset_type
  109. cfg.data.test.data_root = cfg.data_root
  110. cfg.data.test.img_dir = img_dir
  111. cfg.data.test.ann_dir = ann_dir
  112. cfg.data.test.pipeline = cfg.test_pipeline
  113. cfg.data.test.split = 'ImageSets/Segmentation/val.txt'
  114. # We can still use the pre-trained Mask RCNN model though we do not need to
  115. # use the mask branch
  116. cfg.load_from = 'mit_b5.pth'
  117. # Set up working dir to save files and logs.
  118. cfg.work_dir = './work_dirs/new/tutorial'
  119. cfg.runner.max_iters = 3000
  120. cfg.log_config.interval = 100
  121. cfg.evaluation.interval = 1000
  122. cfg.checkpoint_config.interval = 1000
  123. # Set seed to facitate reproducing the result
  124. cfg.seed = 0
  125. set_random_seed(0, deterministic=False)
  126. cfg.gpu_ids = range(1)
  127. # Let's have a look at the final config used for training
  128. print(f'Config:\n{cfg.pretty_text}')
  129. # Build the dataset
  130. datasets = [build_dataset(cfg.data.train)]
  131. # Build the detector
  132. model = build_segmentor(cfg.model)
  133. # Add an attribute for visualization convenience
  134. model.CLASSES = datasets[0].CLASSES
  135. # Create work_dir
  136. mmcv.mkdir_or_exist(osp.abspath(cfg.work_dir))
  137. if __name__=="__main__":
  138. train_segmentor(model, datasets, cfg, distributed=False, validate=True,
  139. meta=dict())

 七、预测代码:

        大家可以自己尝试一下

  1. import mmcv
  2. import os.path as osp
  3. from PIL import Image
  4. import matplotlib.pyplot as plt
  5. import matplotlib.patches as mpatches
  6. import numpy as np
  7. from mmseg.datasets.builder import DATASETS
  8. from mmseg.datasets.custom import CustomDataset
  9. from mmcv import Config
  10. from mmseg.apis import set_random_seed
  11. from mmseg.datasets import build_dataset
  12. from mmseg.models import build_segmentor
  13. from mmseg.apis import train_segmentor, inference_segmentor, init_segmentor, show_result_pyplot
  14. import os
  15. import cv2
  16. import warnings
  17. os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'
  18. data_root = 'test_yinzhang'
  19. img_dir = 'JPEGImages'
  20. ann_dir = 'SegmentationClass'
  21. # define class and plaette for better visualization
  22. classes = ('background','yinzhang')
  23. palette = [[0,0,0],[255, 0, 0]]
  24. @DATASETS.register_module()
  25. class StanfordBackgroundDataset(CustomDataset):
  26. CLASSES = classes
  27. PALETTE = palette
  28. def show_result(self,
  29. img,
  30. result,
  31. palette=None,
  32. win_name='',
  33. show=False,
  34. wait_time=0,
  35. out_file=None,
  36. opacity=0.5):
  37. img = mmcv.imread(img)
  38. img = img.copy()
  39. seg = result[0]
  40. if palette is None:
  41. if self.PALETTE is None:
  42. palette = np.random.randint(
  43. 0, 255, size=(len(self.CLASSES), 3))
  44. else:
  45. palette = self.PALETTE
  46. palette = np.array(palette)
  47. assert palette.shape[0] == len(self.CLASSES)
  48. assert palette.shape[1] == 3
  49. assert len(palette.shape) == 2
  50. assert 0 < opacity <= 1.0
  51. color_seg = np.zeros((seg.shape[0], seg.shape[1], 3), dtype=np.uint8)
  52. for label, color in enumerate(palette):
  53. color_seg[seg == label, :] = color
  54. # convert to BGR
  55. color_seg = color_seg[..., ::-1]
  56. img = img * (1 - opacity) + color_seg * opacity
  57. img = img.astype(np.uint8)
  58. # if out_file specified, do not show image in window
  59. if out_file is not None:
  60. show = False
  61. if show:
  62. mmcv.imshow(img, win_name, wait_time)
  63. if out_file is not None:
  64. mmcv.imwrite(img, out_file) # 可在此处进行修改
  65. if not (show or out_file):
  66. warnings.warn('show==False and out_file is not specified, only '
  67. 'result image will be returned')
  68. return img
  69. def __init__(self, split, **kwargs):
  70. super().__init__(img_suffix='.jpg', seg_map_suffix='.png',
  71. split=split, **kwargs)
  72. assert osp.exists(self.img_dir) and self.split is not None
  73. ############################################################################################
  74. cfg = Config.fromfile('segformer_mit-b5.pyy')
  75. # Since we use only one GPU, BN is used instead of SyncBN
  76. cfg.norm_cfg = dict(type='BN', requires_grad=True)
  77. # cfg.model.backbone.norm_cfg = cfg.norm_cfg
  78. cfg.model.decode_head.norm_cfg = cfg.norm_cfg
  79. # cfg.model.auxiliary_head.norm_cfg = cfg.norm_cfg
  80. # modify num classes of the model in decode/auxiliary head
  81. cfg.model.decode_head.num_classes = 8
  82. # cfg.model.auxiliary_head.num_classes = 8
  83. # Modify dataset type and path
  84. cfg.dataset_type = 'StanfordBackgroundDataset'
  85. cfg.data_root = data_root
  86. cfg.data.samples_per_gpu = 2
  87. cfg.data.workers_per_gpu = 0
  88. cfg.img_norm_cfg = dict(
  89. mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
  90. cfg.crop_size = (256, 256)
  91. cfg.train_pipeline = [
  92. dict(type='LoadImageFromFile'),
  93. dict(type='LoadAnnotations'),
  94. dict(type='Resize', img_scale=(320, 240), ratio_range=(0.5, 2.0)),
  95. dict(type='RandomCrop', crop_size=cfg.crop_size, cat_max_ratio=0.75),
  96. dict(type='RandomFlip', flip_ratio=0.5),
  97. dict(type='PhotoMetricDistortion'),
  98. dict(type='Normalize', **cfg.img_norm_cfg),
  99. dict(type='Pad', size=cfg.crop_size, pad_val=0, seg_pad_val=255),
  100. dict(type='DefaultFormatBundle'),
  101. dict(type='Collect', keys=['img', 'gt_semantic_seg']),
  102. ]
  103. cfg.test_pipeline = [
  104. dict(type='LoadImageFromFile'),
  105. dict(
  106. type='MultiScaleFlipAug',
  107. img_scale=(320, 240),
  108. # img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75],
  109. flip=False,
  110. transforms=[
  111. dict(type='Resize', keep_ratio=True),
  112. dict(type='RandomFlip'),
  113. dict(type='Normalize', **cfg.img_norm_cfg),
  114. dict(type='ImageToTensor', keys=['img']),
  115. dict(type='Collect', keys=['img']),
  116. ])
  117. ]
  118. cfg.data.train.type = cfg.dataset_type
  119. cfg.data.train.data_root = cfg.data_root
  120. cfg.data.train.img_dir = img_dir
  121. cfg.data.train.ann_dir = ann_dir
  122. cfg.data.train.pipeline = cfg.train_pipeline
  123. cfg.data.train.split = 'splits/train.txt'
  124. cfg.data.val.type = cfg.dataset_type
  125. cfg.data.val.data_root = cfg.data_root
  126. cfg.data.val.img_dir = img_dir
  127. cfg.data.val.ann_dir = ann_dir
  128. cfg.data.val.pipeline = cfg.test_pipeline
  129. cfg.data.val.split = 'splits/val.txt'
  130. cfg.data.test.type = cfg.dataset_type
  131. cfg.data.test.data_root = cfg.data_root
  132. cfg.data.test.img_dir = img_dir
  133. cfg.data.test.ann_dir = ann_dir
  134. cfg.data.test.pipeline = cfg.test_pipeline
  135. cfg.data.test.split = 'splits/val.txt'
  136. # We can still use the pre-trained Mask RCNN model though we do not need to
  137. # use the mask branch
  138. cfg.load_from = 'pspnet_r50-d8_512x1024_40k_cityscapes_20200605_003338-2966598c.pth'
  139. # Set up working dir to save files and logs.
  140. cfg.work_dir = './work_dirs'
  141. cfg.runner.max_iters = 200
  142. cfg.log_config.interval = 10
  143. cfg.evaluation.interval = 200
  144. cfg.checkpoint_config.interval = 200
  145. # Set seed to facitate reproducing the result
  146. cfg.seed = 0
  147. set_random_seed(0, deterministic=False)
  148. cfg.gpu_ids = range(1)
  149. # Let's have a look at the final config used for training
  150. print(f'Config:\n{cfg.pretty_text}')
  151. config_file = cfg
  152. checkpoints_file = './work_dirs/tutorial/latest.pth'
  153. model = init_segmentor(config_file, checkpoints_file, device='cuda:0')
  154. img = './input/test2.jpg'
  155. print(img)
  156. '''
  157. show_result(self,
  158. img,
  159. result,
  160. palette=None,
  161. win_name='',
  162. show=False,
  163. wait_time=0,
  164. out_file=None,
  165. opacity=0.5)
  166. '''
  167. result = inference_segmentor(model, img)
  168. plt.figure(figsize=(8, 6))
  169. show_result_pyplot(model, img, result, palette)
  170. model.show_result(img, result, show=True)
  171. #result = cv2.cvtColor(result, cv2.COLOR_RGB2BGR)
  172. #result[:, :] = (255,255,255)
  173. #result = img+result
  174. cv2.imwrite('./output/re4.jpg', result[0]*255)

八、训练预测结果

        由于我还没在autodl进行大规模训练,就给大家展示一下在自己机子上浅跑的预测结果吧~

        原图像

        预测结果

 后续大规模训练后会继续更新。。。

坑1:libpng warning: iCCP: known incorrect sRGB profile报错

原因是新版的libpng增强了ICC profiles检查,发出警告。此警告可以忽略 ,我在此也没有对其进行操作,可以使用其他方法(如skimage)读取的方式避免该类报错。

 坑2:ValueError: expected 4D input (got 3D input)报错

        这是一个困扰我许久的问题

        通过上网才发现问题所在是因为使用了不正确的BatchNorm函数,快速解决的方法就是不需要在模型的backbone添加 'norm_cfg' 

 未完待续。。。

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/weixin_40725706/article/detail/345183
推荐阅读
相关标签
  

闽ICP备14008679号