当前位置:   article > 正文

MMdetection3.0 训练DETR问题分析_mmdetection 训练detr

mmdetection 训练detr

MMdetection3.0 训练DETR问题分析

针对在MMdetection3.0框架下训练DETR模型,验证集AP值一直为0.000的原因作出如下分析并得出结论。

条件:
1、NWPU-VHR-10数据集:共650张,训练:验证=611:39;
2、MMdetection3.0框架实验分析;
3、DETR原论文提供源代码实验分析;
4、已在代码中完成了数据类别定义(num_classes)等相关配置的修改。

分析:
1、在MMdetection3.0框架下,只是加载backbone的预训练权重,val上AP始终为0.0000.如下图所示:=》loss收敛较慢,val始终为0.0000.
在这里插入图片描述
在这里插入图片描述
2、在MMdetection3.0框架下,直接加载detr的完整预训练权重。如下图所示:=》存在警告(size mismatch for bbox_head.fc_cls.weight: copying a param with shape torch.Size([81, 256]) from checkpoint, the shape in current model is torch.Size([11, 256]).
size mismatch for bbox_head.fc_cls.bias: copying a param with shape torch.Size([81]) from checkpoint, the shape in current model is torch.Size([11]).
),但训练测试指标还算正常。

=》警告原因:自定义数据集的类别是10+1,而MMdetection3.0提供的是coco数据集与训练权重80+1.
=》因此,需要修改预训练模型的全连接层输出(见下述第4点)。

在这里插入图片描述
在这里插入图片描述3、在MMdetection3.0框架下,直接加载修改后的detr的完整预训练权重训练测试结果见下图所示:=》警告消除,一切正常,并且修改证据权重类别后loss下降变快,val指标更好(不能说更好,只能说更正常)
在这里插入图片描述
在这里插入图片描述4、修改模型权重参数脚本
=》代码中的METAINFO不想修改 不修改也行。
=》主要是pretrained_weights[‘state_dict’][‘bbox_head.fc_cls.weight’].resize_(11, 256)
pretrained_weights[‘state_dict’][‘bbox_head.fc_cls.bias’].resize_(11)

import torch
METAINFO = dict(
    CLASSES=(
        'airplane',
        'ship',
        'storage tank',
        'baseball diamond',
        'tennis court',
        'basketball court',
        'ground track field',
        'harbor',
        'bridge',
        'vehicle',
    ),
    PALETTE=[
        (
            120,
            120,
            120,
        ),
        (
            180,
            120,
            120,
        ),
        (
            6,
            230,
            230,
        ),
        (
            80,
            50,
            50,
        ),
        (
            4,
            200,
            3,
        ),
        (
            120,
            120,
            80,
        ),
        (
            140,
            140,
            140,
        ),
        (
            204,
            5,
            255,
        ),
        (
            230,
            230,
            230,
        ),
        (
            4,
            250,
            7,
        ),
    ])

pretrained_weights = torch.load('/home/admin1/pywork/data/weigh/resnet50-0676ba61.pth')
# 11 是指 数据类别 + 1
pretrained_weights['state_dict']['bbox_head.fc_cls.weight'].resize_(11, 256)
pretrained_weights['state_dict']['bbox_head.fc_cls.bias'].resize_(11)
pretrained_weights['meta']['experiment_name'] = 'detr_r50_8xb2-150e_coco_11'
pretrained_weights['meta']['dataset_meta'] = METAINFO
torch.save(pretrained_weights, "detr_r50_8xb2-150e_coco_%d.pth" % num_classes)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74

5、DETR原论文提供的源代码训练情况跟MMdetection3.0框架下的情况类似,都必须加载预训练模型,否则就是一直0.000000000000000.

总结分析:
1、NWPU-VHR-10数据量太小导致的问题(90%),等待进一步测试。
2、Transformer模型提出来的时候就已经说明很吃数据,所以没有足够的数据直接使用transformer训练往往效果不好,所以数据量不足的情况下,还是加载预训练权重吧。
3、backbone的权重在模型的比例其实很小,主要还是后面的编码、解码器,所以只加载backbone的权重也没什么用。

总之,数据、数据、数据要足够哇

声明:本文内容由网友自发贡献,转载请注明出处:【wpsshop】
推荐阅读
相关标签
  

闽ICP备14008679号