当前位置:   article > 正文

超详解|yoloV8训练自己的数据集|训练|测试|部署

yolov8训练自己的数据集

1:部署源码

我的训练环境:
Ubuntu 18.04
python:3.9
torch:2.2.2+cu121
torchvision:0.17.2+cu121
GPU:Nvidia GeForce RTX 3090 * 2

这里提供中文的官方文档,方便小伙伴们阅读:Ultralytics官网
python环境,还请各位小伙伴自己安装好。

1.1:下载源码

github地址:https://github.com/ultralytics/ultralytics
可以去github上下载压缩包或者使用git命令下载源码

git clone https://github.com/ultralytics/ultralytics.git
  • 1

1.2:安装运行环境

  pip install ultralytics==8.2.9 -i https://pypi.tuna.tsinghua.edu.cn/simple
  pip install hub-sdk -i https://pypi.tuna.tsinghua.edu.cn/simple
  
  # 安装torch-gpu版本
  conda install pytorch==2.2.2 torchvision==0.17.2 torchaudio==2.2.2 pytorch-cuda=12.1 -c pytorch -c nvidia
  • 1
  • 2
  • 3
  • 4
  • 5

1.3:下载预训练模型

各位也可以在官网进行下载,我这边提取了下载链接,可以直接点击下载

model下载
YOLOv8n下载
YOLOv8s下载
YOLOv8m下载
YOLOv8l下载
YOLOv8x下载

1.4:测试

在项目根目录下,新建backbone文件夹,将下载好的预训练模型放到backbone文件夹。
使用以下代码测试,控制台输出结果。

from ultralytics import YOLO

# 加载预训练的YOLOv8n模型
model = YOLO('/backbone/yolov8s.pt')

# 在'bus.jpg'上运行推理,并附加参数
model.predict('/ultralytics/assets/bus.jpg', save=True, imgsz=640, conf=0.5)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7

1.5:可能遇见的错误

1:‘ has no attribute ‘FigureCanvas‘
降级matplotlib

pip install matplotlib==3.5.0 -i https://pypi.tuna.tsinghua.edu.cn/simple
  • 1

2:制作训练数据集

使用labelimg(labelimg免命令安装)进行标注的数据,训练集包含图片和xml标注文件,统一转成yolo数据格式。

2.1:将VOC格式的数据转为yolo格式

需要修改的地方,在代码中,我加上了todo,修改成自己实际的路径

# -*- coding:utf-8 -*-

"""
筛选公开数据集中,只包含目标数据集的图片和标签,自己标注的数据集也可以用。可以实现以下几点功能
1、筛选出自己感兴趣的标签,剔除不感兴趣的标签
2、按比例切分训练集、测试集、验证集
3、将xml标签文件转为yolo格式的标签
4、过滤无标签的标注文件
5、重命名转换后的文件名
"""

import os, sys, shutil
import cv2
import numpy as np

