当前位置:   article > 正文

OpenMMLab-AI实战营第二期——3-2. MMPretrain代码实战_mmpretrain 中文字符

mmpretrain 中文字符


这里的重点是要掌握四种配置文件,整体可以参考:

1. 配置文件修改

mmpretrain/configs/resnet/resnet18_8xb32_in1k.py文件为例,这个文件其实是把其他四个配置文件汇总到一起

_base_ = [
    '../_base_/models/resnet18.py', '../_base_/datasets/imagenet_bs32.py',
    '../_base_/schedules/imagenet_bs256.py', '../_base_/default_runtime.py'
]
  • 1
  • 2
  • 3
  • 4

即关于mmpretrain的配置可以分为四部分,下面会分别说明:

  1. 模型类配置,可以参考:学习配置文件
  2. 数据类配置,可以参考:
  3. 训练schedule配置,可以参考:自定义训练优化策略
  4. 运行时配置,可以参考:自定义运行参数

其实这些配置汇总到一个文件,就和之前在mmpose和mmdetection中讲的配置,差不多,都是使用字典来组织属性和值。

1.1 模型设置

基本模板:

# model settings
model = dict(
    type='ImageClassifier',
    backbone=dict(
        type='ResNet',
        depth=18,
        num_stages=4,
        out_indices=(3, ),
        style='pytorch'),
    neck=dict(type='GlobalAveragePooling'),
    head=dict(
        type='LinearClsHead',
        num_classes=1000,  # 分类数量1000
        in_channels=512,
        loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
        topk=(1, 5),
    ),
    # 如果要预训练,可以加上下面的参数:
    init_cfg=dict(
        type='Pretrained',
        checkpoint='https://download.openmmlab.com/mmclassification/v0/resnet/resnet18_8xb32_in1k_20210831-fbbb1da6.pth',)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21

另外,之前在mmdetection中进行预训练使用的是load_from参数,根据load_from 和 init_cfg 之间的关系是什么?

load_from:

  • 如果resume=False,只导入模型权重,主要用于加载训练过的模型;
  • 如果 resume=True ,加载所有的模型权重、优化器状态和其他训练信息,主要用于恢复中断的训练。

init_cfg:
也可以指定init=dict(type=“Pretrained”, checkpoint=xxx)来加载权重, 表示在模型权重初始化时加载权重,通常在训练的开始阶段执行。 主要用于微调预训练模型,你可以在骨干网络的配置中配置它,还可以使用 prefix 字段来只加载对应的权重

即这两个参数都可以用来指定预训练模型

1.2 数据设置

1.2.1 基本配置模板

基本模板:

# dataset settings
dataset_type = 'ImageNet'
# 数据预处理,归一化部分,这里和模型有重叠,都给了`num_classes`这个属性,需要注意保持一致。
data_preprocessor = dict(
    num_classes=1000,
    # RGB format normalization parameters
    mean=[123.675, 116.28, 103.53],
    std=[58.395, 57.12, 57.375],
    # convert image from BGR to RGB
    to_rgb=True,
)
# 训练数据、测试数据的数据增强pipeline
train_pipeline = [
    dict(type='LoadImageFromFile'),
    dict(type='RandomResizedCrop', scale=224),
    dict(type='RandomFlip', prob=0.5, direction='horizontal'),
    dict(type='PackInputs'),
]

test_pipeline = [
    dict(type='LoadImageFromFile'),
    dict(type='ResizeEdge', scale=256, edge='short'),
    dict(type='CenterCrop', crop_size=224),
    dict(type='PackInputs'),
]
# 数据划分batch,以及数据集位置,读取问题
train_dataloader = dict(
    batch_size=32,
    num_workers=5,
    dataset=dict(
        type=dataset_type,
        data_root='data/imagenet',
        ann_file='meta/train.txt',
        data_prefix='train',
        pipeline=train_pipeline),
    sampler=dict(type='DefaultSampler', shuffle=True),
)
val_dataloader = dict(
    batch_size=32,
    num_workers=5,
    dataset=dict(
        type=dataset_type,
        data_root='data/imagenet',
        ann_file='meta/val.txt',
        data_prefix='val',
        pipeline=test_pipeline),
    sampler=dict(type='DefaultSampler', shuffle=False),
)
# 评估
val_evaluator = dict(type='Accuracy', topk=(1, 5))

# If you want standard test, please manually configure the test dataset
test_dataloader = val_dataloader
test_evaluator = val_evaluator
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54

1.2.2 自定义数据集

如果不想对数据集进行训练集和测试集的划分:

  • 比如:数据已经传到google云上了,懒得再改
  • 或者之后可能会修改数据分布,可能会有新的数据进来等

那么频繁对图像文件夹进行改动,就不是很方便,此时可以考虑添加一个标注文件。

以分类数据集为例,参考:准备数据集-标注文件方式

1.文件结构

|--fruit30_train
	|--哈密瓜
		|--1.jpg
		|--2.jpg
	|--苦瓜
		|--1.jpg
		|--2.jpg
	|--meta  # 存放标注文件
		|--train.txt
		|--val.txt
		|--test.txt
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11

2.标注文件内容
比如: meta/train.txt文件的内容

苦瓜/63.jpg 0 # 这里的路径是相对于`数据集的根目录`而言的
苦瓜/189.jpg 0
苦瓜/77.jpg 0
苦瓜/200.jpg 0
苦瓜/51.jpeg 0
苦瓜/89.jpg 0
苦瓜/188.jpg 0
苦瓜/60.jpg 0
苦瓜/48.jpg 0
...
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10

3.配置文件

dataset=dict(
        type='CustomDataset', # mmpretrain支持两种自定义格式,一种是子目录,一种就是自己定义标签
        data_root='/content/fruit30_train', # 数据集的根目录,图像和标注的根目录
        ann_file='meta/train.txt', # 标注文件,相对于根目录
        classes=[
            '苦瓜', '圣女果', '芒果', '菠萝', '石榴', '山竹', '苹果-红', '猕猴桃', '哈密瓜', '柚子',
            '脐橙', '车厘子', '杨梅', '草莓', '椰子', '西瓜', '桂圆', '葡萄-红', '苹果-青', '荔枝',
            '香蕉', '柠檬', '胡萝卜', '梨', '葡萄-白', '砂糖橘', '黄瓜', '榴莲', '西红柿', '火龙果'
        ],
        # 分类的类别,注意是 0开始,注意序号和对应类别名称对应
        pipeline=[
            dict(type='LoadImageFromFile'),
            dict(type='RandomResizedCrop', scale=224),
            dict(type='RandomFlip', prob=0.5, direction='horizontal'),
            dict(type='PackInputs')
        ]),
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16

另外,更正规的格式是:

----数据集文件结构------
|--fruit30_train
	|--data #分类图像文件夹的上级目录
		|--哈密瓜
			|--1.jpg
			|--2.jpg
		|--苦瓜
			|--1.jpg
			|--2.jpg
	|--meta  # 存放标注文件
		|--train.txt
		|--val.txt
		|--test.txt
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13

则对应的meta里的标注内容:

/data/苦瓜/63.jpg 0 # 这里的路径是相对于`数据集的根目录`而言的
/data/苦瓜/189.jpg 0
/data/苦瓜/77.jpg 0
/data/苦瓜/200.jpg 0
/data/苦瓜/51.jpeg 0
...
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6

相应配置文件的内容:

dataset=dict(
        type='CustomDataset', # mmpretrain支持两种自定义格式,一种是子目录,一种就是自己定义标签
        data_root='/content/fruit30_train', # 数据集的根目录,图像和标注的根目录
        ann_file='meta/train.txt', # 标注文件,相对于根目录
        data_prefix='data/',    # 标注文件内容里 文件路径的前缀,相对于 `data_root`
        classes=[
            '苦瓜', '圣女果', '芒果', '菠萝', '石榴', '山竹', '苹果-红', '猕猴桃', '哈密瓜', '柚子',
            '脐橙', '车厘子', '杨梅', '草莓', '椰子', '西瓜', '桂圆', '葡萄-红', '苹果-青', '荔枝',
            '香蕉', '柠檬', '胡萝卜', '梨', '葡萄-白', '砂糖橘', '黄瓜', '榴莲', '西红柿', '火龙果'
        ])
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10

1.3 训练策略设置

基本模板

# optimizer
optim_wrapper = dict(
    optimizer=dict(type='SGD', lr=0.1, momentum=0.9, weight_decay=0.0001))

# learning policy
param_scheduler = dict(
    type='MultiStepLR', by_epoch=True, milestones=[30, 60, 90], gamma=0.1)

# train, val, test setting
train_cfg = dict(by_epoch=True, max_epochs=100, val_interval=1)
val_cfg = dict()
test_cfg = dict()

# NOTE: `auto_scale_lr` is for automatically scaling LR,
# based on the actual training batch size.
auto_scale_lr = dict(base_batch_size=256)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16

注意:
在上面数据配置中,设置了batch_size=32,,而这里auto_scale_lr = dict(base_batch_size=256)是因为用了8卡去训练,如果batch_size是256,则学习率用lr=0.1,否则,根据实际的batch_size大小去等比例缩放学习率

train_dataloader = dict(
batch_size=32,
num_workers=5,
  • 1
  • 2
  • 3

其他有兴趣的可以看看:自定义训练优化策略

1.4 运行设置

# defaults to use registries in mmpretrain
default_scope = 'mmpretrain'

# configure default hooks
default_hooks = dict(
    # record the time of every iteration.
    timer=dict(type='IterTimerHook'),

    # print log every 100 iterations.
    logger=dict(type='LoggerHook', interval=100),

    # enable the parameter scheduler.
    param_scheduler=dict(type='ParamSchedulerHook'),

    # save checkpoint per epoch.
    checkpoint=dict(type='CheckpointHook', interval=1),

    # set sampler seed in distributed evrionment.
    sampler_seed=dict(type='DistSamplerSeedHook'),

    # validation results visualization, set True to enable it.
    visualization=dict(type='VisualizationHook', enable=False),
)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23

值得说明的是:

  • 可以在checkpoint中添加:checkpoint=dict(type='CheckpointHook', interval=1,max_keep_ckpts=5,save_best='auto'),来控制最多保存的权重数量,已经保存最优权重。
  • 其他的没有什么好说的,有兴趣的可以看看这个文档:自定义运行参数

2. 训练测试代码

大部分是命令行执行,很少是python脚本

2.1 mmpretrain环境安装

pip install openmim

git clone https://github.com/open-mmlab/mmpretrain.git
cd /mmpretrain
mim install -e ".[multimodal]"
  • 1
  • 2
  • 3
  • 4
  • 5

2.2 mmpretrain基本语法

from mmpretrain import get_model,list_models,inference_model

# 1.列举图像分类相关的resnet18的模型
list_models(task="Image Classification",pattern="resnet18")
>['resnet18_8xb16_cifar10', 'resnet18_8xb32_in1k']

# 2. 获取模型结构
model = get_model('resnet18_8xb16_cifar10')
type(model),type(model.backbone)
> (mmpretrain.models.classifiers.image.ImageClassifier,
 mmpretrain.models.backbones.resnet_cifar.ResNet_CIFAR)

# 3.默认get的model是没有权重的,此时进行推理会得到错误的结果
inference_model(model,"/content/drive/MyDrive/OpenMMLab/Exercise_2/litchi_test.jpeg",show=True)     

# 4. 指定要下载的权重名称,再去推理,就可以得到正确的结果
inference_model("blip-base_3rdparty_caption","/content/drive/MyDrive/OpenMMLab/Exercise_2/litchi_test.jpeg",show=True)
> {'pred_caption': 'a close up of a fruit on a white background'}

# 5. 运行时修改配置文件
from mmengine import Config
cfg = Config.fromfile('/content/mmpretrain/configs/resnet/resnet18_8xb32_in1k.py')
cfg.model
>{'type': 'ImageClassifier',
 'backbone': {'type': 'ResNet',
  'depth': 18,
  'num_stages': 4,
  'out_indices': (3,),
  'style': 'pytorch'},
 'neck': {'type': 'GlobalAveragePooling'},
 'head': {'type': 'LinearClsHead',
  'num_classes': 1000,
  'in_channels': 512,
  'loss': {'type': 'CrossEntropyLoss', 'loss_weight': 1.0},
  'topk': (1, 5)}}
cfg.model.head.num_classes=2
cfg.model
>{'type': 'ImageClassifier',
 'backbone': {'type': 'ResNet',
  'depth': 18,
  'num_stages': 4,
  'out_indices': (3,),
  'style': 'pytorch'},
 'neck': {'type': 'GlobalAveragePooling'},
 'head': {'type': 'LinearClsHead',
  'num_classes': 2,
  'in_channels': 512,
  'loss': {'type': 'CrossEntropyLoss', 'loss_weight': 1.0},
  'topk': (1, 5)}}
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49

2.3 训练、测试和推理命令

  1. 训练
    # min train mmpretrain 配置文件路径 --work-dir(没有等号) workdir路径
    mim train mmpretrain /content/drive/MyDrive/OpenMMLab/Exercise_2/resnet50_finetune_fruits.py \
    --work-dir /content/drive/MyDrive/OpenMMLab/workdir/resnet50_finetune_fruits
    
    • 1
    • 2
    • 3
  2. 测试
    # min test mmpretrain 配置文件路径 --checkpoint 权重路径  --out 后续分析基于的pkl文件,指定一个保存路径
    mim test mmpretrain /content/drive/MyDrive/OpenMMLab/Exercise_2/resnet50_finetune_fruits.py \
     --checkpoint /content/drive/MyDrive/OpenMMLab/workdir/resnet50_finetune_fruits/best_accuracy_top1_epoch_9.pth\
     --out /content/drive/MyDrive/OpenMMLab/workdir/resnet50_finetune_fruits/result.pkl
    
    • 1
    • 2
    • 3
    • 4
  3. 分析结果(基于测试时得到的result.pkl文件)
    # mim run mmpretrain analyze_results 配置文件路径 测试结果pkl文件路径 --out-dir 输出存放路径
    mim run mmpretrain analyze_results /content/drive/MyDrive/OpenMMLab/Exercise_2/resnet50_finetune_fruits.py\
    /content/drive/MyDrive/OpenMMLab/workdir/resnet50_finetune_fruits/result.pkl\
    --out-dir /content/drive/MyDrive/OpenMMLab/Exercise_2/outputs/fruit30
    
    • 1
    • 2
    • 3
    • 4
    输出的目录包含:failsuccess,分别存放推理错误和正确的图像
  4. 混淆矩阵查看
    # mim run mmpretrain confusion_matrix 配置文件路径 测试结果pkl文件路径 是否显示 图像上是否包含TP等的确切数值
    mim run mmpretrain confusion_matrix /content/drive/MyDrive/OpenMMLab/Exercise_2/resnet50_finetune_fruits.py\
    /content/drive/MyDrive/OpenMMLab/workdir/resnet50_finetune_fruits/result.pkl\
    --show --include-values
    
    • 1
    • 2
    • 3
    • 4
  5. 推理:
    from mmpretrain import ImageClassificationInferencer
    # ImageClassificationInferencer类传入配置文件路径和最优权重路径
    inferencer = ImageClassificationInferencer('/content/drive/MyDrive/OpenMMLab/Exercise_2/resnet50_finetune_fruits.py',
    pretrained='/content/drive/MyDrive/OpenMMLab/workdir/resnet50_finetune_fruits/best_accuracy_top1_epoch_9.pth')
    
    # 如果所在编程环境可以显示图像,则可以使用show=True
    # inferencer("/content/drive/MyDrive/OpenMMLab/Exercise_2/litchi_test.jpeg",show=True)
    
    # 如果所在编程环境无法显示图像,则可以使用show_dir指定保存输出结果目录
    inferencer("/content/drive/MyDrive/OpenMMLab/Exercise_2/litchi_test.jpeg",show_dir="/content/drive/MyDrive/OpenMMLab/Exercise_2/")
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10

3. 结果显示支持中文

写文件
根据ImageClassificationInferencer类代码进行定位:

  • mmpretrain/apis/image_classification.py文件中,类ImageClassificationInferencervisualize()方法调用了mmpretrain/visualization/visualizer.py文件中的visualize_cls()方法,这个方法中保存文件用了mmcv.imwrite(),根据mmcv/mmcv/image/io.py,mmcv写图像时其实还是调用了cv2
  • 所以结合其他部分的代码,推测如果要写图像的话,用的是cv2,所以不支持中文,只能借用pillow等其他库

例如:

from mmpretrain import ImageClassificationInferencer

inferencer = ImageClassificationInferencer('/content/drive/MyDrive/OpenMMLab/Exercise_2/resnet50_finetune_fruits.py',
 pretrained='/content/drive/MyDrive/OpenMMLab/workdir/resnet50_finetune_fruits/best_accuracy_top1_epoch_9.pth')

visual_img = inferencer("/content/drive/MyDrive/OpenMMLab/Exercise_2/litchi_test.jpeg",show_dir="/content/drive/MyDrive/OpenMMLab/Exercise_2/")


visual_img[0]
> {'pred_scores': array([9.8617036e-08, 4.6559649e-09, 4.7238853e-09, 6.1176209e-08,
        1.1758071e-06, 1.9768613e-05, 6.5803279e-06, 2.8776577e-09,
        1.9613218e-09, 4.2568220e-09, 5.1932891e-10, 1.3199590e-07,
        1.7486315e-07, 2.1758628e-05, 4.7297416e-07, 3.6769190e-08,
        1.7304191e-06, 6.3718785e-06, 4.1835184e-09, 9.9994004e-01,
        5.3369670e-10, 1.9838148e-11, 1.4331825e-06, 1.3142658e-08,
        1.2287903e-09, 1.3498119e-07, 3.3858927e-10, 1.2507003e-09,
        4.5590962e-09, 2.9508143e-09], dtype=float32),
 'pred_label': 19,
 'pred_score': 0.999940037727356,
 'pred_class': '荔枝'}

from PIL import ImageFont, ImageDraw, Image
import cv2
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

img_path = "/content/drive/MyDrive/OpenMMLab/Exercise_2/litchi_test.jpeg"
fontpath = "/content/drive/MyDrive/OpenMMLab/SimHei.ttf" # <== 这里是宋体路径 
font = ImageFont.truetype(fontpath, 32)

img2=cv2.imread(img_path)
img = Image.fromarray(cv2.cvtColor(img2, cv2.COLOR_BGR2RGB))

draw = ImageDraw.Draw(img)
pred_class = visual_img[0]['pred_class']
pred_label=visual_img[0]['pred_label']
pred_score=visual_img[0]['pred_score']
draw.text((50,50 ), 'label:{} socre:{:.2f} class:{}'.format(pred_label,pred_score,pred_class), (0,0,0), font=font)

img = np.asarray(img)
plt.imshow(img)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42

在这里插入图片描述
参考:How to draw Chinese text on the image using cv2.putTextcorrectly? (Python+OpenCV)


读文件
根据以下内容:

  • SOURCE CODE FOR MMENGINE.VISUALIZATION.VISUALIZER

    font_properties (Union[FontProperties, List[FontProperties]], optional):The font properties of texts. FontProperties is a font_manager.FontProperties() object. If you want to draw Chinese texts, you need to prepare a font file that can show Chinese characters properly.For example: simhei.ttf, simsun.ttc, simkai.ttf and so on. Then set font_properties=matplotlib.font_manager.FontProperties(fname='path/to/font_file') font_properties can have the same length with texts or just single value. If font_properties is single value, all the texts will have the same font properties. Defaults to None. New in version 0.6.0.
    这一属性属于: draw_texts()函数,是mmocr里用的,mmpretrain虽然也调用了mmengine,但是没有用这个函数,

  • [Feature] add a new argument font_properties to set a specific font file in order to draw Chinese characters properly

    修改配置文件:
    vis_backends = [dict(type='LocalVisBackend')]
    visualizer = dict(
        type='TextRecogLocalVisualizer',
        name='visualizer',
        vis_backends=vis_backends,
        font_properties='xxx/SimHei.ttf')
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
  • ImageClassificationInferencervisualize()方法调用了mmpretrain/visualization/visualizer.py文件中的visualize_cls()方法,在visualize_cls()中,有参数:
    """
    text_cfg (dict): Extra text setting, which accepts
      arguments of :meth:`mmengine.Visualizer.draw_texts`.
      Defaults to an empty dict.
    """
    # 但是visualize()在调用时,并没有传递一个mmengine.Visualizer.draw_texts参数,所以无法配置中文
    def visualize(self,
                  ori_inputs: List[InputType],
                  preds: List[DataSample],
                  show: bool = False,
                  wait_time: int = 0,
                  resize: Optional[int] = None,
                  rescale_factor: Optional[float] = None,
                  draw_score=True,
                  show_dir=None):
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
  • 所以如果想支持中文的话,需要mmpretrain的人去支持改一下代码,或者我自己改了再去尝试。
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/很楠不爱3/article/detail/457727
推荐阅读
相关标签
  

闽ICP备14008679号