当前位置:   article > 正文

MMSegmention系列之四(自定义数据集与自定义数据增强管道)_samples_per_gpu

samples_per_gpu

1、自定义数据集

1、数据配置

data在 config 文件中是数据配置的变量,用于定义数据集和数据加载器中使用的参数。
下面是一个数据配置的例子:

data = dict(
    samples_per_gpu=4,
    workers_per_gpu=4,
    train=dict(
        type='ADE20KDataset',
        data_root='data/ade/ADEChallengeData2016',
        img_dir='images/training',
        ann_dir='annotations/training',
        pipeline=train_pipeline),
    val=dict(
        type='ADE20KDataset',
        data_root='data/ade/ADEChallengeData2016',
        img_dir='images/validation',
        ann_dir='annotations/validation',
        pipeline=test_pipeline),
    test=dict(
        type='ADE20KDataset',
        data_root='data/ade/ADEChallengeData2016',
        img_dir='images/validation',
        ann_dir='annotations/validation',
        pipeline=test_pipeline))
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21

1、train, val和test:通过使用构建和注册机制来构建数据集实例config,用于模型训练、验证和测试。build and registry
2、samples_per_gpu:模型训练时每个批次和每个gpu加载多少样本,训练的batch_size等于samples_per_gpu乘以gpu数量,例如使用8个gpu进行分布式数据并行训练,samples_per_gpu=2时,batch_size为8*2=16。如果您想定义batch_size用于测试和验证,请使用test_dataloaser和val_dataloader,并使用mmseg >=0.24.1。
3、workers_per_gpu:每个gpu用于数据加载的子进程数。0表示数据将在主进程中加载。
注意:samples_per_gpu仅用于模型训练,当模型测试和验证时,samples_per_gpu的默认设置为1 mmseg(暂不支持批推理)。

2、config.md _

Config类用于操作配置和配置文件。它支持从多种文件格式加载配置,包括python、json和yaml。它提供了类似 dict 的 API 来获取和设置值。

这是配置文件的示例test.py。
加载和使用配置

>>> cfg = Config.fromfile('test.py')
>>> print(cfg)
>>> dict(a=1,
...      b=dict(b1=[0, 1, 2], b2=None),
...      c=(1, 2),
...      d='string')
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6

对于所有格式配置,都支持一些预定义的变量。它会将变量转换{{ var }}为其实际值。

目前,它支持四个预定义变量:

{{ fileDirname }}- 当前打开文件的目录名,例如 /home/your-username/your-project/folder

{{ fileBasename }}- 当前打开文件的基本名称,例如 file.ext

{{ fileBasenameNoExtension }}- 当前打开的文件的基本名称,没有文件扩展名,例如 file

{{ fileExtname }}- 当前打开文件的扩展名,例如 .ext

这些变量名称来自VS Code。

这是一个带有预定义变量的配置示例。

config_a.py

a = 1
b = './work_dir/{{ fileBasenameNoExtension }}'
c = '{{ fileExtname }}'
>>> cfg = Config.fromfile('./config_a.py')
>>> print(cfg)
>>> dict(a=1,
...      b='./work_dir/config_a',
...      c='.py')
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10

对于所有格式配置,都支持继承。要重用其他配置文件中的字段,请指定_base_=‘./config_a.py’configs 列表_base_=[’./config_a.py’, ‘./config_b.py’]。以下是配置继承的 4 个示例

a = 1
b = dict(b1=[0, 1, 2], b2=None)
  • 1
  • 2

1、从基本配置继承,没有重叠的键

config_b.py

_base_ = './config_a.py'
c = (1, 2)
d = 'string'
>>> cfg = Config.fromfile('./config_b.py')
>>> print(cfg)
>>> dict(a=1,
...      b=dict(b1=[0, 1, 2], b2=None),
...      c=(1, 2),
...      d='string')
中的新字段config_b.py与中的旧字段相结合config_a.py
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12

2、从具有重叠键的基本配置继承

config_c.py

_base_ = './config_a.py'
b = dict(b2=1)
c = (1, 2)
>>> cfg = Config.fromfile('./config_c.py')
>>> print(cfg)
>>> dict(a=1,
...      b=dict(b1=[0, 1, 2], b2=1),
...      c=(1, 2))
b.b2=Noneinconfig_a替换为b.b2=1in config_c.py。
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11