# 批量读取Annotations下的xml文件, 筛选感兴趣的标签数据
def read_xml(path_xml, input_path_img_dir, out_path_xml_dir, out_path_img_dir, class_count, pic_number,img_num, image_list):
    '''
    :param path_xml: 输入处理xml文件的绝对路径
    :return: 返回xml的label与box,其中label是一维的,并与box一一对应。
    '''
    import xml.etree.ElementTree as ET
    global classes
    with open(path_xml, 'r', encoding='UTF-8') as f:
        root = ET.parse(f).getroot()
    img_flag = False
    objects = root.findall('object')  # Get a list of all objects in this image.
    # 获取filename
    filename = str(root.find("filename").text).split(".")[0]  # 图片名称
    img_suffix = str(root.find("filename").text).split(".")[1]  # 图片后缀
    w = int(root.find("size").find('width').text)
    h = int(root.find("size").find('height').text)
    input_img_path = os.path.join(input_path_img_dir, filename + '.' + img_suffix)
    # 如果标签没有对应的图片文件,则直接返回
    if not os.path.exists(input_img_path):
        return pic_number, image_list
    print(path_xml)
    out_file = open(os.path.join(out_path_xml_dir, 'data%06d' %pic_number + '.txt'), 'w')
    # exit()
    # Parse the data for each object.
    for obj in objects:
        class_name = obj.find('name').text
        if class_name in classes:
            # 首先必须进行标签替换
            if class_name in replace_class.keys():
                class_name = replace_class[class_name]
            # 进行标签计数
            class_count[class_name] = class_count[class_name] + 1
            img_flag = True
            # Get the bounding box coordinates.
            bndbox = obj.find('bndbox')
            b1 = float(bndbox.find('xmin').text)  # bndbox.find('xmin').text
            b3 = float(bndbox.find('ymin').text)
            b2 = float(bndbox.find('xmax').text)
            b4 = float(bndbox.find('ymax').text)
            # 标注越界修正
            if b2 > w:
                b2 = w
            if b4 > h:
                b4 = h
            b = (b1, b2, b3, b4)
            bb = convert((w, h), b)
            cls_id = classes.index(class_name)
            out_file.write(str(cls_id) + " " + " ".join([str(a) for a in bb]) + '\n')  # 将xml数据标签,转成
        else:
            # 数据不做处理
            pass
    out_file.close()

    if img_flag:
        # 将图片分离出去
        out_img_path = os.path.join(out_path_img_dir, 'data%06d' % pic_number + '.' + img_suffix)
        # todo 这里是为了方便将图片和数据集移动到服务器的/home/taoxifa/data_collection路径之后,train.txt等文件的路径与服务器的一致(没试过相对路径,你们可以改成相对路径试一下)。
        txt_img_path = "/home/taoxifa/data_collection/images/" + 'data%06d' % pic_number + '.' + img_suffix
        shutil.copy(input_img_path, out_img_path)
        # 记录图片尺寸
        img = cv2.imdecode(np.fromfile(out_img_path, dtype=np.uint8), -1)
        size = img.shape
        if size not in class_count["size"]:
            class_count["size"].append(size)
        image_list.append(txt_img_path) # 如果是本机直接训练,不需要再把数据集移到服务器的话,将txt_img_path替换成out_img_path
        pic_number += 1
        img_num += 1
    return pic_number, image_list, img_num



def split_train_cal(txtsavepath, img_list, train_ratio, val_ratio, test_ratio):
    # 拆分数据集为训练集和测试集
    # 计算切分数据集的索引
    num_files = len(img_list)
    num_train = int(num_files * train_ratio)
    num_val = int(num_files * val_ratio)
    num_test = num_files - num_train - num_val

    # 分离训练集
    train_files = img_list[:num_train]

    # 分离验证集
    val_files = img_list[num_train:num_train + num_val]

    # 分离测试集
    test_files = img_list[num_train + num_val:]


    file_test = open(txtsavepath + '/test.txt', 'a')
    file_train = open(txtsavepath + '/train.txt', 'a')
    file_val = open(txtsavepath + '/val.txt', 'a')

    for i in train_files:
        file_train.write(i + '\n')

    for i in test_files:
        file_test.write(i + '\n')

    for i in val_files:
        file_val.write(i + '\n')

    # file_trainval.close()
    file_train.close()
    file_val.close()
    file_test.close()



def convert(size, box):
    dw = 1. / (size[0])
    dh = 1. / (size[1])
    x = (box[0] + box[1]) / 2.0 - 1
    y = (box[2] + box[3]) / 2.0 - 1
    w = box[1] - box[0]
    h = box[3] - box[2]
    x = x * dw
    w = w * dw
    y = y * dh
    h = h * dh
    return x, y, w, h



