当前位置:   article > 正文

DETR训练自己的数据集,yolo数据集格式转为coco数据集格式_yolo格式转coco格式

yolo格式转coco格式

一、数据集准备

1.1 DETR数据格式

|--- dataset
	|--- train
		|--- 1.jpg
		|--- 2.jpg
	|--- val
		|--- 1.jpg
		|--- 2.jpg
	|--- annotations
		|--- instances_train.json
		|--- instances_val.json
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10

其中 instances_train.json 和 instances_val.json 中记录了图片标注信息,例如:

{
	"images": [
	{"file_name":"/home/shares/detr/datasets/val_xml/208.xml", 
	"height": 720, "width": 1280, "id": 208}, 
	{"file_name":"/home/shares/detr/datasets/val_xml/468.xml", 
	"height": 720, "width": 1280, "id": 468},
	...
	]
	 "type": "instances",
	 "annotations": [
	 {"area": 14151, "iscrowd": 0, "image_id": 208, "bbox": [360, 521, 159, 89], "category_id": 3, "id": 1, "ignore": 0, "segmentation": []},
	 {"area": 21890, "iscrowd": 0, "image_id": 468, "bbox": [209, 382, 110, 199], "category_id": 2, "id": 2, "ignore": 0, "segmentation": []},
	...
	]
	"categories": [
	{"supercategory": "none", "id": 1, "name": "cigaretteface"}, 
	{"supercategory": "none", "id": 2, "name": "smokeface"}, 
	{"supercategory": "none", "id": 3, "name": "normalface"}, 
	{"supercategory": "none", "id": 4, "name": "callface"}
	]
}
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21

1.2 yolo数据集格式转为coco数据集格式

若已有coco数据集格式则跳过此步骤
  • 1

1.2.1 yolo数据集格式如下:

|--- /home/shares/datasets/my_voc_dataset
	|--- Annotations
		|--- 1.xml
		|--- 2.xml
	|--- ImageSets
		|--- Main
			|--- train.txt
			|--- val.txt
	|--- JPEGImages
		|--- 1.jpg
		|--- 2.jpg
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11

1.2.2 使用下述代码可将此种结构目录转换为coco格式:

import os
import json
import glob
import shutil
import xml.etree.ElementTree as ET

# 定义标签从1开始编号并给定类别对应数字标签
START_BOUNDING_BOX_ID = 1
PRE_DEFINE_CATEGORIES = {"cigaretteface": 1, "smokeface": 2, "normalface": 3, "callface": 4}

def get(root, name):
    vars = root.findall(name)
    return vars

def get_and_check(root, name, length):
    vars = root.findall(name)
    if len(vars) == 0:
        raise ValueError("Can not find %s in %s." % (name, root.tag))
    if length > 0 and len(vars) != length:
        raise ValueError(
            "The size of %s is supposed to be %d, but is %d."
            % (name, length, len(vars))
        )
    if length == 1:
        vars = vars[0]
    return vars

def get_filename_as_int(filename):
    try:
        filename = filename.replace("\\", "/")
        filename = os.path.splitext(os.path.basename(filename))[0]
        return int(filename)
    except:
        raise ValueError(
            "Filename %s is supposed to be an integer." % (filename))

def get_categories(xml_files):
    """Generate category name to id mapping from a list of xml files.

    Arguments:
        xml_files {list} -- A list of xml file paths.

    Returns:
        dict -- category name to id mapping.
    """
    classes_names = []
    for xml_file in xml_files:
        tree = ET.parse(xml_file)
        root = tree.getroot()
        for member in root.findall("object"):
            classes_names.append(member[0].text)
    classes_names = list(set(classes_names))
    classes_names.sort()
    return {name: i for i, name in enumerate(classes_names)}