3、从具有忽略字段的基本配置继承

config_d.py

_base_ = './config_a.py'
b = dict(_delete_=True, b2=None, b3=0.1)
c = (1, 2)
>>> cfg = Config.fromfile('./config_d.py')
>>> print(cfg)
>>> dict(a=1,
...      b=dict(b2=None, b3=0.1),
...      c=(1, 2))
您也可以设置_delete_=True忽略基本配置中的某些字段。所有旧钥匙b1, b2, b3都b被新钥匙取代b2, b3
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11

4、从多个基本配置继承(基本配置不应包含相同的键)

config_e.py

c = (1, 2)
d = 'string'
config_f.py

_base_ = ['./config_a.py', './config_e.py']
>>> cfg = Config.fromfile('./config_f.py')
>>> print(cfg)
>>> dict(a=1,
...      b=dict(b1=[0, 1, 2], b2=None),
...      c=(1, 2),
...      d='string')
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13

5、从基础引用变量

您可以使用以下语法引用 base 中定义的变量。

base.py

item1 = 'a'
item2 = dict(item3 = 'b')
config_g.py

_base_ = ['./base.py']
item = dict(a = {{ _base_.item1 }}, b = {{ _base_.item2.item3 }})
>>> cfg = Config.fromfile('./config_g.py')
>>> print(cfg.pretty_text)
item1 = 'a'
item2 = dict(item3='b')
item = dict(a='a', b='b')
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15

6、在配置中添加弃用信息

可以在配置文件中添加弃用信息,这将UserWarning在加载此配置文件时触发。

deprecated_cfg.py

_base_ = 'expected_cfg.py'

_deprecation_ = dict(
    expected = 'expected_cfg.py',  # optional to show expected config path in the warning information
    reference = 'url to related PR'  # optional to show reference link in the warning information
)
>>> cfg = Config.fromfile('./deprecated_cfg.py')

UserWarning: The config file deprecated_cfg.py will be deprecated in the future. Please use expected_cfg.py instead. More information can be
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13

3、build and registry

create hooks, runners, models, and datasets, through configs
要通过Registry管理代码库中的模块,有以下三个步骤。

  1. 创建一个构建方法(可选的,在大多数情况下你可以使用默认方法)。
  2. 创建一个registry.
  3. 3。使用此注册表管理模块。
    Registry的build_func参数用于自定义如何实例化类实例或如何调用函数来获得结果,这里实现的默认参数是build_from_cfg。
mmcv.utils.build_from_cfg(cfg: Dict, registry: mmcv.utils.registry.Registry, default_args: Optional[Dict] = None) → Any
  • 1

当它是一个类配置时,从配置字典构建一个模块,或者当它是一个函数配置时,从配置字典调用一个函数。

1、exmple

>>> MODELS = Registry('models')
>>> @MODELS.register_module()
>>> class ResNet:
>>>     pass
>>> resnet = build_from_cfg(dict(type='Resnet'), MODELS)
>>> # Returns an instantiated object
>>> @MODELS.register_module()
>>> def resnet50():
>>>     pass
>>> resnet = build_from_cfg(dict(type='resnet50'), MODELS)
>>> # Return a result of the calling function
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11

2、 Parameters

Parameters
	cfg (dict) – Config dict. It should at least contain the key “type”.
	
	registry (Registry) – The registry to search the type from.

	default_args (dict, optional) – Default initialization arguments.

Returns
	The constructed object.

Return type
	object
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12

3、A Simple Example

这里我们展示了一个使用注册表来管理包中的模块的简单示例。您可以在OpenMMLab项目中找到更多实际的例子。假设我们希望实现一系列Dataset Converter,用于将不同格式的数据转换为预期的数据格式。我们创建了一个名为converters的目录作为包。在包中,我们首先创建一个文件来实现生成器,名为converters/builder.py,如下所示

from mmcv.utils import Registry
# create a registry for converters
CONVERTERS = Registry('converters')
  • 1
  • 2
  • 3

然后我们可以在包中实现不同的类或函数转换器。例如,在Converter1 .py中实现Converter1,在converter2.py中实现converter2。


from .builder import CONVERTERS

