当前位置:   article > 正文

RTDETR模型一键训练/预测(执行train.sh与detect.sh)_rtdert训练xml数据

rtdert训练xml数据

引言

本文章基于客户一键训练与测试需求,我使用u公司的yolov8集成的RTDETR模型改成较为保姆级的一键操作的训练/预测方式,也特别适合新手或想偷懒转换数据格式的朋友们。本文一键体现数据格式为图像与xml,调用train.sh与detect.sh可完成模型的训练与预测。而为完成该操作,模型内嵌入xml转RTDETR的txt格式、自动分配训练/验证集、自动切换环境等内容。接下来,我将介绍如何操作,并附修改源码。

源码链接:我已上传个人资源,请自行下载!

一、配置参数设置

该文件是RTDETR数据转换配置和模型使用参数,被我修改满足一键训练与测试文件的配置参数。包含将图像与xml文件数据格式转为模型训练格式数据,只需要提供xml与图像文件夹,可完成数据转换,详情如下:

# 设置img与xml的文件路径,也可为同一个文件,按照xml选择img
img_path: C:/Users/Administrator/Desktop/rtdetr/example_template/data  #
xml_path: C:/Users/Administrator/Desktop/rtdetr/example_template/data

# 设置数据集训练与验证集测试的比率,和小于1,通常test比率不设置为0
train_rate: 0.8
val_rate: 0.2
test_rate:

path: C:/Users/Administrator/Desktop/rtdetr/example_template/rtdert_data  # 必填,转换存放数据集文件夹,必须设置
train: images/train  # 不设置
val: images/val  # 不设置
test:  
# Classes


names:
  0: person
  1: bicycle
  2: car
  3: motorcycle
  4: airplane
  5: bus
  6: train
  7: truck
  8: boat
  9: traffic light
  10: fire hydrant
  11: stop sign
  12: parking meter
  13: bench
  14: bird
  15: cat
  16: dog
  17: horse
  18: sheep
  19: cow
  20: elephant
  21: bear
  22: zebra
  23: giraffe
  24: backpack
  25: umbrella
  26: handbag
  27: tie
  28: suitcase
  29: frisbee
  30: skis
  31: snowboard
  32: sports ball
  33: kite
  34: baseball bat
  35: baseball glove
  36: skateboard
  37: surfboard
  38: tennis racket
  39: bottle
  40: wine glass
  41: cup
  42: fork
  43: knife
  44: spoon
  45: bowl
  46: banana
  47: apple
  48: sandwich
  49: orange
  50: broccoli
  51: carrot
  52: hot dog
  53: pizza
  54: donut
  55: cake
  56: chair
  57: couch
  58: potted plant
  59: bed
  60: dining table
  61: toilet
  62: tv
  63: laptop
  64: mouse
  65: remote
  66: keyboard
  67: cell phone
  68: microwave
  69: oven
  70: toaster
  71: sink
  72: refrigerator
  73: book
  74: clock
  75: vase
  76: scissors
  77: teddy bear
  78: hair drier
  79: toothbrush

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98

二、数据格式转换代码

该文件代码提供了xml格式转rtdetr模型需要格式,基本是属于逻辑,代码能力较为基础,我不在介绍,代码如下:

import pandas as pd
import cv2
from tqdm import tqdm
import os
import numpy as np
import json
import xml.etree.ElementTree as ET
from lxml.etree import Element, SubElement, tostring, ElementTree
from xml.dom.minidom import parseString
import random
import shutil
import yaml

img_format = ['.jpg', '.png', '.bmp']


def build_dir(root):
    import os
    if not os.path.exists(root):
        os.makedirs(root)
    return root


def del_dir(root):
    import os
    if os.path.exists(root):
        shutil.rmtree(root)
    return root


