当前位置:   article > 正文

使用DETR模型训练自己的数据集过程记录_detrx 训练自己的数据

detrx 训练自己的数据

记录使用DETR模型训练的过程(使用linux环境)。

官方源码和使用教程:

mirrors / facebookresearch / detr · GitCode

1 配置DETR环境

1.1 新建一个环境,并下载DETR源码文件

  1. conda create -n detr python=3.7
  2. # 进入环境
  3. conda activate detr
  4. cd %放置代码的路径%
  5. git clone https://github.com/facebookresearch/detr.git

1.2 安装pytorch

查看显卡支持的版本,安装相应版本的pytorch

  1. #查看版本
  2. nvidia-smi

Previous PyTorch Versions | PyTorch 在这里找到对应的版本安装,选择与CUDA版本相同及以下版本的cudatoolkit都可以

  1. #安装
  2. conda install pytorch==1.7.0 torchvision==0.8.0 torchaudio==0.7.0 cudatoolkit=11.0 -c pytorch

查看cuda是否可用

  1. python
  2. import torch
  3. torch.cuda.is_available()
  4. #显示True就说明可用

1.3 安装其他依赖库

然后按照官方的要求继续安装其他库包,如果pycocotools安装不上可以自己下载源码下来安装

源码在这里:GitHub - cocodataset/cocoapi: COCO API - Dataset @ http://cocodataset.org/

(更推荐用下面的pip安装方式,通过下载源码安装在训练评估时有可能会出现问题,若训练时出现NameError: name‘unicode’is not defined“ 这个问题可以看这里。)

  1. #(训练)安装pycocotools 和 scipy
  2. conda install cython scipy
  3. pip install -U 'git+https://github.com/cocodataset/cocoapi.git#subdirectory=PythonAPI'
  4. #(下面这个可选)安装panopticapi,应该是用于全景分割
  5. pip install git+https://github.com/cocodataset/panopticapi.git

1 预训练准备

2.1 预训练权重

下载官方提供的预训练模型,这里使用的是resnet50的backbone:

https://dl.fbaipublicfiles.com/detr/detr-r50-e632da11.pth

自己新建一个.py文件,在文件中写入:

  1. import torch
  2. pretrained_weights=torch.load('./detr-r50-e632da11.pth') # 注意修改成自己下载pth文件后放置的路径
  3. num_classes=2 # num_classes是类别数+1,我只有1类所以是2
  4. pretrained_weights["model"]["class_embed.weight"].resize_(num_classes+1,256)
  5. pretrained_weights["model"]["class_embed.bias"].resize_(num_classes+1)
  6. torch.save(pretrained_weights,"detr_r50_%d.pth"%num_classes)

注意num_classes数值要修改成自己的类别数+1(因为包括了背景),运行这个.py文件,会得到一个符合自己类别数的预训练模型。

2.2 自己的数据集

自己的数据集需要转化成coco数据集的格式,数据如下放置:

  1. +-- dataset
  2.     +-- train2017
  3.     +-- *.jpg
  4.     +-- val2017
  5.     +-- *.jpg
  6.     +-- annotations
  7.     +-- instances_train2017.json
  8.     +-- instances_val2017.json

关于数据如何转换成coco格式主要参考了这一篇:https://www.cnblogs.com/xiaochouk/p/15999696.html

2.3 修改代码

detr/models/detr.py 中第313行的num_classes修改为自己的类别数,因为coco数据集格式包含背景,所以在使用coco数据集时,类别数仍要记得+1。由于这里我只使用了实例分割,没有做全景分割,并不涉及到下面的coco_panoptic,因此,第317行的num_classes修不修改都可以。

上述工作完成后,就可以开始训练了。

3 训练

3.1 目标检测

使用detr/main.py文件进行训练,使用单张GPU训练时可以直接使用python命令。注意有一些参数需要指定,或者直接修改main.py中的默认值。例如:

  1. python /work/detr/main.py\
  2. --batch_size 4\
  3. --lr_drop 10\
  4. --output_dir /work/weights\
  5. --coco_path /work/dataset\
  6. --resume /detr/detr_r50_2.pth\
  7. --epochs 200

这里设置的batch_size 是一次训练4个样本,可以根据GPU的内存进行修改;lr_drop 设置为权重文件每10轮训练保存一次(在main.py中203行已经设置了每20轮固定保存一次权重,可以根据自己需求更改);output_dir是模型权重的保存路径;coco_path是数据集路径;resume是预训练文件路径(即2.1中制作的文件,后续训练按需求可以修改成自己训练好的其他权重文件);epoch是设置完整训练数据集200轮(根据自己的训练情况修改);还有其他参数都可以根据自己的情况设置。

当有多个GPU时,可以进行分布式训练,可以参考官方给出的命令,例如:

  1. python -m torch.distributed.launch --nproc_per_node=4\
  2. --use_env detr/main.py\
  3. --output_dir /work/weights\
  4. --coco_path /work/dataset\
  5. --resume /detr/detr_r50_2.pth\
  6. --epochs 200

这里nproc_per_node设置了利用4张GPU进行训练,其他参数设置与前面类似。训练完成后保存了评估文件、权重文件和训练日志文件log.txt。

至此,模型训练目标检测步骤完成。

3.2 目标分割

DETR模型除了可以进行目标检测,还提供了通过添加mask头的方式训练图像分割。官方也给出了相关示例。对于实例分割,需要利用3.1中训练好的目标检测的权重文件,冻结权重,训练分割头。

  1. python /home/work1/baiyw/DETR/code/detr/main.py\
  2. --masks\
  3. --coco_path /home/work1/baiyw/DETR/dataset/front_test\
  4. --batch_size 2\
  5. --epochs 40\
  6. --lr_drop 15\
  7. --frozen_weights /work/weights/checkpoint0199.pth\
  8. --output_dir /work/weights/segm_model

只需要在训练命令行添加"--mask"参数和"--frozen_weights"冻结权重路径(即选出一个合适的,3.1中训练好的权重文件),其他参数根据实际需求调整,分布式训练也类似。

因为我并没有使用全景分割的数据集格式,所以没有添加“--coco_panoptic_path”(全景分割数据集路径),“--dataset_file”数据格式也没有修改,使用的是默认的“coco”,有需要根据实际情况修改即可。

--coco_panoptic_path /path/to/coco_panoptic  --dataset_file coco_panoptic

在调试过程中参考的教程:

windows10复现DEtection TRansformers(DETR)并实现自己的数据集_detr复现_w1520039381的博客-CSDN博客

DETR训练自己的数据集_kyrie变相不减速的博客-CSDN博客

如何用DETR(detection transformer)训练自己的数据集_小小凡sir的博客-CSDN博客

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/知新_RL/article/detail/337802
推荐阅读
相关标签
  

闽ICP备14008679号