# use the registry to manage the module
@CONVERTERS.register_module()
class Converter1(object):
    def __init__(self, a, b):
        self.a = a
        self.b = b
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9

4、下面是一个特定数据加载器的示例:

注意:在vo.24.1之前,除train、val test、samples_per_gpu和workers_per_gpu外,data中的其他键必须是pytorch中dataloader的输入关键字参数,用于模型训练、验证和测试的dataloader具有相同的输入参数。在vo24.1中,mmseg支持使用train_dataloader、test_dataloaser和val_dataloader来指定不同的关键字参数,并且仍然支持总体参数定义,但特定的数据loader设置具有更高的优先级。

data = dict(
    samples_per_gpu=4,
    workers_per_gpu=4,
    shuffle=True,
    train=dict(type='xxx', ...),
    val=dict(type='xxx', ...),
    test=dict(type='xxx', ...),
    # Use different batch size during validation and testing.
    val_dataloader=dict(samples_per_gpu=1, workers_per_gpu=4, shuffle=False),
    test_dataloader=dict(samples_per_gpu=1, workers_per_gpu=4, shuffle=False))
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10

假设只使用一个gpu进行模型训练和测试,因为整体参数定义的优先级较低,用于训练的batch_size为4,数据集将进行洗选,用于测试和验证的batch_size为1,数据集将不进行洗选

为了使数据配置更清晰,我们建议使用特定的数据加载器设置,而不是v0.24.1之后的整体数据加载器设置,就像:

data = dict(
    train=dict(type='xxx', ...),
    val=dict(type='xxx', ...),
    test=dict(type='xxx', ...),
    # Use specific dataloader setting
    train_dataloader=dict(samples_per_gpu=4, workers_per_gpu=4, shuffle=True),
    val_dataloader=dict(samples_per_gpu=1, workers_per_gpu=4, shuffle=False),
    test_dataloader=dict(samples_per_gpu=1, workers_per_gpu=4, shuffle=False))
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8

注意:在模型训练中,mmseg for dataloader的脚本默认值为shuffle=True, drop_last=True,在模型验证和测试中,默认值为shuffle=False, drop_last=False

5、通过重组数据自定义数据集

最简单的方法是将数据集转换为文件夹。文件结构示例如下所示。
├── data
│ ├── my_dataset
│ │ ├── img_dir
│ │ │ ├── train
│ │ │ │ ├── xxx{img_suffix}
│ │ │ │ ├── yyy{img_suffix}
│ │ │ │ ├── zzz{img_suffix}
│ │ │ ├── val
│ │ ├── ann_dir
│ │ │ ├── train
│ │ │ │ ├── xxx{seg_map_suffix}
│ │ │ │ ├── yyy{seg_map_suffix}
│ │ │ │ ├── zzz{seg_map_suffix}
│ │ │ ├── val

6、通过混合数据集定制数据集

MMSegmentation也支持混合数据集进行训练。目前它支持连接、重复和多图像混合数据集

1、重复的数据集

我们使用RepeatDataset作为包装器来重复数据集。例如,假设原始数据集是Dataset_A,为了重复它,配置如下所示

dataset_A_train = dict(
        type='RepeatDataset',
        times=N,
        dataset=dict(  # This is the original config of Dataset_A
            type='Dataset_A',
            ...
            pipeline=train_pipeline
        )
    )
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9

2、连接数据集

有两种方法连接数据集。如果您想要连接的数据集属于具有不同注释文件的同一类型,您可以像下面这样连接数据集配置。

1、You may concatenate two ann_dir.
dataset_A_train = dict(
    type='Dataset_A',
    img_dir = 'img_dir',
    ann_dir = ['anno_dir_1', 'anno_dir_2'],
    pipeline=train_pipeline
)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
2、You may concatenate two split.
dataset_A_train = dict(
    type='Dataset_A',
    img_dir = 'img_dir',
    ann_dir = 'anno_dir',
    split = ['split_1.txt', 'split_2.txt'],
    pipeline=train_pipeline
)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
3、You may concatenate two ann_dir and split simultaneously.
dataset_A_train = dict(
    type='Dataset_A',
    img_dir = 'img_dir',
    ann_dir = ['anno_dir_1', 'anno_dir_2'],
    split = ['split_1.txt', 'split_2.txt'],
    pipeline=train_pipeline
)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7