############################################生成xml方法##########################
def product_xml(name_img, boxes, codes, img=None, wh=None):
    '''
    :param img: 以读好的图片
    :param name_img: 图片名字,如'xxx.jpg'
    :param boxes: box为列表
    :param codes: 为列表
    :return:
    '''
    if img is not None:
        width = img.shape[0]
        height = img.shape[1]
    else:
        assert wh is not None
        width = wh[0]
        height = wh[1]

    node_root = Element('annotation')
    node_folder = SubElement(node_root, 'folder')
    node_folder.text = 'VOC2007'

    node_filename = SubElement(node_root, 'filename')
    node_filename.text = name_img  # 图片名字

    node_size = SubElement(node_root, 'size')
    node_width = SubElement(node_size, 'width')
    node_width.text = str(height)

    node_height = SubElement(node_size, 'height')
    node_height.text = str(width)

    node_depth = SubElement(node_size, 'depth')
    node_depth.text = '3'

    for i, code in enumerate(codes):
        box = [boxes[i][0], boxes[i][1], boxes[i][2], boxes[i][3]]
        node_object = SubElement(node_root, 'object')
        node_name = SubElement(node_object, 'name')
        node_name.text = code
        node_difficult = SubElement(node_object, 'difficult')
        node_difficult.text = '0'
        node_bndbox = SubElement(node_object, 'bndbox')
        node_xmin = SubElement(node_bndbox, 'xmin')
        node_xmin.text = str(int(box[0]))
        node_ymin = SubElement(node_bndbox, 'ymin')
        node_ymin.text = str(int(box[1]))
        node_xmax = SubElement(node_bndbox, 'xmax')
        node_xmax.text = str(int(box[2]))
        node_ymax = SubElement(node_bndbox, 'ymax')
        node_ymax.text = str(int(box[3]))

    xml = tostring(node_root, pretty_print=True)  # 格式化显示,该换行的换行
    dom = parseString(xml)

    name = name_img[:-4] + '.xml'

    tree = ElementTree(node_root)

    print('name:{},dom:{}'.format(name, dom))
    return tree, name


def product_xml_demo():
    '''
    通过box与cat信息为图片产生xml文件
    '''
    img_root = r'C:\Users\Administrator\Desktop\123\1.jpg'
    write_img_name = 'hhhaaa.jpg'
    bboxes_lst = [[22, 32, 46, 89]]
    cat_lst = ['cat']
    img = cv2.imread(img_root)
    tree, xml_name = product_xml(write_img_name, bboxes_lst, cat_lst, img=img)
    tree.write(os.path.join('./', xml_name))


############################################xml转yolo的txt##########################
def read_xml(xml_root):
    '''
    :param xml_root: .xml文件
    :return: dict('cat':['cat1',...],'bboxes':[[x1,y1,x2,y2],...],'whd':[w ,h,d])
    '''
    dict_info = {'cat': [], 'bboxes': [], 'box_wh': [], 'whd': []}
    if os.path.splitext(xml_root)[-1] == '.xml':
        tree = ET.parse(xml_root)  # ET是一个xml文件解析库,ET.parse()打开xml文件。parse--"解析"
        root = tree.getroot()  # 获取根节点
        whd = root.find('size')
        whd = [whd.find('width').text, whd.find('height').text, whd.find('depth').text]

        for obj in root.findall('object'):  # 找到根节点下所有“object”节点
            cat = str(obj.find('name').text)  # 找到object节点下name子节点的值(字符串)
            bbox = obj.find('bndbox')
            x1, y1, x2, y2 = [int(bbox.find('xmin').text),
                              int(bbox.find('ymin').text),
                              int(bbox.find('xmax').text),
                              int(bbox.find('ymax').text)]
            b_w = x2 - x1 + 1
            b_h = y2 - y1 + 1

            dict_info['cat'].append(cat)
            dict_info['bboxes'].append([x1, y1, x2, y2])
            dict_info['box_wh'].append([b_w, b_h])
            dict_info['whd'].append(whd)
    else:
        print('[inexistence]:{} suffix is not xml '.format(xml_root))
    return dict_info

def write_txt(text_lst, out_txt=None):
    '''
    每行内容为列表,将其写入text中
    '''
    out_dir = out_txt if out_txt is not None else 'classes.txt'
    file_write_obj = open(out_dir, 'w', encoding='utf-8')  # 以写的方式打开文件,如果文件不存在,就会自动创建
    for text in text_lst:
        file_write_obj.writelines(str(text))
        file_write_obj.write('\n')
    file_write_obj.close()

