赞
踩
GTSDB数据集介绍:用于目标检测的交通标志数据集。共900张图,格式为.ppm。共4大类(禁止,危险,强制,其他)42小类。
官方下载地址
FullIJCNN2013.zip是主要文件。里面包含00000.pmm~00899.ppm共900张.ppm格式的图片。还有一份真值文件gt.txt,里面记录900张图片的标注信息。
由于我用的SSD模型是基于VOC数据集格式的,我要把GTSDB转换成VOC格式。
根目录下三个目录。Annotations/放xml文件,JPEGImages/放jpg图片。JPEGImages/Main/放train.txt等。
VOC格式需要图片格式是.jpg。转换结果放入VOC指定位置。ppm2jpg.py代码如下
# implementation: /home/jqy/data/FullIJCNN2013/all/*.ppm -> /home/jqy/data/gtsdb_voc/all/*.jpg
# /home/jqy/data/FullIJCNN2013/00/*.ppm -> /home/jqy/data/gtsdb_voc/00/*.jpg
# and el.
from PIL import Image
import os
input_train_path="/home/jqy/data/FullIJCNN2013/"
output_train_path = "/home/jqy/data/gtsdbjpg/"
# input_test_path="./full/"
# output_test_path = "./full/"
def batch_image(in_dir, out_dir):
if not os.path.exists(out_dir):
print(out_dir, 'is not existed.')
os.mkdir(out_dir)
if not os.path.exists(in_dir):
print(in_dir, 'is not existed.')
return -1
directories = [d for d in os.listdir(in_dir) if os.path.isdir(os.path.join(in_dir, d))]
for d in directories:
label_directory = os.path.join(in_dir, d)
new_directory = os.path.join(out_dir, d)
out_folder = os.path.exists(out_dir+d)
if not out_folder:
os.mkdir(new_directory)
file_names = [os.path.join(label_directory, f) for f in os.listdir(label_directory) if f.endswith(".ppm")]
# file_names is every photo which is end with ".ppm"
count = 0
for files in file_names:
file_path, extfilename = os.path.split(files)
filename, extname = os.path.splitext(extfilename)
out_file = filename + '.jpg'
# print(filepath,',',filename, ',', out_file)
im = Image.open(files)
new_path = os.path.join(new_directory, out_file)
print(count, ',', new_path)
count = count + 1
im.save(new_path)
if __name__ == '__main__':
# batch_image(input_test_path, output_test_path)
batch_image(input_train_path, output_train_path)
VOC标志信息记录在xml中。转换结果放入VOC指定位置。我是分成4+1=5大类,背景为0。gt2xml.py代码如下:
#!/usr/bin/env python
#-*- coding:utf-8 -*-
import sys
import os
import codecs
import cv2
reload(sys)
sys.setdefaultencoding('utf8')
root = r'/home/jqy/data/gtsdbjpg/all/xml/' # output xml path
fp = open('gt.txt') # path of gt.txt
#fp2 = open('train.txt', 'w') # path of train.txt
uavinfo = fp.readlines()
def get_label(label):
prohibitory = [0, 1, 2, 3, 4, 5, 7, 8, 9, 10, 15, 16] # (circular, white ground with red border)
mandatory = [33, 34, 35, 36, 37, 38, 39, 40] # (circular, blue ground)
danger = [11, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31] # (triangular, white ground with red border)
if label in prohibitory:
new_label = 1
elif label in mandatory:
new_label = 2
elif label in danger:
new_label = 3
else:
new_label = 4
return new_label
for i in range(len(uavinfo)):
line = uavinfo[i]
line = line.strip().split(';')
line[0] = "/home/jqy/data/gtsdbjpg/all/"+str(line[0]) # need to write image path
img = cv2.imread(line[0])
print line[0]
sp = img.shape
height = sp[0]
width = sp[1]
depth = sp[2]
info1 = line[0].split('/')[-1]
info2 = info1.split('.')[0]
l_pos1 = line[1]
l_pos2 = line[2]
r_pos1 = line[3]
r_pos2 = line[4]
name = int(line[5])
lable = get_label(name)
if(os.path.exists(root + info2 + '.xml') == False):
#fp2.writelines(info2 + '\n')
with codecs.open(root + info2 + '.xml', 'w', 'utf-8') as xml:
# xml.write('<?xml version="1.0" encoding="UTF-8"?>\n')
xml.write('<annotation>\n')
xml.write('\t<folder>' + 'GTSDB' + '</folder>\n')
xml.write('\t<filename>' + info1 + '</filename>\n')
xml.write('\t<source>\n')
xml.write('\t\t<database>The GTSDB Database</database>\n')
xml.write('\t</source>\n')
xml.write('\t<size>\n')
xml.write('\t\t<width>'+ str(width) + '</width>\n')
xml.write('\t\t<height>'+ str(height) + '</height>\n')
xml.write('\t\t<depth>' + str(depth) + '</depth>\n')
xml.write('\t</size>\n')
xml.write('\t\t<segmented>0</segmented>\n')
xml.write('\t<object>\n')
xml.write('\t\t<name>' + str(lable) + '</name>\n')
xml.write('\t\t<pose>Unspecified</pose>\n')
xml.write('\t\t<truncated>0</truncated>\n')
xml.write('\t\t<difficult>0</difficult>\n')
xml.write('\t\t<bndbox>\n')
xml.write('\t\t\t<xmin>' + l_pos1 + '</xmin>\n')
xml.write('\t\t\t<ymin>' + l_pos2 + '</ymin>\n')
xml.write('\t\t\t<xmax>' + r_pos1 + '</xmax>\n')
xml.write('\t\t\t<ymax>' + r_pos2 + '</ymax>\n')
xml.write('\t\t</bndbox>\n')
xml.write('\t</object>\n')
xml.write('</annotation>')
else:
with codecs.open(root + info2 + '.xml', 'r', 'utf-8') as xml:
lines = xml.readlines()
with codecs.open(root + info2 + '.xml', 'w', 'utf-8') as xml:
for line in lines:
if "</annotation>" in line:
continue
xml.write(line)
xml.write('\t<object>\n')
xml.write('\t\t<name>' + str(lable) + '</name>\n')
xml.write('\t\t<pose>Unspecified</pose>\n')
xml.write('\t\t<truncated>0</truncated>\n')
xml.write('\t\t<difficult>0</difficult>\n')
xml.write('\t\t<bndbox>\n')
xml.write('\t\t\t<xmin>' + l_pos1 + '</xmin>\n')
xml.write('\t\t\t<ymin>' + l_pos2 + '</ymin>\n')
xml.write('\t\t\t<xmax>' + r_pos1 + '</xmax>\n')
xml.write('\t\t\t<ymax>' + r_pos2 + '</ymax>\n')
xml.write('\t\t</bndbox>\n')
xml.write('\t</object>\n')
xml.write('</annotation>')
for i in range(900):
name = str(i).zfill(5)
img_path = "/home/jqy/data/gtsdbjpg/all/" + name + '.jpg'
print(img_path)
img = cv2.imread(str(img_path))
sp = img.shape
height = sp[0]
width = sp[1]
depth = sp[2]
info1 = img_path.split('/')[-1]
info2 = info1.split('.')[0]
if(os.path.exists(root + info2 + '.xml') == False):
#fp2.writelines(info2 + '\n')
with codecs.open(root + info2 + '.xml', 'w', 'utf-8') as xml:
# xml.write('<?xml version="1.0" encoding="UTF-8"?>\n')
xml.write('<annotation>\n')
xml.write('\t<folder>' + 'GTSDB' + '</folder>\n')
xml.write('\t<filename>' + info1 + '</filename>\n')
xml.write('\t<source>\n')
xml.write('\t\t<database>The GTSDB Database</database>\n')
xml.write('\t</source>\n')
xml.write('\t<size>\n')
xml.write('\t\t<width>'+ str(width) + '</width>\n')
xml.write('\t\t<height>'+ str(height) + '</height>\n')
xml.write('\t\t<depth>' + str(depth) + '</depth>\n')
xml.write('\t</size>\n')
xml.write('\t\t<segmented>0</segmented>\n')
xml.write('</annotation>')
#fp2.close()
所有原始图片数据,训练集和测试集需要按比例分割开来。train_test_separation.py代码如下:
import os
import random
trainval_percent = 0.75
train_percent = 0.5
xmlfilepath = '/home/jqy/data/gtsdb_voc/Annotations'
txtsavepath = '/home/jqy/data/gtsdb_voc/ImageSets/Main'
total_xml = os.listdir(xmlfilepath)
num=len(total_xml)
list=range(num)
tv=int(num*trainval_percent)
tr=int(tv*train_percent)
trainval= random.sample(list,tv)
train=random.sample(trainval,tr)
ftrainval = open('/home/jqy/data/gtsdb_voc/ImageSets/Main/trainval.txt', 'w')
ftest = open('/home/jqy/data/gtsdb_voc/ImageSets/Main/test.txt', 'w')
ftrain = open('/home/jqy/data/gtsdb_voc/ImageSets/Main/train.txt', 'w')
fval = open('/home/jqy/data/gtsdb_voc/ImageSets/Main/val.txt', 'w')
for i in list:
name=total_xml[i][:-4]+'\n'
if i in trainval:
ftrainval.write(name)
if i in train:
ftrain.write(name)
else:
fval.write(name)
else:
ftest.write(name)
ftrainval.close()
ftrain.close()
fval.close()
ftest .close()
会出现以下四个文件:这四个文件内容是图片的名称序号(例如000000.jpg,则txt记录一行000000)。trainval.txt记录训练集和验证集,train.txt记录训练集,val.txt记录验证集,test.txt记录测试集。
Caffe_Root/data/gtsdb_voc/create_list.sh
代码如下:
#!/bin/bash
root_dir=/home/jqy/data/
sub_dir=ImageSets/Main
bash_dir="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
for dataset in trainval test
do
dst_file=$bash_dir/$dataset.txt
if [ -f $dst_file ]
then
rm -f $dst_file
fi
for name in gtsdb_voc
do
echo "Create list for $name $dataset..."
dataset_file=$root_dir/$name/$sub_dir/$dataset.txt
img_file=$bash_dir/$dataset"_img.txt"
cp $dataset_file $img_file
sed -i "s/^/$name\/JPEGImages\//g" $img_file
sed -i "s/$/.jpg/g" $img_file
label_file=$bash_dir/$dataset"_label.txt"
cp $dataset_file $label_file
sed -i "s/^/$name\/Annotations\//g" $label_file
sed -i "s/$/.xml/g" $label_file
paste -d' ' $img_file $label_file >> $dst_file
rm -f $label_file
rm -f $img_file
done
# Generate image name and size infomation.
if [ $dataset == "test" ]
then
$bash_dir/../../build/tools/get_image_size $root_dir $dst_file $bash_dir/$dataset"_name_size.txt"
fi
# Shuffle trainval file.
if [ $dataset == "trainval" ]
then
rand_file=$dst_file.random
cat $dst_file | perl -MList::Util=shuffle -e 'print shuffle(<STDIN>);' > $rand_file
mv $rand_file $dst_file
fi
done
生成文件trainval.txt、test.txt,记录图片和对应的xml标注文件。
生成test_name_size.txt,记录测试集图片长宽大小。
需要Caffe_Root/data/gtsdb_voc/
下的trainval.txt、test.txt、test_name_size.txt、labelmap_voc.prototxt。
其中labelmap_voc.prototxt需要根据自己的数据集写(注意背景0也要写进去):
Caffe_Root/data/gtsdb_voc/create_data.sh
代码如下:
cur_dir=$(cd $( dirname ${BASH_SOURCE[0]} ) && pwd )
root_dir=$cur_dir/../..
cd $root_dir
redo=1
data_root_dir="/home/jqy/data/"
dataset_name="gtsdb_voc"
mapfile="$root_dir/data/$dataset_name/labelmap_voc.prototxt"
anno_type="detection"
db="lmdb"
min_dim=0
max_dim=0
width=300
height=300
extra_cmd="--encode-type=jpg --encoded"
if [ $redo ]
then
extra_cmd="$extra_cmd --redo"
fi
for subset in test trainval
do
python2 $root_dir/scripts/create_annoset.py --anno-type=$anno_type --label-map-file=$mapfile --min-dim=$min_dim --max-dim=$max_dim --resize-width=$width --resize-height=$height --check-label $extra_cmd $data_root_dir $root_dir/data/$dataset_name/$subset.txt $data_root_dir/$dataset_name/$db/$dataset_name"_"$subset"_"$db examples/$dataset_name
done
生成的lmdb文件在:/home/jqy/data/gtsdb_voc/lmdb
/home/jqy/jqy_caffe/caffe-gpu/caffe-ssd/examples/ssd/ssd_gtsdb.py
生成caffe训练所需文件。
/home/jqy/jqy_caffe/caffe-gpu/caffe-ssd/examples/ssd/score_ssd_gtsdb.py
生成caffe评估所需文件。
在/home/jqy/jqy_caffe/caffe-gpu/caffe-ssd/jobs/VGGNet/gtsdb_voc
下:
cd SSD_300x300
./VGG_gtsdb_voc_SSD_300x300.sh # 开始训练
cd SSD_300x300_score
./VGG_gtsdb_voc_SSD_300x300.sh # 开始评估
训练好的模型在/home/jqy/jqy_caffe/caffe-gpu/caffe-ssd/models/VGGNet/gtsdb_voc
下
result在/home/jqy/data/gtsdb_voc/results/5classes/SSD_300x300/Main
下
run_soon设置是否要立即开始训练,我一般都False。之后可以去jobs文件夹下手动运行。
resume_training:没仔细研究
remove_old_models:没仔细研究
以下两个均值需要根据不同数据集修改,暂时我还没改,我今天才意识到这个问题
save_dir:权值文件.caffemodel存储路径。
snapshot_dir:快照文件.solverstate(快照是指训练中断后继续训练用的文件)存储路径。
job_dir:开始训练的.sh文件和.log文件存储路径。
output_result_dir:results路径。这里面文件我还没细看是啥。
pretrain_model:注意如果没有已经训练的模型,运行ssd_gtsdb.py后会自动选用此预训练模型。
GPU选择,我有俩GPU
根据显存容量设置,这边我的俩2080ti能跑48。
注意GTSDB数据集训练的时候base_lr需要缩小点,SSD官方源码源码base_lr=0.001,我这边是调整成了0.0001。
ssd_gtsdb.py可以设置,乘了0.1。
当然ssd_gtsdb.py运行后生成的solver.prototxt中也能单独修改。
修改完成后,
cd Caffe_Root
python examples/ssd/ssd_gtsdb.py
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。