在本例中,ann_dir_1和ann_dir_2对应split_1.txt和split_2.txt。
2. 如果要连接的数据集不同,可以像下面这样连接数据集配置

dataset_A_train = dict()
dataset_B_train = dict()

data = dict(
    imgs_per_gpu=2,
    workers_per_gpu=2,
    train = [
        dataset_A_train,
        dataset_B_train
    ],
    val = dataset_A_val,
    test = dataset_A_test
    )
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13

下面是一个更复杂的例子,它分别重复了Dataset_A和Dataset_B N次和M次,然后将重复的数据集连接起来。

dataset_A_train = dict(
    type='RepeatDataset',
    times=N,
    dataset=dict(
        type='Dataset_A',
        ...
        pipeline=train_pipeline
    )
)
dataset_A_val = dict(
    ...
    pipeline=test_pipeline
)
dataset_A_test = dict(
    ...
    pipeline=test_pipeline
)
dataset_B_train = dict(
    type='RepeatDataset',
    times=M,
    dataset=dict(
        type='Dataset_B',
        ...
        pipeline=train_pipeline
    )
)
data = dict(
    imgs_per_gpu=2,
    workers_per_gpu=2,
    train = [
        dataset_A_train,
        dataset_B_train
    ],
    val = dataset_A_val,
    test = dataset_A_test
)

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37

3、多映像组合数据集(Multi-image Mix Dataset)

我们使用MultiImageMixDataset作为包装器来混合来自多个数据集的图像。MultiImageMixDataset可用于多个图像混合数据增强,如马赛克和混合。一个使用MultiImageMixDataset与马赛克数据增强的例子:

train_pipeline = [
    dict(type='RandomMosaic', prob=1),
    dict(type='Resize', img_scale=(1024, 512), keep_ratio=True),
    dict(type='RandomFlip', prob=0.5),
    dict(type='Normalize', **img_norm_cfg),
    dict(type='DefaultFormatBundle'),
    dict(type='Collect', keys=['img', 'gt_semantic_seg']),
]

train_dataset = dict(
    type='MultiImageMixDataset',
    dataset=dict(
        classes=classes,
        palette=palette,
        type=dataset_type,
        reduce_zero_label=False,
        img_dir=data_root + "images/train",
        ann_dir=data_root + "annotations/train",
        pipeline=[
            dict(type='LoadImageFromFile'),
            dict(type='LoadAnnotations'),
        ]
    ),
    pipeline=train_pipeline
)

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26

2、定制数据管道

1、数据管道设计

遵循典型的约定,我们使用Dataset和DataLoader对多个worker进行数据加载。Dataset返回与模型的forward方法的参数相对应的数据项字典。由于语义分割中的数据可能大小不同,我们在MMCV中引入了一个新的DataContainer类型来帮助收集和分发不同大小的数据。详见这里data_container.py
对数据准备管道和数据集进行分解。通常,数据集定义如何处理注释,数据管道定义准备数据字典的所有步骤。管道由一系列操作组成。每个操作以一个字典作为输入,并输出一个字典用于下一个转换。操作分为数据加载、预处理、格式化和测试时间扩展。下面是PSPNet的管道示例。

img_norm_cfg = dict(
    mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
crop_size = (512, 1024)
train_pipeline = [
    dict(type='LoadImageFromFile'),
    dict(type='LoadAnnotations'),
    dict(type='Resize', img_scale=(2048, 1024), ratio_range=(0.5, 2.0)),
    dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
    dict(type='RandomFlip', flip_ratio=0.5),
    dict(type='PhotoMetricDistortion'),
    dict(type='Normalize', **img_norm_cfg),
    dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255),
    dict(type='DefaultFormatBundle'),
    dict(type='Collect', keys=['img', 'gt_semantic_seg']),
]
test_pipeline = [
    dict(type='LoadImageFromFile'),
    dict(
        type='MultiScaleFlipAug',
        img_scale=(2048, 1024),
        # img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75],
        flip=False,
        transforms=[
            dict(type='Resize', keep_ratio=True),
            dict(type='RandomFlip'),
            dict(type='Normalize', **img_norm_cfg),
            dict(type='ImageToTensor', keys=['img']),
            dict(type='Collect', keys=['img']),
        ])
]
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30

对于每个操作,我们列出添加/更新/删除的相关dict字段