def xml2yolotxt(xml_root, img_root=None, save_txt=None, labels_name_lst=None):
    '''
    :param xml_root: xml的路径
    :param img_root:图像路径,可提供也可不提供,提供主要获得图像的高宽
    :param out_file:保存txt路径的文件夹
    :param labels_name_lst:提供训练列表,xml中出现类别与列表对应,如['pedes', 'elec', 'car', 'truck', 'bus', 'tricycle']
    pedes表示0,elec表示1,car表示2:return:
    '''

    if labels_name_lst is None:
        raise ValueError("lack labels list  ")
    if save_txt is None:
        raise ValueError("lack saving root for txt file  ")

    xml_info = read_xml(xml_root)

    if img_root is not None:
        # 从中提取W与H
        img = cv2.imread(img_root)
        H, W = img.shape[:2]
    else:
        whd = xml_info['whd'][0]
        W, H = float(whd[0]), float(whd[1])

    boxes_lst = xml_info['bboxes']
    labels_lst = xml_info['cat']

    yolotxt_lst = []
    for i, b in enumerate(boxes_lst):
        label = labels_lst[i]
        if label in labels_name_lst:
            label_idx = list(labels_name_lst).index(label)

            bw, bh = b[2] - b[0], b[3] - b[1]
            x, y = b[0] + bw / 2, b[1] + bh / 2
            x, y, w, h = x / W, y / H, bw / W, bh / H
            # yolotxt = str(cat_lst[i]) + ' ' + str(x) + ' ' + str(y) + ' ' + str(w) + ' ' + str(h)
            yolotxt = str(label_idx) + ' ' + str(x) + ' ' + str(y) + ' ' + str(w) + ' ' + str(h)
            yolotxt_lst.append(yolotxt)
    if len(yolotxt_lst) > 0:
        write_txt(yolotxt_lst, save_txt)