def read_xml2txt():
    # todo 处理多文件夹,目前需要救护车、渣土车数据、泥浆槽罐车数据(在这个字典中,修改自己的xml数据集路径)
    dir_all = [{
        "input_path_xml_dir": r"E:\date_collection\Objects21-vehicel数据集\annotations\train",  # 救护车数据集
        "input_path_img_dir": r"E:\date_collection\Objects21-vehicel数据集\images\train"
    },{
        "input_path_xml_dir": r"E:\date_collection\渣土车\Z-渣土车数据集2400VOC\Annotations",  # 渣土车数据集
        "input_path_img_dir": r"E:\date_collection\渣土车\Z-渣土车数据集2400VOC\JPEGImages"
    },{
        "input_path_xml_dir": r"E:\date_collection\渣土车\自制渣土车数据",  # 渣土车数据集
        "input_path_img_dir": r"E:\date_collection\渣土车\自制渣土车数据"
    },{
        "input_path_xml_dir": r"E:\date_collection\槽罐车\泥浆槽罐车",  # 泥浆槽罐车数据集
        "input_path_img_dir": r"E:\date_collection\槽罐车\泥浆槽罐车"
    }]

    class_count = {"ambulance": 0, "mud_tank_truck": 0, "slag_car": 0, "size": []}
    pic_number = 7830
    for dir_path in dir_all:
        image_list = []
        input_path_xml_dir = dir_path["input_path_xml_dir"]
        input_path_img_dir = dir_path["input_path_img_dir"]
        img_num = 1
        for xml_name in os.listdir(input_path_xml_dir):
            xml_path = os.path.join(input_path_xml_dir, xml_name)
            # 后缀不是xml的跳过
            if xml_name.split('.')[1] != "xml":
                continue
            # 数量限制,先取4000张图片用于训练
            if img_num >= 4000:
                continue
            pic_number, image_list, img_num = read_xml(xml_path, input_path_img_dir, out_path_xml_dir, out_path_img_dir, class_count, pic_number,img_num, image_list)

        # 切分数据集
        split_train_cal(txtsavepath, image_list, train_ratio, val_ratio, test_ratio)

    # 将最终的处理标签结果保存下来
    print("class_count", class_count)
    # with open(os.path.join(out_path_xml_dir + "/class_count.txt"), 'w') as v:
    #     v.write(class_count)


if __name__ == "__main__":
    # 待处理的xml文件路径列表
    train_ratio = 0.7  # 训练集比例
    val_ratio = 0.2  # 验证集比例
    test_ratio = 0.1  # 测试集比例
    classes = ["ambulance", "mud_tank_truck", "slag_car", "muck-truck"]  #  todo 待提取的感兴趣的标签,不在这个列表的标签,会进行剔除
    replace_class = {"muck-truck": "slag_car"}  #  todo 替换修改标签,如果有多数据集来源,标签名称不一致,可以在这个字典中替换{"替换前的标签":"替换后的标签"}
    out_path_xml_dir = r"E:\date_collection\工业区训练数据\数据增强\labels"  #  todo 标签保存的文件夹路径
    out_path_img_dir = r"E:\date_collection\工业区训练数据\数据增强\images"  #  todo 图片保存的文件夹路径
    txtsavepath = r"E:\date_collection\工业区训练数据\数据增强"  #  todo 数据集拆分的保存路径
    if not os.path.exists(out_path_xml_dir):
        os.mkdir(out_path_xml_dir)
    if not os.path.exists(out_path_img_dir):
        os.mkdir(out_path_img_dir)
    if not os.path.exists(txtsavepath):
        os.mkdir(txtsavepath)
    
    with open(os.path.join(out_path_xml_dir + "/classes.txt"), 'w') as f:
        for i in classes:
            f.write(i + '\n')

    read_xml2txt()
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107
  • 108
  • 109
  • 110
  • 111
  • 112
  • 113
  • 114
  • 115
  • 116
  • 117
  • 118
  • 119
  • 120
  • 121
  • 122
  • 123
  • 124
  • 125
  • 126
  • 127
  • 128
  • 129
  • 130
  • 131
  • 132
  • 133
  • 134
  • 135
  • 136
  • 137
  • 138
  • 139
  • 140
  • 141
  • 142
  • 143
  • 144
  • 145
  • 146
  • 147
  • 148
  • 149
  • 150
  • 151
  • 152
  • 153
  • 154
  • 155
  • 156
  • 157
  • 158
  • 159
  • 160
  • 161
  • 162
  • 163
  • 164
  • 165
  • 166
  • 167
  • 168
  • 169
  • 170
  • 171
  • 172
  • 173
  • 174
  • 175
  • 176
  • 177
  • 178
  • 179
  • 180
  • 181
  • 182
  • 183
  • 184
  • 185
  • 186
  • 187
  • 188
  • 189
  • 190
  • 191
  • 192
  • 193
  • 194
  • 195
  • 196
  • 197
  • 198
  • 199
  • 200
  • 201
  • 202
  • 203
  • 204