1、Data loading

LoadImageFromFile

add: img, img_shape, ori_shape

LoadAnnotations

add: gt_semantic_seg, seg_fields

2、Pre-processing

Resize

add: scale, scale_idx, pad_shape, scale_factor, keep_ratio

update: img, img_shape, *seg_fields

RandomFlip

add: flip

update: img, *seg_fields

Pad

add: pad_fixed_size, pad_size_divisor

update: img, pad_shape, *seg_fields

RandomCrop

update: img, pad_shape, *seg_fields

Normalize

add: img_norm_cfg

update: img

SegRescale

update: gt_semantic_seg

PhotoMetricDistortion

update: img

3、Formatting

ToTensor

update: specified by keys.

ImageToTensor

update: specified by keys.

Transpose

update: specified by keys.

ToDataContainer

update: specified by fields.

DefaultFormatBundle

update: img, gt_semantic_seg

Collect

add: img_meta (the keys of img_meta is specified by meta_keys)

remove: all other keys except for those specified by keys

4、Test time augmentation

2、Extend and use custom pipelines(拓展使用自定义数据增强)

mmsegmention中的transform.py包括class ResizeToMultiple、class Resize、class RandomFlip、class RandomRotate、class AdjustGamma、class PhotoMetricDistortion、 RandomCutOut、RandomMosaic等方法
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
mmdetection中的transform.py包括
在这里插入图片描述
在这里插入图片描述

先transform.py导入额外需要的库
try:
    from imagecorruptions import corrupt
except ImportError:
    corrupt = None

try:
    import albumentations
    from albumentations import Compose
except ImportError:
    albumentations = None
    Compose = None
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12

1、Write a new pipeline in any file, e.g., my_pipeline.py. It takes a dict as input and return a dict.

  1. 在任何文件中写入一个新的管道,例如my_pipeline.py。它接受一个字典作为输入并返回一个字典。
from mmseg.datasets import PIPELINES

@PIPELINES.register_module()
class MyTransform:
      def __init__(self,
                 prob,
                 img_scale=(640, 640),
                 center_ratio_range=(0.5, 1.5),
                 pad_val=0,
                 seg_pad_val=255):
        assert 0 <= prob and prob <= 1
        assert isinstance(img_scale, tuple)
        self.prob = prob
        self.img_scale = img_scale
        self.center_ratio_range = center_ratio_range
        self.pad_val = pad_val
        self.seg_pad_val = seg_pad_val

    def __call__(self, results):
        results['dummy'] = True
        return results
    def __repr__(self):
        repr_str = self.__class__.__name__
        repr_str += f'(prob={self.prob}, '
        repr_str += f'img_scale={self.img_scale}, '
        repr_str += f'center_ratio_range={self.center_ratio_range}, '
        repr_str += f'pad_val={self.pad_val}, '
        repr_str += f'seg_pad_val={self.pad_val})'
        return repr_str
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29

新增class Albu(object)这个类