def convert_data_train(xml_path, img_path, out_file_path, labels_name_lst, **kwargs):
    '''
    xml_path:xml文件夹的路径
    img_path:图片文件夹的路径
    out_file_path:模型训练的文件夹,用于yolo模型训练
    labels_name_lst:标签列表,模型只转换与训练的标签列表
    kwargs:其它参数

    '''

    print('\n convert data...')

    img_suffix = kwargs.get('img_suffix') if kwargs.get('img_suffix') else 4
    img_names = [name for name in os.listdir(img_path) if name[-4:] in img_format]
    img_names_no_suffix = [name[:-img_suffix] for name in img_names]

    xml_names_temp = [name for name in os.listdir(xml_path) if name[-3:] == 'xml']
    N = len(xml_names_temp)
    N_idx = [i for i in range(N)]
    random.shuffle(N_idx)
    xml_names = [xml_names_temp[i] for i in N_idx]

    train_N = N * kwargs.get('train_rate') if kwargs.get('train_rate') else 0.7 * N
    val_N = N * kwargs.get('val_rate') if kwargs.get('val_rate') else 0.3 * N
    test_N = N * kwargs.get('test_rate') if kwargs.get('test_rate') else 0

    if (train_N / N + val_N / N + test_N / N) > 1:
        raise ValueError(
            "rate of datasets error,sum>1, train_rate:{}\tval_rate:{}\ttest_rate{}".format(train_N / N, val_N / N,
                                                                                           test_N / N))

    # 构建训练文件
    images_path = os.path.join(out_file_path, 'images')
    labels_path = os.path.join(out_file_path, 'labels')

    del_dir(images_path)
    del_dir(labels_path)

    build_dir(images_path)
    build_dir(labels_path)

    train_img_path = build_dir(os.path.join(images_path, 'train'))
    val_img_path = build_dir(os.path.join(images_path, 'val'))
    test_img_path = build_dir(os.path.join(images_path, 'test'))

    train_label_path = build_dir(os.path.join(labels_path, 'train'))
    val_label_path = build_dir(os.path.join(labels_path, 'val'))
    test_label_path = build_dir(os.path.join(labels_path, 'test'))
    problem_xmls=[]

    for i in tqdm(range(int(train_N))):
        xml_name = xml_names[i]
        xml_root = os.path.join(xml_path, xml_name)
        if xml_name[:-4] in list(img_names_no_suffix):
            img_idx = list(img_names_no_suffix).index(xml_name[:-4])
            img_name = img_names[img_idx]
            img_root = os.path.join(img_path, img_name)
            save_txt = os.path.join(train_label_path, xml_name[:-3] + 'txt')
            try:
                xml2yolotxt(xml_root, img_root=img_root, save_txt=save_txt, labels_name_lst=labels_name_lst)
            except:
                problem_xmls.append(xml_root)
                break
            shutil.copy(img_root, os.path.join(train_img_path, img_name))
    print('\nfinishing vonvert of train data,train_rate:\t{}\t    train count:\t{} \n'.format(train_N / N, int(train_N)))

    for i in tqdm(range(int(train_N), int(train_N + val_N))):
        xml_name = xml_names[i]
        xml_root = os.path.join(xml_path, xml_name)
        if xml_name[:-4] in list(img_names_no_suffix):
            img_idx = list(img_names_no_suffix).index(xml_name[:-4])
            img_name = img_names[img_idx]
            img_root = os.path.join(img_path, img_name)
            save_txt = os.path.join(val_label_path, xml_name[:-3] + 'txt')
            try:
                xml2yolotxt(xml_root, img_root=img_root, save_txt=save_txt, labels_name_lst=labels_name_lst)
            except:
                problem_xmls.append(xml_root)
                break

            # xml2yolotxt(xml_root, img_root=img_root, save_txt=save_txt, labels_name_lst=labels_name_lst)
            shutil.copy(img_root, os.path.join(val_img_path, img_name))
    print('\nfinishing vonvert of val data,    val_rate:\t{}\t      val count:\t{} \n'.format(val_N / N, int(val_N)))

    for i in tqdm(range(int(train_N + val_N), int(train_N + val_N + test_N))):
        xml_name = xml_names[i]
        xml_root = os.path.join(xml_path, xml_name)
        if xml_name[:-4] in list(img_names_no_suffix):
            img_idx = list(img_names_no_suffix).index(xml_name[:-4])
            img_name = img_names[img_idx]
            img_root = os.path.join(img_path, img_name)
            save_txt = os.path.join(test_label_path, xml_name[:-3] + 'txt')
            try:
                xml2yolotxt(xml_root, img_root=img_root, save_txt=save_txt, labels_name_lst=labels_name_lst)
            except:
                problem_xmls.append(xml_root)
                break
            # xml2yolotxt(xml_root, img_root=img_root, save_txt=save_txt, labels_name_lst=labels_name_lst)
            shutil.copy(img_root, os.path.join(test_img_path, img_name))
    print('\nfinishing vonvert of test data,  test_rate:\t{}\t      test count:\t{} \n'.format(test_N / N, int(test_N)))

    print( '\n problem xml:{}\n'.format(len(problem_xmls))  )
    for probel_path in problem_xmls:
        print(probel_path)

def product_yolo_dataset(yaml_path):
    f = open(yaml_path, 'rb')
    cfg = yaml.load(f, Loader=yaml.FullLoader)
    img_path = cfg['img_path']
    xml_path = cfg['xml_path']
    out_file_path = cfg['path']
    labels_name_lst = [v for k,v in cfg['names'].items()]
    kwargs = {"train_rate": cfg['train_rate'], "val_rate": cfg['val_rate'], "test_rate": cfg['test_rate']}

    convert_data_train(xml_path, img_path, out_file_path, labels_name_lst, **kwargs)
    return cfg

def yolo_dataset_demo():
    '''
    将xml数据格式转换为yolo格式的方法
    '''
    yaml_path = 'coco128_auto.yaml'

    product_yolo_dataset(yaml_path)




def read_yaml(yaml_path):
    f = open(yaml_path, 'rb')
    cfg = yaml.load(f, Loader=yaml.FullLoader)

    return cfg



def del_runsfile():
    from pathlib import Path
    import sys
    FILE = Path(__file__).resolve()
    ROOT = FILE.parents[0]  # YOLOv5 root directory
    if str(ROOT) not in sys.path:
        sys.path.append(str(ROOT))  # add ROOT to PATH
    ROOT = Path(os.path.relpath(ROOT, Path.cwd()))  # relative
    del_dir(ROOT/'runs/detect/train')