执行完成之后,会在目标文件夹新增以下文件
在这里插入图片描述
images:存储图片
labels:存储标签
test.txt:测试数据集路径
train.txt:训练数据集路径
val.txt:验证数据集路径
!!!务必保证txt文件中的路径与实际一致,如下图所示!!!
在这里插入图片描述

3:修改配置文件

3.1:配置自己的数据文件

在项目路径下\ultralytics\cfg\datasets\ 新建立一个my_data.yaml文件,内容如下


path: /home/taoxifa/data_collection # dataset root dir
train: train.txt # train images (relative to 'path') 118287 images
val: val.txt # val images (relative to 'path') 5000 images
test: test.txt # 20288 of 40670 images, submit to https://competitions.codalab.org/competitions/20794

# Classes
names:
  0: ambulance
  1: mud_tank_truck
  2: slag_car
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11

3.2:修改yolov8.yaml

文件路径:\ultralytics\cfg\models\v8\yolov8.yaml,修改nc=自己的类别
在这里插入图片描述

4:开始训练

训练代码如下

import os
# os.environ["OMP_NUM_THREADS"]='2'

from ultralytics import YOLO

if __name__ == "__main__":
    # Load a model
    model_yaml = r"/home/taoxifa/Ai_project/yolov8/ultralytics/cfg/models/v8/yolov8s.yaml"
    data_yaml = r"/home/taoxifa/Ai_project/yolov8/ultralytics/cfg/datasets/my_data.yaml"
    pre_model = r"/home/taoxifa/Ai_project/yolov8/backbone/yolov8s.pt"

    model = YOLO(model_yaml, task='detect').load(pre_model)  # build from YAML and transfer weights
    # model = YOLO(pre_model, task='detect')  # load a pretrained model (recommended for training)

    # Train the model
    results = model.train(data=data_yaml, epochs=2000, imgsz=640, batch=16, workers=8, device=[0,1],
                          cos_lr=True, close_mosaic=200, warmup_epochs=10)

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18

参数解释
data:数据集配置文件的路径,指定用于训练的数据集配置文件
epochs:训练过程中整个数据集将被迭代多少次。
imgsz:用于设置输入图像尺寸。训练时,只能是整数32的倍数,测试和预测时,可以使用设置[1920,1080]。
batch:每个批次中的图像数量。
workers:用于设置数据加载过程中的线程数
device:模型运行的设置,多gpu时,用列表表示
cos_lr:余弦学习率调度器,设置为True可以帮助模型在训练过程中按照余弦函数的形状调整学习率,从而在训练初期使用较高的学习率,有助于快速收敛,而在训练后期逐渐降低学习率,有助于细致调整模型参数
close_mosaic:用于确定是否在最后几个训练周期中禁用马赛克数据增强
warmup_epochs:预热学习轮数,学习率从低值逐渐增加到初始学习率,以在早期稳定训练
更多参数,参考:ultralytics-训练参数

5:查看训练结果

训练结果存储在runs/detect/train下,可以查看评估的结果和最优模型。
模型在weights路径下的best.pt。下面的图为训练15轮后的结果,非2000轮后的结果。
在这里插入图片描述

6:模型评估

代码评估

from ultralytics import YOLO