@PIPELINES.register_module()
class Albu(object):
    """Albumentation augmentation. Adds custom transformations from
    Albumentations library. Please, visit	
    `https://albumentations.readthedocs.io` to get more information. An example
    of ``transforms`` is as followed:
    .. code-block::
            dict(
                type='ShiftScaleRotate',
                shift_limit=0.0625,
                scale_limit=0.0,
                rotate_limit=0,
                interpolation=1,
                p=0.5),
            dict(
                type='RandomBrightnessContrast',
                brightness_limit=[0.1, 0.3],
                contrast_limit=[0.1, 0.3],
                p=0.2),
            dict(type='ChannelShuffle', p=0.1),
            dict(
                type='OneOf',
                transforms=[
                    dict(type='Blur', blur_limit=3, p=1.0),
                    dict(type='MedianBlur', blur_limit=3, p=1.0)
                ],
                p=0.1),
        ]
    Args:
        transforms (list[dict]): A list of albu transformations
        keymap (dict): Contains {'input key':'albumentation-style key'}
    """
    def __init__(self, transforms, keymap=None, update_pad_shape=False):
        # Args will be modified later, copying it will be safer
        transforms = copy.deepcopy(transforms)
        if keymap is not None:
            keymap = copy.deepcopy(keymap)
        self.transforms = transforms
        self.filter_lost_elements = False
        self.update_pad_shape = update_pad_shape
        self.aug = Compose([self.albu_builder(t) for t in self.transforms])
        if not keymap:
            self.keymap_to_albu = {'img': 'image', 'gt_semantic_seg': 'mask'}
        else:
            self.keymap_to_albu = keymap
        self.keymap_back = {v: k for k, v in self.keymap_to_albu.items()}

    def albu_builder(self, cfg):
        """Import a module from albumentations.
        It inherits some of :func:`build_from_cfg` logic.
        Args:
            cfg (dict): Config dict. It should at least contain the key "type".
        Returns:
            obj: The constructed object.
        """
        assert isinstance(cfg, dict) and 'type' in cfg
        args = cfg.copy()
        obj_type = args.pop('type')
        if mmcv.is_str(obj_type):
            obj_cls = getattr(albumentations, obj_type)
        else:
            raise TypeError(f'type must be str, but got {type(obj_type)}')
        if 'transforms' in args:
            args['transforms'] = [
                self.albu_builder(transform)
                for transform in args['transforms']
            ]
        return obj_cls(**args)
        
    @staticmethod
    def mapper(d, keymap):

        """Dictionary mapper.
        Renames keys according to keymap provided.
        Args:	
            d (dict): old dict	
            keymap (dict): {'old_key':'new_key'}
        Returns:
            dict: new dict.
        """
        updated_dict = {}
        for k, v in zip(d.keys(), d.values()):
            new_k = keymap.get(k, k)
            updated_dict[new_k] = d[k]
        return updated_dict

    def __call__(self, results):	
        # dict to albumentations format
        results = self.mapper(results, self.keymap_to_albu)
        results = self.aug(**results)
        # back to the original format
        results = self.mapper(results, self.keymap_back)
        # update final shape
        if self.update_pad_shape:
            results['pad_shape'] = results['img'].shape

        return results

    def __repr__(self):
        repr_str = self.__class__.__name__ + f'(transforms={self.transforms})'
        return repr_str
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101

2、在__init__.py中导入这个新类数据增强函数

from .compose import Compose
from .formating import (Collect, ImageToTensor, ToDataContainer, ToTensor,
                        Transpose, to_tensor)
from .loading import LoadAnnotations, LoadImageFromFile
from .test_time_aug import MultiScaleFlipAug
from .transforms import (CLAHE, AdjustGamma, Normalize, Pad,
                         PhotoMetricDistortion, RandomCrop, RandomFlip,
                         RandomRotate, Rerange, Resize, RGB2Gray, SegRescale, Albu, Grid)#新增的Albu数据增强

__all__ = [
    'Compose', 'to_tensor', 'ToTensor', 'ImageToTensor', 'ToDataContainer',
    'Transpose', 'Collect', 'LoadAnnotations', 'LoadImageFromFile',
    'MultiScaleFlipAug', 'Resize', 'RandomFlip', 'Pad', 'RandomCrop',
    'Normalize', 'SegRescale', 'PhotoMetricDistortion', 'RandomRotate',
    'AdjustGamma', 'CLAHE', 'Rerange', 'RGB2Gray',  'Albu', 'Grid'
]#新增的Albu数据增强
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
from .my_pipeline import MyTransform
from .poly_transforms import (CorrectRBBox, PolyResize, PolyRandomFlip, PolyRandomRotate,
                              Poly_Mosaic_RandomPerspective, MixUp, PolyImgPlot)
  • 1
  • 2
  • 3

3、Use it in config files

img_norm_cfg = dict(
    mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
crop_size = (512, 1024)
train_pipeline = [
    dict(type='LoadImageFromFile'),
    dict(type='LoadAnnotations'),
    dict(type='Resize', img_scale=(2048, 1024), ratio_range=(0.5, 2.0)),
    dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
    dict(type='RandomFlip', flip_ratio=0.5),
    dict(type='PhotoMetricDistortion'),
    dict(type='Normalize', **img_norm_cfg),
    dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255),
    dict(type='MyTransform'),
    dict(type='DefaultFormatBundle'),
    dict(type='Collect', keys=['img', 'gt_semantic_seg']),
]
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
work_dir = 'work_dirs/swin_base_patch4_window12_dotav2/'
# model settings
norm_cfg = dict(type='GN', num_groups=32, requires_grad=True)

