当前位置:   article > 正文

caffe学习(10):交通标志目标检测训练整体流程_gtsdb数据集

gtsdb数据集

下载数据集

GTSDB数据集介绍:用于目标检测的交通标志数据集。共900张图,格式为.ppm。共4大类(禁止,危险,强制,其他)42小类。
官方下载地址
在这里插入图片描述
FullIJCNN2013.zip是主要文件。里面包含00000.pmm~00899.ppm共900张.ppm格式的图片。还有一份真值文件gt.txt,里面记录900张图片的标注信息。

数据集处理

由于我用的SSD模型是基于VOC数据集格式的,我要把GTSDB转换成VOC格式。

VOC格式

在这里插入图片描述
根目录下三个目录。Annotations/放xml文件,JPEGImages/放jpg图片。JPEGImages/Main/放train.txt等。

.pmm转.jpg

VOC格式需要图片格式是.jpg。转换结果放入VOC指定位置。ppm2jpg.py代码如下

# implementation: /home/jqy/data/FullIJCNN2013/all/*.ppm  ->  /home/jqy/data/gtsdb_voc/all/*.jpg
# /home/jqy/data/FullIJCNN2013/00/*.ppm  ->  /home/jqy/data/gtsdb_voc/00/*.jpg
# and el.

from PIL import Image
import os

input_train_path="/home/jqy/data/FullIJCNN2013/"
output_train_path = "/home/jqy/data/gtsdbjpg/"
# input_test_path="./full/"
# output_test_path = "./full/"
 
def batch_image(in_dir, out_dir):
    if not os.path.exists(out_dir):
        print(out_dir, 'is not existed.')
        os.mkdir(out_dir)
 
    if not os.path.exists(in_dir):
        print(in_dir, 'is not existed.')
        return -1
    
    directories = [d for d in os.listdir(in_dir) if os.path.isdir(os.path.join(in_dir, d))]
    for d in directories:
        label_directory = os.path.join(in_dir, d)
        new_directory = os.path.join(out_dir, d)
        out_folder = os.path.exists(out_dir+d)
        if not out_folder:
            os.mkdir(new_directory)
        file_names = [os.path.join(label_directory, f) for f in os.listdir(label_directory) if f.endswith(".ppm")]
        # file_names is every photo which is end with ".ppm"
 
        count = 0
        for files in file_names:
            file_path, extfilename = os.path.split(files)
            filename, extname = os.path.splitext(extfilename)
            out_file = filename + '.jpg'
            # print(filepath,',',filename, ',', out_file)
            im = Image.open(files)
            new_path = os.path.join(new_directory, out_file)
            print(count, ',', new_path)
            count = count + 1
            im.save(new_path)
 
 
 
if __name__ == '__main__':
    # batch_image(input_test_path, output_test_path)
    batch_image(input_train_path, output_train_path)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48

gt.txt转xml

VOC标志信息记录在xml中。转换结果放入VOC指定位置。我是分成4+1=5大类,背景为0。gt2xml.py代码如下:

#!/usr/bin/env python
#-*- coding:utf-8 -*-
import sys
import os
import codecs
import cv2
reload(sys)
sys.setdefaultencoding('utf8')

root = r'/home/jqy/data/gtsdbjpg/all/xml/' # output xml path
fp = open('gt.txt') # path of gt.txt
#fp2 = open('train.txt', 'w') # path of train.txt
uavinfo = fp.readlines()

def get_label(label):
    prohibitory = [0, 1, 2, 3, 4, 5, 7, 8, 9, 10, 15, 16]  # (circular, white ground with red border)
    mandatory = [33, 34, 35, 36, 37, 38, 39, 40]  # (circular, blue ground)
    danger = [11, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31]  # (triangular, white ground with red border)

    if label in prohibitory:
        new_label = 1
    elif label in mandatory:
        new_label = 2
    elif label in danger:
        new_label = 3
    else:
        new_label = 4

    return new_label
    