if __name__ == '__main__':

    # Load a model
    model_yaml = r"/home/taoxifa/Ai_project/yolov8/ultralytics/cfg/models/v8/yolov8s.yaml"
    data_yaml = r"/home/taoxifa/Ai_project/yolov8/ultralytics/cfg/datasets/my_data.yaml"
    pre_model = r"/home/taoxifa/Ai_project/yolov8/runs/detect/train/weights/best.pt"

    # model = YOLO(model_yaml, task='detect').load(pre_model)  # build from YAML and transfer weights
    model = YOLO(pre_model, task='detect')  # load a pretrained model (recommended for training)

    # Train the model
    results = model.val(data=data_yaml, imgsz=640, device=[0, 1])
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14

指令评估

yolo detect val model=/home/taoxifa/Ai_project/yolov8/runs/detect/train/weights/best.pt data=/home/taoxifa/Ai_project/yolov8/ultralytics/cfg/datasets/my_data.yaml
  • 1

评估结果
在这里插入图片描述

7:模型预测

from ultralytics import YOLO



if __name__ == '__main__':

    pth_path = r"/home/taoxifa/Ai_project/yolov8/runs/detect/train/weights/best.pt"

    test_path = r"/home/taoxifa/Ai_project/yolov8/images"
    # Load a model
    # model = YOLO('yolov8n.pt')  # load an official model
    model = YOLO(pth_path)  # load a custom model

    # Predict with the model
    results = model(test_path, save=True, conf=0.6, iou=0.8)  # predict on an image
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15

输出图片
在这里插入图片描述

在这里插入图片描述

8:模型导出(导出为onnx,可用于tensorrt或openvino推理部署使用)

安装其他依赖环境

  pip install openvino-dev==2024.0.0 -i https://pypi.tuna.tsinghua.edu.cn/simple
  pip install onnx==1.16.0 -i https://pypi.tuna.tsinghua.edu.cn/simple
  pip install nncf==2.9.0 -i https://pypi.tuna.tsinghua.edu.cn/simple
  pip install onnxruntime==1.17.3 -i https://pypi.tuna.tsinghua.edu.cn/simple
  • 1
  • 2
  • 3
  • 4
from ultralytics import YOLO

if __name__ == "__main__":

    pth_path = r"/home/taoxifa/Ai_project/yolov8/runs/detect/train/weights/best.pt"

    # Load a model
    model = YOLO("/home/taoxifa/Ai_project/yolov8/backbone/yolov8s.pt")  # 加载官方模型
    model = YOLO(pth_path)  # 加载自定义训练模型

    # Export the model
    model.export(format='openvino', int8=True) # 设置int8,会自动下载coco128的数据集对齐,量化成int8模型
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12

9:openvino Or ONNX 推理

openvino官网:官网openvino教程

# 官方位置:https://docs.openvino.ai/2024/notebooks/004-hello-detection-with-output.html

import argparse
import time
import cv2
import numpy as np
from openvino.runtime import Core  # pip install openvino -i  https://pypi.tuna.tsinghua.edu.cn/simple
import onnxruntime as ort  # 使用onnxruntime推理用上,pip install onnxruntime,默认安装CPU

# COCO默认的80类
CLASSES = ['ambulance', 'mud_tank_truck', 'slag_car']


class OpenvinoInference(object):
    def __init__(self, onnx_path):
        self.onnx_path = onnx_path
        ie = Core()
        self.model_onnx = ie.read_model(model=self.onnx_path)
        self.compiled_model_onnx = ie.compile_model(model=self.model_onnx, device_name="CPU", config={})
        self.output_layer_onnx = self.compiled_model_onnx.output(0)

    def predict(self, datas):
        predict_data = self.compiled_model_onnx([datas])[self.output_layer_onnx]
        return predict_data