if __name__ == '__main__':
    yolo_dataset_demo()
    del_runsfile() # 帮忙删除runs文件
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107
  • 108
  • 109
  • 110
  • 111
  • 112
  • 113
  • 114
  • 115
  • 116
  • 117
  • 118
  • 119
  • 120
  • 121
  • 122
  • 123
  • 124
  • 125
  • 126
  • 127
  • 128
  • 129
  • 130
  • 131
  • 132
  • 133
  • 134
  • 135
  • 136
  • 137
  • 138
  • 139
  • 140
  • 141
  • 142
  • 143
  • 144
  • 145
  • 146
  • 147
  • 148
  • 149
  • 150
  • 151
  • 152
  • 153
  • 154
  • 155
  • 156
  • 157
  • 158
  • 159
  • 160
  • 161
  • 162
  • 163
  • 164
  • 165
  • 166
  • 167
  • 168
  • 169
  • 170
  • 171
  • 172
  • 173
  • 174
  • 175
  • 176
  • 177
  • 178
  • 179
  • 180
  • 181
  • 182
  • 183
  • 184
  • 185
  • 186
  • 187
  • 188
  • 189
  • 190
  • 191
  • 192
  • 193
  • 194
  • 195
  • 196
  • 197
  • 198
  • 199
  • 200
  • 201
  • 202
  • 203
  • 204
  • 205
  • 206
  • 207
  • 208
  • 209
  • 210
  • 211
  • 212
  • 213
  • 214
  • 215
  • 216
  • 217
  • 218
  • 219
  • 220
  • 221
  • 222
  • 223
  • 224
  • 225
  • 226
  • 227
  • 228
  • 229
  • 230
  • 231
  • 232
  • 233
  • 234
  • 235
  • 236
  • 237
  • 238
  • 239
  • 240
  • 241
  • 242
  • 243
  • 244
  • 245
  • 246
  • 247
  • 248
  • 249
  • 250
  • 251
  • 252
  • 253
  • 254
  • 255
  • 256
  • 257
  • 258
  • 259
  • 260
  • 261
  • 262
  • 263
  • 264
  • 265
  • 266
  • 267
  • 268
  • 269
  • 270
  • 271
  • 272
  • 273
  • 274
  • 275
  • 276
  • 277
  • 278
  • 279
  • 280
  • 281
  • 282
  • 283
  • 284
  • 285
  • 286
  • 287
  • 288
  • 289
  • 290
  • 291
  • 292
  • 293
  • 294
  • 295
  • 296
  • 297
  • 298
  • 299
  • 300
  • 301
  • 302
  • 303
  • 304
  • 305
  • 306
  • 307
  • 308
  • 309
  • 310
  • 311
  • 312
  • 313
  • 314
  • 315
  • 316
  • 317
  • 318
  • 319
  • 320
  • 321
  • 322
  • 323
  • 324
  • 325
  • 326
  • 327
  • 328
  • 329
  • 330
  • 331
  • 332
  • 333
  • 334
  • 335
  • 336
  • 337
  • 338
  • 339
  • 340
  • 341
  • 342

注:该代码只需图像文件与对应xml文件,即可按照比列转换train、val、test数据。

三、一键训练/预测的sh内容

1、训练sh文件(train.sh)内容

训练文件为sh文件,只需通过以下命令,实现训练。

sh train.sh
  • 1

该文件包含虚拟环境切换与自动调用模型训练,其详情如下:

# train.sh

train_weight=/home/ubuntu/Project/tj/auto_project/RTDETR/model_rtdetr/rtdetr-l.pt

echo -e "\n"train time $(date "+%Y-%m-%d")"\n"


# 更换虚拟环境

__conda_setup="$('/home/ubuntu/miniconda3/bin/conda' 'shell.bash' 'hook' 2> /dev/null)"
if [ $? -eq 0 ]; then
	        eval "$__conda_setup"
		    else
			                if [ -f "/home/ubuntu/miniconda3/etc/profile.d/conda.sh" ]; then
						                    . "/home/ubuntu/miniconda3/etc/profile.d/conda.sh"
								                        else
												                            export PATH="/home/ubuntu/miniconda3/bin:$PATH"
															                                fi