model = dict(
    type='OrientedRepPointsDetector',
    pretrained='/checkpoints_torch1.4/swin_base_patch4_window12_384_22kto1k.pth',
    backbone=dict(
        type='SwinTransformer',
        embed_dim=128,         # tiny 96    small 96       base 128      large 192
        depths=[2, 2, 18, 2],  # tiny 2262  small 22 18 2  base 22 18 2  large 22 18 2
        num_heads=[4, 8, 16, 32],
        window_size=12,        # tiny 7     samall 7
        mlp_ratio=4.,
        qkv_bias=True,
        qk_scale=None,
        drop_rate=0.,
        attn_drop_rate=0.,
        drop_path_rate=0.3,  # 训练时间小于1×最好置为0.1
        ape=False,    # 是否需要对嵌入向量进行相对位置编码
        patch_norm=True,
        out_indices=(1, 2, 3),  # strides: [4, 8, 16, 32]  channel:[128, 256, 512, 1024]
        use_checkpoint=True
    ),
    neck=
        dict(
        type='FPN',
        in_channels=[256, 512, 1024],
        out_channels=256,
        #start_level=1,
        add_extra_convs=True,
        num_outs=5,
        norm_cfg=norm_cfg
        ),
    bbox_head=dict(
        type='OrientedRepPointsHead',
        num_classes=16,
        in_channels=256,
        feat_channels=256,
        point_feat_channels=256,
        stacked_convs=3,
        num_points=9,
        gradient_mul=0.3,
        point_strides=[8, 16, 32, 64, 128],
        point_base_scale=2,
        norm_cfg=norm_cfg,
        loss_cls=dict(type='FocalLoss', use_sigmoid=True, gamma=2.0, alpha=0.25, loss_weight=1.0),
        loss_rbox_init=dict(type='GIoULoss', loss_weight=0.375),
        loss_rbox_refine=dict(type='GIoULoss', loss_weight=1.0),
        loss_spatial_init=dict(type='SpatialBorderLoss', loss_weight=0.05),
        loss_spatial_refine=dict(type='SpatialBorderLoss', loss_weight=0.1),
        top_ratio=0.4,))
# training and testing settings
train_cfg = dict(
    init=dict(
        assigner=dict(type='PointAssigner', scale=4, pos_num=1),  # 每个gtbox仅选一个正样本
        allowed_border=-1,
        pos_weight=-1,
        debug=False),
    refine=dict(
        assigner=dict(
            type='MaxIoUAssigner', #pre-assign to select more samples for samples selection
            pos_iou_thr=0.1,
            neg_iou_thr=0.1,
            min_pos_iou=0,
            ignore_iof_thr=-1),
        allowed_border=-1,
        pos_weight=-1,
        debug=False))

test_cfg = dict(
    nms_pre=2000,
    min_bbox_size=0,
    score_thr=0.05,
    nms=dict(type='rnms', iou_thr=0.4),
    max_per_img=2000)