for i in range(len(uavinfo)):
    line = uavinfo[i]
    line = line.strip().split(';') 
    line[0] = "/home/jqy/data/gtsdbjpg/all/"+str(line[0]) # need to write image path
    img = cv2.imread(line[0])
    print line[0]
    sp = img.shape
    height = sp[0]
    width = sp[1]
    depth = sp[2]
    info1 = line[0].split('/')[-1]
    info2 = info1.split('.')[0]

    l_pos1 = line[1]
    l_pos2 = line[2]
    r_pos1 = line[3]
    r_pos2 = line[4]
    name = int(line[5])
    lable = get_label(name)
    if(os.path.exists(root + info2 + '.xml') == False):
        #fp2.writelines(info2 + '\n')
        with codecs.open(root + info2 + '.xml', 'w', 'utf-8') as xml:
            # xml.write('<?xml version="1.0" encoding="UTF-8"?>\n')
            xml.write('<annotation>\n')
            xml.write('\t<folder>' + 'GTSDB' + '</folder>\n')
            xml.write('\t<filename>' + info1 + '</filename>\n')
            xml.write('\t<source>\n')
            xml.write('\t\t<database>The GTSDB Database</database>\n')
            xml.write('\t</source>\n')
            xml.write('\t<size>\n')
            xml.write('\t\t<width>'+ str(width) + '</width>\n')
            xml.write('\t\t<height>'+ str(height) + '</height>\n')
            xml.write('\t\t<depth>' + str(depth) + '</depth>\n')
            xml.write('\t</size>\n')
            xml.write('\t\t<segmented>0</segmented>\n')
            xml.write('\t<object>\n')
            xml.write('\t\t<name>' + str(lable) + '</name>\n')
            xml.write('\t\t<pose>Unspecified</pose>\n')
            xml.write('\t\t<truncated>0</truncated>\n')
            xml.write('\t\t<difficult>0</difficult>\n')
            xml.write('\t\t<bndbox>\n')
            xml.write('\t\t\t<xmin>' + l_pos1 + '</xmin>\n')
            xml.write('\t\t\t<ymin>' + l_pos2 + '</ymin>\n')
            xml.write('\t\t\t<xmax>' + r_pos1 + '</xmax>\n')
            xml.write('\t\t\t<ymax>' + r_pos2 + '</ymax>\n')
            xml.write('\t\t</bndbox>\n')
            xml.write('\t</object>\n')
            xml.write('</annotation>')
    else:
        with codecs.open(root + info2 + '.xml', 'r', 'utf-8') as xml:
            lines = xml.readlines()
        with codecs.open(root + info2 + '.xml', 'w', 'utf-8') as xml:
            for line in lines:
                if "</annotation>" in line:
                    continue
                xml.write(line)
            xml.write('\t<object>\n')
            xml.write('\t\t<name>' + str(lable) + '</name>\n')
            xml.write('\t\t<pose>Unspecified</pose>\n')
            xml.write('\t\t<truncated>0</truncated>\n')
            xml.write('\t\t<difficult>0</difficult>\n')
            xml.write('\t\t<bndbox>\n')
            xml.write('\t\t\t<xmin>' + l_pos1 + '</xmin>\n')
            xml.write('\t\t\t<ymin>' + l_pos2 + '</ymin>\n')
            xml.write('\t\t\t<xmax>' + r_pos1 + '</xmax>\n')
            xml.write('\t\t\t<ymax>' + r_pos2 + '</ymax>\n')
            xml.write('\t\t</bndbox>\n')
            xml.write('\t</object>\n')
            xml.write('</annotation>')
for i in range(900):
    name = str(i).zfill(5)
    img_path = "/home/jqy/data/gtsdbjpg/all/" + name + '.jpg'
    print(img_path)
    img = cv2.imread(str(img_path))
    sp = img.shape
    height = sp[0]
    width = sp[1]
    depth = sp[2]
    info1 = img_path.split('/')[-1]
    info2 = info1.split('.')[0]
    if(os.path.exists(root + info2 + '.xml') == False):
        #fp2.writelines(info2 + '\n')
        with codecs.open(root + info2 + '.xml', 'w', 'utf-8') as xml:
            # xml.write('<?xml version="1.0" encoding="UTF-8"?>\n')
            xml.write('<annotation>\n')
            xml.write('\t<folder>' + 'GTSDB' + '</folder>\n')
            xml.write('\t<filename>' + info1 + '</filename>\n')
            xml.write('\t<source>\n')
            xml.write('\t\t<database>The GTSDB Database</database>\n')
            xml.write('\t</source>\n')
            xml.write('\t<size>\n')
            xml.write('\t\t<width>'+ str(width) + '</width>\n')
            xml.write('\t\t<height>'+ str(height) + '</height>\n')
            xml.write('\t\t<depth>' + str(depth) + '</depth>\n')
            xml.write('\t</size>\n')
            xml.write('\t\t<segmented>0</segmented>\n')
            xml.write('</annotation>')