class YOLOv8:
    """YOLOv8 object detection model class for handling inference and visualization."""

    def __init__(self, onnx_model, imgsz=(640, 640), infer_tool='openvino'):
        """
        Initialization.

        Args:
            onnx_model (str): Path to the ONNX model.
        """
        self.infer_tool = infer_tool
        if self.infer_tool == 'openvino':
            # 构建openvino推理引擎
            self.openvino = OpenvinoInference(onnx_model)
            self.ndtype = np.single
        else:
            # 构建onnxruntime推理引擎
            self.ort_session = ort.InferenceSession(onnx_model,
                                                    providers=['CUDAExecutionProvider', 'CPUExecutionProvider']
                                                    if ort.get_device() == 'GPU' else ['CPUExecutionProvider'])

            # Numpy dtype: support both FP32 and FP16 onnx model
            self.ndtype = np.half if self.ort_session.get_inputs()[0].type == 'tensor(float16)' else np.single

        self.classes = CLASSES  # 加载模型类别
        self.model_height, self.model_width = imgsz[0], imgsz[1]  # 图像resize大小
        self.color_palette = np.random.uniform(0, 255, size=(len(self.classes), 3))  # 为每个类别生成调色板

    def __call__(self, im0, conf_threshold=0.4, iou_threshold=0.45):
        """
        The whole pipeline: pre-process -> inference -> post-process.

        Args:
            im0 (Numpy.ndarray): original input image.
            conf_threshold (float): confidence threshold for filtering predictions.
            iou_threshold (float): iou threshold for NMS.

        Returns:
            boxes (List): list of bounding boxes.
        """
        # 前处理Pre-process
        t1 = time.time()
        im, ratio, (pad_w, pad_h) = self.preprocess(im0)
        print('预处理时间:{:.3f}s'.format(time.time() - t1))

        # 推理 inference
        t2 = time.time()
        if self.infer_tool == 'openvino':
            preds = self.openvino.predict(im)
        else:
            preds = self.ort_session.run(None, {self.ort_session.get_inputs()[0].name: im})[0]
        print('推理时间:{:.2f}s'.format(time.time() - t2))

        # 后处理Post-process
        t3 = time.time()
        boxes = self.postprocess(preds,
                                 im0=im0,
                                 ratio=ratio,
                                 pad_w=pad_w,
                                 pad_h=pad_h,
                                 conf_threshold=conf_threshold,
                                 iou_threshold=iou_threshold,
                                 )
        print('后处理时间:{:.3f}s'.format(time.time() - t3))

        return boxes

    # 前处理,包括:resize, pad, HWC to CHW,BGR to RGB,归一化,增加维度CHW -> BCHW
    def preprocess(self, img):
        """
        Pre-processes the input image.

        Args:
            img (Numpy.ndarray): image about to be processed.

        Returns:
            img_process (Numpy.ndarray): image preprocessed for inference.
            ratio (tuple): width, height ratios in letterbox.
            pad_w (float): width padding in letterbox.
            pad_h (float): height padding in letterbox.
        """
        # Resize and pad input image using letterbox() (Borrowed from Ultralytics)
        shape = img.shape[:2]  # original image shape
        new_shape = (self.model_height, self.model_width)
        r = min(new_shape[0] / shape[0], new_shape[1] / shape[1])
        ratio = r, r
        new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r))
        pad_w, pad_h = (new_shape[1] - new_unpad[0]) / 2, (new_shape[0] - new_unpad[1]) / 2  # wh padding
        if shape[::-1] != new_unpad:  # resize
            img = cv2.resize(img, new_unpad, interpolation=cv2.INTER_LINEAR)
        top, bottom = int(round(pad_h - 0.1)), int(round(pad_h + 0.1))
        left, right = int(round(pad_w - 0.1)), int(round(pad_w + 0.1))
        img = cv2.copyMakeBorder(img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=(114, 114, 114))  # 填充
        # Transforms: HWC to CHW -> BGR to RGB -> div(255) -> contiguous -> add axis(optional)
        img = np.ascontiguousarray(np.einsum('HWC->CHW', img)[::-1], dtype=self.ndtype) / 255.0
        img_process = img[None] if len(img.shape) == 3 else img
        return img_process, ratio, (pad_w, pad_h)

    # 后处理,包括:阈值过滤与NMS
    def postprocess(self, preds, im0, ratio, pad_w, pad_h, conf_threshold, iou_threshold):
        """
        Post-process the prediction.

        Args:
            preds (Numpy.ndarray): predictions come from ort.session.run().
            im0 (Numpy.ndarray): [h, w, c] original input image.
            ratio (tuple): width, height ratios in letterbox.
            pad_w (float): width padding in letterbox.
            pad_h (float): height padding in letterbox.
            conf_threshold (float): conf threshold.
            iou_threshold (float): iou threshold.

        Returns:
            boxes (List): list of bounding boxes.
        """
        x = preds  # outputs: predictions (1, 84, 8400)
        # Transpose the first output: (Batch_size, xywh_conf_cls, Num_anchors) -> (Batch_size, Num_anchors, xywh_conf_cls)
        x = np.einsum('bcn->bnc', x)  # (1, 8400, 84)

        # Predictions filtering by conf-threshold
        x = x[np.amax(x[..., 4:], axis=-1) > conf_threshold]

        # Create a new matrix which merge these(box, score, cls) into one
        # For more details about `numpy.c_()`: https://numpy.org/doc/1.26/reference/generated/numpy.c_.html
        x = np.c_[x[..., :4], np.amax(x[..., 4:], axis=-1), np.argmax(x[..., 4:], axis=-1)]

        # NMS filtering
        # 经过NMS后的值, np.array([[x, y, w, h, conf, cls], ...]), shape=(-1, 4 + 1 + 1)
        x = x[cv2.dnn.NMSBoxes(x[:, :4], x[:, 4], conf_threshold, iou_threshold)]

        # 重新缩放边界框,为画图做准备
        if len(x) > 0:
            # Bounding boxes format change: cxcywh -> xyxy
            x[..., [0, 1]] -= x[..., [2, 3]] / 2
            x[..., [2, 3]] += x[..., [0, 1]]

            # Rescales bounding boxes from model shape(model_height, model_width) to the shape of original image
            x[..., :4] -= [pad_w, pad_h, pad_w, pad_h]
            x[..., :4] /= min(ratio)

            # Bounding boxes boundary clamp
            x[..., [0, 2]] = x[:, [0, 2]].clip(0, im0.shape[1])
            x[..., [1, 3]] = x[:, [1, 3]].clip(0, im0.shape[0])

            return x[..., :6]  # boxes
        else:
            return []


    # 绘框
    def draw_and_visualize(self, im, bboxes, vis=False, save=True):
        """
        Draw and visualize results.

        Args:
            im (np.ndarray): original image, shape [h, w, c].
            bboxes (numpy.ndarray): [n, 6], n is number of bboxes.
            vis (bool): imshow using OpenCV.
            save (bool): save image annotated.

        Returns:
            None
        """
        # Draw rectangles
        for (*box, conf, cls_) in bboxes:
            # draw bbox rectangle
            cv2.rectangle(im, (int(box[0]), int(box[1])), (int(box[2]), int(box[3])),
                          self.color_palette[int(cls_)], 1, cv2.LINE_AA)
            cv2.putText(im, f'{self.classes[int(cls_)]}: {conf:.3f}', (int(box[0]), int(box[1] - 9)),
                        cv2.FONT_HERSHEY_SIMPLEX, 0.7, self.color_palette[int(cls_)], 2, cv2.LINE_AA)

        # # Show image
        # if vis:
        #     cv2.imshow('demo', im)
        #     cv2.waitKey(1)


        # # Save image
        if save:
            print("保存图片成功")
            cv2.imwrite('demo.png', im)
        return im