# dataset settings
dataset_type = 'DotaDatasetv2'
data_root = '/media/test/4d846cae-2315-4928-8d1b-ca6d3a61a3c6/DOTA/DOTAv2.0/'
img_norm_cfg = dict(
    mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
train_pipeline = [
    dict(type='LoadImageFromFile'),
    dict(type='LoadAnnotations', with_bbox=True),
    dict(type='CorrectRBBox', correct_rbbox=True, refine_rbbox=True),
    dict(type='PolyResize',
        img_scale=[(1333, 768), (1333, 1280)],  # 建议根据显存来确定长边的值,在线多尺度缩放幅度控制在25%左右为佳
        keep_ratio=True,
        multiscale_mode='range',
        clamp_rbbox=False),
    dict(type='PolyRandomFlip', flip_ratio=0.5),
   # dict(type='HSVAugment', hgain=0.015, sgain=0.7, vgain=0.4),
    dict(type='PolyRandomRotate', rotate_ratio=0.5, angles_range=180, auto_bound=False),
    dict(type='Pad', size_divisor=32),
   # dict(type='Poly_Mosaic_RandomPerspective', mosaic_ratio=0, ifcrop=True, degrees=0, translate=0.1, scale=0.2, shear=0, perspective=0.0),
   # dict(type='MixUp', mixup_ratio=0.5),
    dict(type='PolyImgPlot', img_save_path=work_dir, save_img_num=16, class_num=18, thickness=2),
    dict(type='Normalize', **img_norm_cfg),
    dict(type='DefaultFormatBundle'),
    dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels'])]

test_pipeline = [
    dict(type='LoadImageFromFile'),
    dict(
        type='MultiScaleFlipAug',
        img_scale=(1024, 1024),
        flip=False,
        transforms=[
            dict(type='PolyResize', keep_ratio=True),
            dict(type='PolyRandomFlip'),
            dict(type='Normalize', **img_norm_cfg),
            dict(type='Pad', size_divisor=32),
            dict(type='ImageToTensor', keys=['img']),
            dict(type='Collect', keys=['img']),
        ])
]

data = dict(
    imgs_per_gpu=2,
    workers_per_gpu=2,
    train=dict(
        type=dataset_type,
        ann_file=data_root + 'trainval_split_1024/Train_dotav2_trainval1024_poly.json',
        img_prefix=data_root + 'trainval_split_1024/images/',
        pipeline=train_pipeline,
        Mosaic4=False,
        Mosaic9=False,
        Mixup=False),
    val=dict(
        type=dataset_type,
        ann_file=data_root + 'trainval_split_1024/Train_dotav2_trainval1024_poly.json',
        img_prefix=data_root + 'trainval_split_1024/images/',
        pipeline=test_pipeline),
    test=dict(
        type=dataset_type,
        ann_file=data_root + 'test-dev_split/Test_datav2_test1024.json',
        img_prefix=data_root + 'test-dev_split/images/',
        pipeline=test_pipeline))
evaluation = dict(interval=1, metric='bbox')

# optimizer
optimizer = dict(type='AdamW', lr=0.0001, betas=(0.9, 0.999), weight_decay=0.05,
                paramwise_cfg=dict(custom_keys={'absolute_pos_embed': dict(decay_mult=0.),
                                                 'relative_position_bias_table': dict(decay_mult=0.),
                                                 'norm': dict(decay_mult=0.)}))

# learning policy
lr_config = dict(
    policy='step',
    warmup='linear',
    warmup_iters=500,
    warmup_ratio=1.0 / 3, # 
    step=[27, 33])

runner = dict(type='EpochBasedRunnerAmp', max_epochs=36)
total_epochs = 36

checkpoint_config = dict(interval=2)
# yapf:disable
log_config = dict(
    interval=20,          # 迭代n次时打印一次
    hooks=[
        dict(type='TextLoggerHook')
    ])
# yapf:enable
# runtime settings
dist_params = dict(backend='nccl')
log_level = 'INFO'
load_from = None
resume_from = None#'work_dirs/swin_tiny_patch4_window7_gradclip/latest.pth'
workflow = [('train', 1)]

# do not use mmdet version fp16
fp16 = None
optimizer_config = dict(grad_clip=dict(max_norm=35, norm_type=2))
# optimizer_config = dict(
#  #   type="DistOptimizerHook",
#  #   update_interval=1,
#     grad_clip=None,
#     coalesce=True,
#     bucket_size_mb=-1,
#  #   use_fp16=True,
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107
  • 108
  • 109
  • 110
  • 111
  • 112
  • 113
  • 114
  • 115
  • 116
  • 117
  • 118
  • 119
  • 120
  • 121
  • 122
  • 123
  • 124
  • 125
  • 126
  • 127
  • 128
  • 129
  • 130
  • 131
  • 132
  • 133
  • 134
  • 135
  • 136
  • 137
  • 138
  • 139
  • 140
  • 141
  • 142
  • 143
  • 144
  • 145
  • 146
  • 147
  • 148
  • 149
  • 150
  • 151
  • 152
  • 153
  • 154
  • 155
  • 156
  • 157
  • 158
  • 159
  • 160
  • 161
  • 162
  • 163
  • 164
  • 165
  • 166
  • 167
  • 168
  • 169
  • 170
  • 171
  • 172
  • 173
  • 174
  • 175
  • 176
  • 177
  • 178
  • 179
  • 180
  • 181
  • 182
  • 183
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/Cpp五条/article/detail/347440
推荐阅读
相关标签
  

闽ICP备14008679号