#fp2.close()
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107
  • 108
  • 109
  • 110
  • 111
  • 112
  • 113
  • 114
  • 115
  • 116
  • 117
  • 118
  • 119
  • 120
  • 121
  • 122
  • 123
  • 124
  • 125
  • 126
  • 127
  • 128

训练前数据准备

分割训练集和测试集数据

所有原始图片数据,训练集和测试集需要按比例分割开来。train_test_separation.py代码如下:

import os 
import random 
trainval_percent = 0.75 
train_percent = 0.5 
xmlfilepath = '/home/jqy/data/gtsdb_voc/Annotations'
txtsavepath = '/home/jqy/data/gtsdb_voc/ImageSets/Main'
total_xml = os.listdir(xmlfilepath) 
num=len(total_xml) 
list=range(num) 
tv=int(num*trainval_percent) 
tr=int(tv*train_percent) 
trainval= random.sample(list,tv)
train=random.sample(trainval,tr)
ftrainval = open('/home/jqy/data/gtsdb_voc/ImageSets/Main/trainval.txt', 'w')
ftest = open('/home/jqy/data/gtsdb_voc/ImageSets/Main/test.txt', 'w')
ftrain = open('/home/jqy/data/gtsdb_voc/ImageSets/Main/train.txt', 'w')
fval = open('/home/jqy/data/gtsdb_voc/ImageSets/Main/val.txt', 'w')
for i in list: 
	name=total_xml[i][:-4]+'\n'
	if i in trainval:
		ftrainval.write(name)
		if i in train:
			ftrain.write(name)
		else:
			fval.write(name)
	else:
		ftest.write(name)
ftrainval.close()
ftrain.close()
fval.close()
ftest .close()
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31

会出现以下四个文件:这四个文件内容是图片的名称序号(例如000000.jpg,则txt记录一行000000)。trainval.txt记录训练集和验证集,train.txt记录训练集,val.txt记录验证集,test.txt记录测试集。
在这里插入图片描述

生成(图片+标注的txt)

Caffe_Root/data/gtsdb_voc/create_list.sh代码如下:

#!/bin/bash

root_dir=/home/jqy/data/
sub_dir=ImageSets/Main
bash_dir="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
for dataset in trainval test
do
  dst_file=$bash_dir/$dataset.txt
  if [ -f $dst_file ]
  then
    rm -f $dst_file
  fi
  for name in gtsdb_voc
  do
    echo "Create list for $name $dataset..."
    dataset_file=$root_dir/$name/$sub_dir/$dataset.txt

    img_file=$bash_dir/$dataset"_img.txt"
    cp $dataset_file $img_file
    sed -i "s/^/$name\/JPEGImages\//g" $img_file
    sed -i "s/$/.jpg/g" $img_file

    label_file=$bash_dir/$dataset"_label.txt"
    cp $dataset_file $label_file
    sed -i "s/^/$name\/Annotations\//g" $label_file
    sed -i "s/$/.xml/g" $label_file

    paste -d' ' $img_file $label_file >> $dst_file

    rm -f $label_file
    rm -f $img_file
  done

  # Generate image name and size infomation.
  if [ $dataset == "test" ]
  then
    $bash_dir/../../build/tools/get_image_size $root_dir $dst_file $bash_dir/$dataset"_name_size.txt"
  fi

  # Shuffle trainval file.
  if [ $dataset == "trainval" ]
  then
    rand_file=$dst_file.random
    cat $dst_file | perl -MList::Util=shuffle -e 'print shuffle(<STDIN>);' > $rand_file
    mv $rand_file $dst_file
  fi
done

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48

生成文件trainval.txt、test.txt,记录图片和对应的xml标注文件。
在这里插入图片描述
生成test_name_size.txt,记录测试集图片长宽大小。
在这里插入图片描述

创建lmdb格式文件