def read_Video(video_path, model, conf_threshold, iou_threshold):
    cap = cv2.VideoCapture(video_path)


    while True:
        # Capture NEXT frame
        # next_frame = player.next()
        start_time = time.time()
        frame_number = 0
        ret, img = cap.read()
        if ret:
            boxes = model(img, conf_threshold, iou_threshold)
            if len(boxes) > 0:
                img = model.draw_and_visualize(img, boxes, vis=False, save=True)
            stop_time = time.time()
            total_time = stop_time - start_time
            frame_number = frame_number + 1
            async_fps = frame_number / total_time
            print("async_fps",async_fps)
            cv2.putText(img, f'{async_fps:.3f}', (20, 20),
                        cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255,255,255), 2, cv2.LINE_AA)
            cv2.imshow("img", img)
            key = cv2.waitKey(1)
            # escape = 27
            if key == 27:
                break


if __name__ == '__main__':
    # Create an argument parser to handle command-line arguments
    parser = argparse.ArgumentParser()
    parser.add_argument('--model', type=str, default=r'/home/taoxifa/Ai_project/yolov8/runs/detect/train/weights/best_int8_openvino_model/best.xml', help='Path to ONNX model')
    parser.add_argument('--source', type=str, default=str(r'/home/taoxifa/Ai_project/yolov8/images/data007753.jpg'), help='Path to input image')
    parser.add_argument('--imgsz', type=tuple, default=(640, 640), help='Image input size')
    parser.add_argument('--conf', type=float, default=0.6, help='Confidence threshold')
    parser.add_argument('--iou', type=float, default=0.1, help='NMS IoU threshold')
    parser.add_argument('--infer_tool', type=str, default='openvino', choices=("openvino", "onnxruntime"),
                        help='选择推理引擎')
    args = parser.parse_args()

    # Build model
    model = YOLOv8(args.model, args.imgsz, args.infer_tool)

    # Read image by OpenCV
    img = cv2.imread(args.source)
    # 视频推理
    # video_path = r"E:\date_collection\安全帽\工地全景视频.mp4"
    # read_Video(video_path, model, conf_threshold=args.conf, iou_threshold=args.iou)  # 视频推理


    # Inference
    boxes = model(img, conf_threshold=args.conf, iou_threshold=args.iou) # 图片推理

    # Visualize
    if len(boxes) > 0:
        model.draw_and_visualize(img, boxes, vis=False, save=True)
    # cv2.destroyAllWindows()

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107
  • 108
  • 109
  • 110
  • 111
  • 112
  • 113
  • 114
  • 115
  • 116
  • 117
  • 118
  • 119
  • 120
  • 121
  • 122
  • 123
  • 124
  • 125
  • 126
  • 127
  • 128
  • 129
  • 130
  • 131
  • 132
  • 133
  • 134
  • 135
  • 136
  • 137
  • 138
  • 139
  • 140
  • 141
  • 142
  • 143
  • 144
  • 145
  • 146
  • 147
  • 148
  • 149
  • 150
  • 151
  • 152
  • 153
  • 154
  • 155
  • 156
  • 157
  • 158
  • 159
  • 160
  • 161
  • 162
  • 163
  • 164
  • 165
  • 166
  • 167
  • 168
  • 169
  • 170
  • 171
  • 172
  • 173
  • 174
  • 175
  • 176
  • 177
  • 178
  • 179
  • 180
  • 181
  • 182
  • 183
  • 184
  • 185
  • 186
  • 187
  • 188
  • 189
  • 190
  • 191
  • 192
  • 193
  • 194
  • 195
  • 196
  • 197
  • 198
  • 199
  • 200
  • 201
  • 202
  • 203
  • 204
  • 205
  • 206
  • 207
  • 208
  • 209
  • 210
  • 211
  • 212
  • 213
  • 214
  • 215
  • 216
  • 217
  • 218
  • 219
  • 220
  • 221
  • 222
  • 223
  • 224
  • 225
  • 226
  • 227
  • 228
  • 229
  • 230
  • 231
  • 232
  • 233
  • 234
  • 235
  • 236
  • 237
  • 238
  • 239
  • 240
  • 241
  • 242
  • 243
  • 244
  • 245
  • 246
  • 247
  • 248
  • 249
  • 250
  • 251
  • 252
  • 253
  • 254
  • 255
  • 256
  • 257
  • 258
  • 259
  • 260
  • 261
  • 262
  • 263
  • 264
  • 265
  • 266
  • 267

码字不易,劳烦点赞+收藏。有问题,可以在评论区讨论,都会一一回应,谢谢。

本文内容由网友自发贡献,转载请注明出处:https://www.wpsshop.cn/w/你好赵伟/article/detail/857463
推荐阅读
相关标签
  

闽ICP备14008679号