def convert(xml_files, json_file):
    json_dict = {"images": [], "type": "instances",
                 "annotations": [], "categories": []}
    if PRE_DEFINE_CATEGORIES is not None:
        categories = PRE_DEFINE_CATEGORIES
    else:
        categories = get_categories(xml_files)
    bnd_id = START_BOUNDING_BOX_ID
    nums = len(xml_files)
    i = 1
    for xml_file in xml_files:
        print('\r converting xml to json : {}/{}'.format(i, nums), end = "")
        i += 1
        tree = ET.parse(xml_file)
        root = tree.getroot()
        path = get(root, "path")

        # The filename must be a number
        image_id = get_filename_as_int(xml_file)
        size = get_and_check(root, "size", 1)
        width = int(get_and_check(size, "width", 1).text)
        height = int(get_and_check(size, "height", 1).text)
        image = {
            "file_name": xml_file,
            "height": height,
            "width": width,
            "id": image_id,
        }
        json_dict["images"].append(image)
        # Currently we do not support segmentation.
        #  segmented = get_and_check(root, 'segmented', 1).text
        #  assert segmented == '0'
        for obj in get(root, "object"):
            category = get_and_check(obj, "name", 1).text
            if category not in categories:
                new_id = len(categories)
                categories[category] = new_id
            category_id = categories[category]
            bndbox = get_and_check(obj, "bndbox", 1)
            xmin = int(get_and_check(bndbox, "xmin", 1).text) - 1
            ymin = int(get_and_check(bndbox, "ymin", 1).text) - 1
            xmax = int(get_and_check(bndbox, "xmax", 1).text)
            ymax = int(get_and_check(bndbox, "ymax", 1).text)
            assert xmax > xmin
            assert ymax > ymin
            o_width = abs(xmax - xmin)
            o_height = abs(ymax - ymin)
            ann = {
                "area": o_width * o_height,
                "iscrowd": 0,
                "image_id": image_id,
                "bbox": [xmin, ymin, o_width, o_height],
                "category_id": category_id,
                "id": bnd_id,
                "ignore": 0,
                "segmentation": [],
            }
            json_dict["annotations"].append(ann)
            bnd_id = bnd_id + 1
    print()
    for cate, cid in categories.items():
        cat = {"supercategory": "none", "id": cid, "name": cate}
        json_dict["categories"].append(cat)

    os.makedirs(os.path.dirname(json_file), exist_ok=True)
    json_fp = open(json_file, "w")
    json_str = json.dumps(json_dict)
    json_fp.write(json_str)
    json_fp.close()


if __name__ == "__main__":
    voc_path = "/home/shares/datasets/my_voc_dataset"
    
    #  保存coco格式数据集根目录
    save_coco_path = "/home/shares/detr/datasets"
    
    #  VOC只分了训练集和验证集即train.txt和val.txt
    data_type_list = ["train", "val"]
    for data_type in data_type_list:
        try:
            os.makedirs(os.path.join(save_coco_path, data_type))
            os.makedirs(os.path.join(save_coco_path, data_type+"_xml"))
            with open(os.path.join(voc_path, "ImageSets"+os.sep+"Main", data_type+".txt"), "r") as f:
                txt_ls = f.readlines()
            txt_ls = [i.strip() for i in txt_ls]
            idx = 0
            for i in os.listdir(os.path.join(voc_path, "JPEGImages")):
                print('\rcopying imgs', end = "")
                if os.path.splitext(i)[0] in txt_ls:
                    shutil.copy(os.path.join(voc_path, "JPEGImages", i),
                                os.path.join(save_coco_path, data_type, str(idx) + ".jpg"))
                    shutil.copy(os.path.join(voc_path, "Annotations", i[:-4]+".xml"), os.path.join(
                        save_coco_path, data_type+"_xml", str(idx)+".xml"))
                    idx += 1
        except:
            print("sdfsf")
        xml_path = os.path.join(save_coco_path, data_type+"_xml")
        xml_files = glob.glob(os.path.join(xml_path, "*.xml"))
        convert(xml_files, os.path.join(save_coco_path,
                "annotations", "instances_"+data_type+".json"))
        shutil.rmtree(xml_path)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107
  • 108
  • 109
  • 110
  • 111
  • 112
  • 113
  • 114
  • 115
  • 116
  • 117
  • 118
  • 119
  • 120
  • 121
  • 122
  • 123
  • 124
  • 125
  • 126
  • 127
  • 128
  • 129
  • 130
  • 131
  • 132
  • 133
  • 134
  • 135
  • 136
  • 137
  • 138
  • 139
  • 140
  • 141
  • 142
  • 143
  • 144
  • 145
  • 146
  • 147
  • 148
  • 149
  • 150
  • 151
  • 152
  • 153
  • 154
  • 155
  • 156
  • 157

执行完上述代码后,会自动创建路径
a. save_coco_path
b. save_coco_path/train
c. save_coco_path/val
且train、val下的图片数据从0开始编号

二、修改训练参数

进入detr/main.py编辑get_args_parser()

# dataset parameters
    parser.add_argument('--coco_path', type=str)
    parser.add_argument('--coco_panoptic_path', type=str)
	...
    parser.add_argument('--output_dir', default='',
                        help='path where to save, empty for no saving')
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6

coco_path:数据集路径,改为上述 save_coco_path 路径
output_dir:训练结果保存路径,eg:runs/result1
在对应参数设置里加入default关键字并赋值,修改后为:

# dataset parameters
    parser.add_argument('--coco_path', type=str, default = "/home/shares/detr/datasets")
    parser.add_argument('--coco_panoptic_path', type=str)
	...
    parser.add_argument('--output_dir', default="/home/shares/detr/runs/result1",
                        help='path where to save, empty for no saving')
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6

三、训练

在终端输入命令

python main.py
  • 1

运行后终端开始打印结果
在这里插入图片描述

四、训练结果

在runs/result1目录下会生成以下文件/文件夹
在这里插入图片描述
其中,checkpoint.pth即为训练完成权重;log.txt为训练记录内容,

本文内容由网友自发贡献,转载请注明出处:https://www.wpsshop.cn/w/从前慢现在也慢/article/detail/869594
推荐阅读
相关标签
  

闽ICP备14008679号