需要Caffe_Root/data/gtsdb_voc/下的trainval.txt、test.txt、test_name_size.txt、labelmap_voc.prototxt。
其中labelmap_voc.prototxt需要根据自己的数据集写(注意背景0也要写进去):
在这里插入图片描述
Caffe_Root/data/gtsdb_voc/create_data.sh代码如下:

cur_dir=$(cd $( dirname ${BASH_SOURCE[0]} ) && pwd )
root_dir=$cur_dir/../..
cd $root_dir

redo=1
data_root_dir="/home/jqy/data/"
dataset_name="gtsdb_voc"
mapfile="$root_dir/data/$dataset_name/labelmap_voc.prototxt"
anno_type="detection"
db="lmdb"
min_dim=0
max_dim=0
width=300
height=300

extra_cmd="--encode-type=jpg --encoded"
if [ $redo ]
then
  extra_cmd="$extra_cmd --redo"
fi
for subset in test trainval
do
  python2 $root_dir/scripts/create_annoset.py --anno-type=$anno_type --label-map-file=$mapfile --min-dim=$min_dim --max-dim=$max_dim --resize-width=$width --resize-height=$height --check-label $extra_cmd $data_root_dir $root_dir/data/$dataset_name/$subset.txt $data_root_dir/$dataset_name/$db/$dataset_name"_"$subset"_"$db examples/$dataset_name
done
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24

生成的lmdb文件在:/home/jqy/data/gtsdb_voc/lmdb在这里插入图片描述

开始训练

整体概览

/home/jqy/jqy_caffe/caffe-gpu/caffe-ssd/examples/ssd/ssd_gtsdb.py生成caffe训练所需文件。
/home/jqy/jqy_caffe/caffe-gpu/caffe-ssd/examples/ssd/score_ssd_gtsdb.py生成caffe评估所需文件。
/home/jqy/jqy_caffe/caffe-gpu/caffe-ssd/jobs/VGGNet/gtsdb_voc下:
在这里插入图片描述

cd SSD_300x300
./VGG_gtsdb_voc_SSD_300x300.sh # 开始训练
  • 1
  • 2
cd SSD_300x300_score
./VGG_gtsdb_voc_SSD_300x300.sh # 开始评估
  • 1
  • 2

训练好的模型在/home/jqy/jqy_caffe/caffe-gpu/caffe-ssd/models/VGGNet/gtsdb_voc
在这里插入图片描述
在这里插入图片描述
result在/home/jqy/data/gtsdb_voc/results/5classes/SSD_300x300/Main
在这里插入图片描述

具体操作

修改ssd_gtsdb.py

运行相关

在这里插入图片描述
run_soon设置是否要立即开始训练,我一般都False。之后可以去jobs文件夹下手动运行。
resume_training:没仔细研究
remove_old_models:没仔细研究

lmdb文件路径

在这里插入图片描述

均值

以下两个均值需要根据不同数据集修改,暂时我还没改,我今天才意识到这个问题
在这里插入图片描述在这里插入图片描述

一些路径设置

在这里插入图片描述
save_dir:权值文件.caffemodel存储路径。
snapshot_dir:快照文件.solverstate(快照是指训练中断后继续训练用的文件)存储路径。
job_dir:开始训练的.sh文件和.log文件存储路径。
output_result_dir:results路径。这里面文件我还没细看是啥。
pretrain_model:注意如果没有已经训练的模型,运行ssd_gtsdb.py后会自动选用此预训练模型。

类别相关

在这里插入图片描述

批处理设置

在这里插入图片描述
GPU选择,我有俩GPU
在这里插入图片描述
根据显存容量设置,这边我的俩2080ti能跑48。

base_lr(基础学习率)

注意GTSDB数据集训练的时候base_lr需要缩小点,SSD官方源码源码base_lr=0.001,我这边是调整成了0.0001。
ssd_gtsdb.py可以设置,在这里插入图片描述乘了0.1。
当然ssd_gtsdb.py运行后生成的solver.prototxt中也能单独修改。

训练参数设置

在这里插入图片描述
修改完成后,

cd Caffe_Root
python examples/ssd/ssd_gtsdb.py
  • 1
  • 2
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/小小林熬夜学编程/article/detail/373375
推荐阅读
相关标签
  

闽ICP备14008679号