fi
unset __conda_setup
conda activate yolov8

cur_dir=$(cd `dirname $0`;pwd)  # 获得当前路径
echo -e  "\ncur_dir:"${cur_dir}"\n"

yaml_dir=$cur_dir/coco128_auto.yaml
echo -e  "\nyaml_dir:"${yaml_dir}"\n"

#save_dir=$cur_dir/runs/train
#echo -e "\nsave_dir:"$save_dir"\n"
#
#
#if [ -d ${save_dir} ];then
#	    echo "save_dir 文件存在"
#    else
#	    echo "save_dir文件不存在-->创建文件"
#	    mkdir -p  $save_dir
#fi

cd ${cur_dir}
ls
echo -e "\n\n\n\t\t\t start train  ... \n\n\n"
# xml数据转txt数据格式
python auto_tools.py
yolo train model=$train_weight data=$yaml_dir epochs=300 imgsz=640  batch=24 amp=False  name=train/exp

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46

2、train.sh内容说明

1、开头有一个重要预训练权重路径,确定使用rtdetr哪个模型,默认为l模型
train_weight=/home/oem/Project/tj/auto_project/RTDETR/model_rtdetr/rtdetr-l.pt

2、最后一句模型运行命令,默认参数命令如下:
yolo train model= t r a i n w e i g h t d a t a = train_weight data= trainweightdata=yaml_dir epochs=300 imgsz=640 batch=12 amp=False name=train/exp

3、添加参数
显卡选择参数device,添加 device=0,1或device=0等形式

3、预测sh文件(detect.sh)介绍

预测文件为sh文件,只需通过以下命令,实现训练。

sh detect.sh
  • 1

该文件包含虚拟环境切换与自动调用模型预测,其详情如下:


# detect.sh

echo -e "\n"detect time $(date "+%Y-%m-%d")"\n"

# 更换虚拟环境

__conda_setup="$('/home/ubuntu/miniconda3/bin/conda' 'shell.bash' 'hook' 2> /dev/null)"
if [ $? -eq 0 ]; then
	        eval "$__conda_setup"
		    else
			                if [ -f "/home/ubuntu/miniconda3/etc/profile.d/conda.sh" ]; then
						                    . "/home/ubuntu/miniconda3/etc/profile.d/conda.sh"
								                        else
												                            export PATH="/home/ubuntu/miniconda3/bin:$PATH"
															                                fi
fi
unset __conda_setup
conda activate yolov8

cur_dir=$(cd `dirname $0`;pwd)  # 获得当前路径
echo -e  "\ncur_dir:"${cur_dir}"\n"

yaml_dir=$cur_dir/coco128_auto.yaml
echo -e  "\nyaml_dir:"${yaml_dir}"\n"
save_dir=$cur_dir/runs/detect
echo -e "\nsave_dir:"$save_dir"\n"

if [ -d ${save_dir} ];then
	    echo "save_dir 文件存在"
    else
	    echo "save_dir文件不存在-->创建文件"
	    mkdir -p  $save_dir
fi

cd ${cur_dir}

ls

echo -e "\n\n\n\t\t\t start detect  ... \n\n\n"

python  predect.py  --conf_thres 0.25


  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44

4、detect.sh内容说明

1、最后一句模型运行命令,默认参数命令如下:
python predect.py --conf_thres 0.25

2、添加权重与图片保存路径,如下格式
–weights /home/ubuntu/runs/detect/train/exp/weights/best.pt
–save_dir /home/ubuntu/runs/detect/predect/exp

四、训练、预测运行结果显示

1、训练效果展示

在这里插入图片描述

2、预测效果展示

在这里插入图片描述

总结

本文一个目的,傻瓜式训练与预测,通过sh脚本实现3个任务,
①、虚拟环境自动切换
②、数据格式自动转换,输入为图像文件与对应xml文件自动完成rtdetr模型训练与预测数据格式
③、模型自动训练与预测,且只需执行sh train.sh或 sh detect.sh即可实现

整体脚本:点击这里

文件整体格式如下图:
在这里插入图片描述

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/秋刀鱼在做梦/article/detail/941166
推荐阅读
相关标签
  

闽ICP备14008679号