当前位置:   article > 正文

Windows下运行Deformable-DETR_windows训练deformable detr

windows训练deformable detr

环境

Win10 + CUDA10.1 + Python3.7 + Pytorch1.5

代码下载

官方代码

git clone https://github.com/fundamentalvision/Deformable-DETR
  • 1

官方权重-r50
推理demo

遇见的坑

由于官方代码是在Linux下开发,所以在Windows下运行需要做一定修改。

项目依赖安装按照官方说明即可。

1 项目编译

Compiling CUDA operators

cd ./models/ops
sh ./make.sh
  • 1
  • 2

在Windows可以通过GitBash来运行.sh文件,或者使用cmd在指定目录下(Deformable-DETR\models\ops)直接运行.sh文件中的指令

python setup.py build install
  • 1

2 开始训练

准备好数据集后,下面开始训练

在GitBash按照官方说明运行训练指令

GPUS_PER_NODE=8 ./tools/run_dist_launch.sh 8 ./configs/r50_deformable_detr.sh
  • 1

出现错误
Traceback (most recent call last):
File "./tools/launch.py", line 192, in <module>
main()
File "./tools/launch.py", line 181, in main
process = subprocess.Popen(cmd, env=current_env)
File "D:\Anaconda3\envs\deformable_detr\lib\subprocess.py", line 800, in __init__ restore_signals, start_new_session)
File "D:\Anaconda3\envs\deformable_detr\lib\subprocess.py", line 1207, in _execute_child startupinfo)
OSError: [WinError 193] %1 不是有效的 Win32 应用程序

原因应该是Windows中调用subprocess是对传入参数的处理和Linux有区别,在这尝试修改,没有成功。因为我自己的电脑只有一个GPU,不需要launch.py中关于多GPU和分布式训练的参数配置,所以索性跳过此步骤,直接使用main.py来进行训练。

3 加载权重

可以通过–resume来加载权重文件,此时需要注意设置–start_epoch和–epochs来控制训练次数,可直接修改默认值,也可以在运行main.py时传入。

parser.add_argument('--start_epoch', default=50, type=int, metavar='N', help='start epoch')
parser.add_argument('--epochs', default=100, type=int)
parser.add_argument('--resume', default='exp/r50/r50_deformable_detr-checkpoint.pth', help='resume from checkpoint')    
  • 1
  • 2
  • 3
python main.py --start_epoch 0 --epochs 100 --resume exp/r50/r50_deformable_detr-checkpoint.pth  
  • 1

这里还有个坑,由于是以–resume来加载权重文件,所以代码中首个epoch会对当前结果进行验证。如果是第一次运行,可能会出现验证时的分类和自己数据集分类不匹配的情况。

# main.py
# line 229
# if args.resume:
if args.resume and args.start_epoch is not 0:
# 可以通过epoch次数控制是否在首轮进行验证加载的权重文件
  • 1
  • 2
  • 3
  • 4
  • 5

4 自己的数据集

默认使用coco数据集的分类数,若要训练自己的数据集,需要修改几个参数:

根据自己的命名修改图像文件和标注文件的目录和地址:

# datasets/coco.py
# line 160-164
    mode = 'instances'
    PATHS = {
        "train": (root / "train2017", root / "annotations" / f'{mode}_train2017.json'),
        "val": (root / "val2017", root / "annotations" / f'{mode}_val2017.json'),
    }
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7

分类数:

# modesl/deformable_detr.py
# line 445
# 将此处num_class修改为自己数据集的分类数
	# num_classes = 20 if args.dataset_file != 'coco' else 91
	num_classes = 20 if args.dataset_file != 'coco' else (num_class + 1)
  • 1
  • 2
  • 3
  • 4
  • 5

直接使用官方的权重文件会出现维度不匹配的情况,所以还需要根据自己的数据集修改权重文件,直接运行即可:

import torch
#加载官方提供的权重文件
pretrained_weights = torch.load('r50_deformable_detr-checkpoint.pth')

#修改相关权重
num_class = 5 # 自己数据集分类数
pretrained_weights['model']['class_embed.0.weight'].resize_(num_class+1, 256)
pretrained_weights['model']['class_embed.0.bias'].resize_(num_class+1)
pretrained_weights['model']['class_embed.1.weight'].resize_(num_class+1, 256)
pretrained_weights['model']['class_embed.1.bias'].resize_(num_class+1)
pretrained_weights['model']['class_embed.2.weight'].resize_(num_class+1, 256)
pretrained_weights['model']['class_embed.2.bias'].resize_(num_class+1)
pretrained_weights['model']['class_embed.3.weight'].resize_(num_class+1, 256)
pretrained_weights['model']['class_embed.3.bias'].resize_(num_class+1)
pretrained_weights['model']['class_embed.4.weight'].resize_(num_class+1, 256)
pretrained_weights['model']['class_embed.4.bias'].resize_(num_class+1)
pretrained_weights['model']['class_embed.5.weight'].resize_(num_class+1, 256)
pretrained_weights['model']['class_embed.5.bias'].resize_(num_class+1)
pretrained_weights['model']['query_embed.weight'].resize_(50, 512) # 此处50对应生成queries的数量,根据main.py中--num_queries数量修改
torch.save(pretrained_weights, 'de_detr-r50_%d.pth'%num_class)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20

5 运行时出错

Original Traceback (most recent call last):
File "D:\Anaconda3\envs\deformable_detr\Lib\site-packages\numpy\core\function_base.py", line 117, in linspace
num = operator.index(num)
TypeError: 'numpy.float64' object cannot be interpreted as an integer

可以修改该处代码为:

 # num = operator.index(num)
 num = operator.index(int(num))
  • 1
  • 2

目前采坑就这些
To Be Updated

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/笔触狂放9/article/detail/337799
推荐阅读
相关标签
  

闽ICP备14008679号