赞
踩
使用yolov5模型对利用开源火灾数据集进行训练模型,实现对照片、视频以及摄像头的火焰烟雾的检测。
详细的环境配置见:目标检测-YOLOV5-口罩检测
链接:https://pan.baidu.com/s/1LLfnSOEBdlgOnL_iEyMIsA
提取码:ghaq
该数据集的标签是xml格式,训练yolo模型需要txt类型的标签。需要先将xml类型的标签转换为txt类型,再划分数据集。
需要新建几个文件夹,具体如下:
划分数据集的py文件
split_train_val.py
# coding=utf-8 import xml.etree.ElementTree as ET import pickle import os from os import listdir, getcwd from os.path import join import random from shutil import copyfile classes = ["fire", "smoke"] # classes=["ball"] TRAIN_RATIO = 80 #训练集和验证集的比例,说明80%的训练集,20%的验证集 def clear_hidden_files(path): dir_list = os.listdir(path) for i in dir_list: abspath = os.path.join(os.path.abspath(path), i) if os.path.isfile(abspath): if i.startswith("._"): os.remove(abspath) else: clear_hidden_files(abspath) def convert(size, box): dw = 1. / size[0] dh = 1. / size[1] x = (box[0] + box[1]) / 2.0 y = (box[2] + box[3]) / 2.0 w = box[1] - box[0] h = box[3] - box[2] x = x * dw w = w * dw y = y * dh h = h * dh return (x, y, w, h) def convert_annotation(image_id): in_file = open('data-fire-smoke/fire/Annotation/%s.xml' % image_id,'rb') out_file = open('data-fire-smoke/fire/YOLOLabels/%s.txt' % image_id, 'w') tree = ET.parse(in_file) root = tree.getroot() size = root.find('size') w = int(size.find('width').text) h = int(size.find('height').text) for obj in root.iter('object'): difficult = obj.find('difficult').text cls = obj.find('name').text if cls not in classes or int(difficult) == 1: continue cls_id = classes.index(cls) xmlbox = obj.find('bndbox') b = (float(xmlbox.find('xmin').text), float(xmlbox.find('xmax').text), float(xmlbox.find('ymin').text), float(xmlbox.find('ymax').text)) bb = convert((w, h), b) out_file.write(str(cls_id) + " " + " ".join([str(a) for a in bb]) + '\n') in_file.close() out_file.close() wd = os.getcwd() wd = os.getcwd() data_base_dir = os.path.join(wd, "data-fire-smoke/") if not os.path.isdir(data_base_dir): os.mkdir(data_base_dir) work_sapce_dir = os.path.join(data_base_dir, "fire/") if not os.path.isdir(work_sapce_dir): os.mkdir(work_sapce_dir) annotation_dir = os.path.join(work_sapce_dir, "Annotation/") if not os.path.isdir(annotation_dir): os.mkdir(annotation_dir) clear_hidden_files(annotation_dir) image_dir = os.path.join(work_sapce_dir, "Image/") if not os.path.isdir(image_dir): os.mkdir(image_dir) clear_hidden_files(image_dir) yolo_labels_dir = os.path.join(work_sapce_dir, "YOLOLabels/") if not os.path.isdir(yolo_labels_dir): os.mkdir(yolo_labels_dir) clear_hidden_files(yolo_labels_dir) yolov5_images_dir = os.path.join(data_base_dir, "images/") if not os.path.isdir(yolov5_images_dir): os.mkdir(yolov5_images_dir) clear_hidden_files(yolov5_images_dir) yolov5_labels_dir = os.path.join(data_base_dir, "labels/") if not os.path.isdir(yolov5_labels_dir): os.mkdir(yolov5_labels_dir) clear_hidden_files(yolov5_labels_dir) yolov5_images_train_dir = os.path.join(yolov5_images_dir, "train/") if not os.path.isdir(yolov5_images_train_dir): os.mkdir(yolov5_images_train_dir) clear_hidden_files(yolov5_images_train_dir) yolov5_images_test_dir = os.path.join(yolov5_images_dir, "val/") if not os.path.isdir(yolov5_images_test_dir): os.mkdir(yolov5_images_test_dir) clear_hidden_files(yolov5_images_test_dir) yolov5_labels_train_dir = os.path.join(yolov5_labels_dir, "train/") if not os.path.isdir(yolov5_labels_train_dir): os.mkdir(yolov5_labels_train_dir) clear_hidden_files(yolov5_labels_train_dir) yolov5_labels_test_dir = os.path.join(yolov5_labels_dir, "val/") if not os.path.isdir(yolov5_labels_test_dir): os.mkdir(yolov5_labels_test_dir) clear_hidden_files(yolov5_labels_test_dir) train_file = open(os.path.join(wd, "yolov5_train.txt"), 'w') test_file = open(os.path.join(wd, "yolov5_val.txt"), 'w') train_file.close() test_file.close() train_file = open(os.path.join(wd, "yolov5_train.txt"), 'a') test_file = open(os.path.join(wd, "yolov5_val.txt"), 'a') list_imgs = os.listdir(image_dir) # list image files prob = random.randint(1, 100) print("Probability: %d" % prob) for i in range(0, len(list_imgs)): path = os.path.join(image_dir, list_imgs[i]) if os.path.isfile(path): image_path = image_dir + list_imgs[i] voc_path = list_imgs[i] (nameWithoutExtention, extention) = os.path.splitext(os.path.basename(image_path)) (voc_nameWithoutExtention, voc_extention) = os.path.splitext(os.path.basename(voc_path)) annotation_name = nameWithoutExtention + '.xml' annotation_path = os.path.join(annotation_dir, annotation_name) label_name = nameWithoutExtention + '.txt' label_path = os.path.join(yolo_labels_dir, label_name) prob = random.randint(1, 100) print("Probability: %d" % prob) if (prob < TRAIN_RATIO): # train dataset if os.path.exists(annotation_path): train_file.write(image_path + '\n') convert_annotation(nameWithoutExtention) # convert label copyfile(image_path, yolov5_images_train_dir + voc_path) copyfile(label_path, yolov5_labels_train_dir + label_name) else: # test dataset if os.path.exists(annotation_path): test_file.write(image_path + '\n') convert_annotation(nameWithoutExtention) # convert label copyfile(image_path, yolov5_images_test_dir + voc_path) copyfile(label_path, yolov5_labels_test_dir + label_name) train_file.close() test_file.close()
运行成功后会出现下图的文件夹
预训练权重下载 本实验使用的是 yolov5s.pt
预训练权重:预先设置的训练权重,可以缩短训练时间。
将yolov5s.pt放置到weights文件夹中
训练自己的数据集需要更改数据配置文件(data中的yaml)和模型配置文件(midels中的yaml文件)
复制data中的voc.yaml命名为fire.yaml
fire.yaml
# PASCAL VOC dataset http://host.robots.ox.ac.uk/pascal/VOC/ # Train command: python train.py --data voc.yaml # Default dataset location is next to /yolov5: # /parent_folder # /VOC # /yolov5 # download command/URL (optional) download: bash data/scripts/get_voc.sh # train and val data as 1) directory: path/images/, 2) file: path/images.txt, or 3) list: [path1/images/, path2/images/] train: data-fire-smoke/images/train/ # 划分后训练集的照片路径 val: data-fire-smoke/images/val/ # 划分后验证集的照片路径 # number of classes nc: 2 #标签的类别书 # class names names: [ 'fire', 'smoke'] #标签
预训练权重选择的是ylov5s.pt
所以models模型应该选择yolov5s.yaml(不同的与训练权重对应不同的网络层数,用错会报错)
复制models中的yolov5.yaml命名为fire.yaml
只需要修改一处地方:
打开train.py文件(train.py是训练自己数据集的函数)
需要修改
需要修改 | 配置文件路径 |
---|---|
预训练权重 | weights/yolov5s.pt |
数据集 | data/fire.yaml |
模型 | models/fire.yaml |
大概在train.py文件的458行左右
parser.add_argument('--weights', type=str, default='weights/yolov5s.pt', help='initial weights path')
parser.add_argument('--cfg', type=str, default='models/fire.yaml', help='model.yaml path')
parser.add_argument('--data', type=str, default='data/fire.yaml', help='data.yaml path')
训练结束会出现runs文件夹
在测试的时候我们选取最好的权重
将前面训练得到的最好的权重(best.pt)进行测试(detect.py测试文件)
pip3 install PyQt5
pip install PyQt5-tools
新建两个py文件
login.py
# -*- coding: utf-8 -*- # Form implementation generated from reading ui file 'login.ui' # # Created by: PyQt5 UI code generator 5.15.4 # # WARNING: Any manual changes made to this file will be lost when pyuic5 is # run again. Do not edit this file unless you know what you are doing. from PyQt5 import QtCore, QtGui, QtWidgets class Ui_login_MainWindow(object): def setupUi(self, login_MainWindow): login_MainWindow.setObjectName("login_MainWindow") login_MainWindow.setEnabled(True) login_MainWindow.resize(575, 392) login_MainWindow.setAnimated(True) self.centralwidget = QtWidgets.QWidget(login_MainWindow) self.centralwidget.setEnabled(True) self.centralwidget.setObjectName("centralwidget") self.label_3 = QtWidgets.QLabel(self.centralwidget) self.label_3.setGeometry(QtCore.QRect(90, 30, 800, 79)) font = QtGui.QFont() font.setPointSize(30) font.setBold(True) font.setItalic(True) font.setWeight(75) self.label_3.setFont(font) self.label_3.setObjectName("label_3") self.label_2 = QtWidgets.QLabel(self.centralwidget) self.label_2.setGeometry(QtCore.QRect(100, 200, 71, 41)) font = QtGui.QFont() font.setPointSize(15) self.label_2.setFont(font) self.label_2.setObjectName("label_2") self.label = QtWidgets.QLabel(self.centralwidget) self.label.setGeometry(QtCore.QRect(100, 150, 81, 41)) font = QtGui.QFont() font.setPointSize(15) self.label.setFont(font) self.label.setObjectName("label") self.lineEdit = QtWidgets.QLineEdit(self.centralwidget) self.lineEdit.setGeometry(QtCore.QRect(180, 200, 221, 41)) font = QtGui.QFont() font.setPointSize(15) self.lineEdit.setFont(font) self.lineEdit.setText("") self.lineEdit.setEchoMode(QtWidgets.QLineEdit.Password) self.lineEdit.setObjectName("lineEdit") self.lineEdit_2 = QtWidgets.QLineEdit(self.centralwidget) self.lineEdit_2.setGeometry(QtCore.QRect(180, 150, 221, 41)) font = QtGui.QFont() font.setPointSize(15) self.lineEdit_2.setFont(font) self.lineEdit_2.setObjectName("lineEdit_2") self.pushButton = QtWidgets.QPushButton(self.centralwidget) self.pushButton.setGeometry(QtCore.QRect(180, 280, 171, 51)) font = QtGui.QFont() font.setPointSize(20) font.setBold(True) font.setWeight(75) self.pushButton.setFont(font) self.pushButton.setObjectName("pushButton") login_MainWindow.setCentralWidget(self.centralwidget) self.menubar = QtWidgets.QMenuBar(login_MainWindow) self.menubar.setGeometry(QtCore.QRect(0, 0, 575, 22)) self.menubar.setObjectName("menubar") login_MainWindow.setMenuBar(self.menubar) self.retranslateUi(login_MainWindow) QtCore.QMetaObject.connectSlotsByName(login_MainWindow) def retranslateUi(self, login_MainWindow): _translate = QtCore.QCoreApplication.translate login_MainWindow.setWindowTitle(_translate("login_MainWindow", "火灾火焰烟雾识别系统")) self.label_3.setText(_translate("login_MainWindow", "欢迎使用火灾火焰烟雾识别系统")) self.label_2.setText(_translate("login_MainWindow", "密 码")) self.label.setText(_translate("login_MainWindow", "管理员")) self.lineEdit_2.setText(_translate("login_MainWindow", "")) self.pushButton.setText(_translate("login_MainWindow", "登录")) # self.pushButton.setShortcut(_translate("login_MainWindow", "Enter")) # 设置快捷键
login_mian.py
# self.pushButton.setShortcut(_translate("MainWindow", "enter")) #设置快捷键 import sys from PyQt5 import QtCore, QtGui, QtWidgets from PyQt5.QtWidgets import QApplication, QMainWindow from PyQt5.QtCore import QTimer from PyQt5.QtGui import QImage, QPixmap, QKeyEvent from PyQt5.QtWidgets import QMessageBox from login import Ui_login_MainWindow from main import * from PyQt5.QtCore import Qt class login_window(QtWidgets.QMainWindow, Ui_login_MainWindow): def __init__(self): super(login_window, self).__init__() self.setupUi(self) # 创建窗体对象 self.init() self.admin = "美女" self.Password = "123456" def init(self): self.pushButton.clicked.connect(self.login_button) # 连接槽 def login_button(self): if self.lineEdit.text() == "": QMessageBox.warning(self, '警告', '密码不能为空,请输入!') return None # if self.password == self.lineEdit.text(): if (self.lineEdit.text() == self.Password) and self.lineEdit_2.text() == self.admin: # Ui_Main = Open_Camera() # 生成主窗口的实例 # 1打开新窗口 Ui_Main.show() # 2关闭本窗口 self.close() else: QMessageBox.critical(self, '错误', '密码错误!') self.lineEdit.clear() return None if __name__ == '__main__': from PyQt5 import QtCore QtCore.QCoreApplication.setAttribute(QtCore.Qt.AA_EnableHighDpiScaling) # 自适应分辨率 app = QtWidgets.QApplication(sys.argv) window = login_window() Ui_Main = Ui_MainWindow() # 生成主窗口的实例 window.show() sys.exit(app.exec_())
运行login_mian.py文件会出现登陆界面,
登录成功会跳转主界面
管理员:美女
密码:123456
新建main.py文件
main.py
import sys import cv2 import argparse import random import torch import numpy as np import torch.backends.cudnn as cudnn from PyQt5 import QtCore, QtGui, QtWidgets from utils.torch_utils import select_device from models.experimental import attempt_load from utils.general import check_img_size, non_max_suppression, scale_coords from utils.datasets import letterbox from utils.plots import plot_one_box class Ui_MainWindow(QtWidgets.QMainWindow): def __init__(self, parent=None): super(Ui_MainWindow, self).__init__(parent) self.timer_video = QtCore.QTimer() self.setupUi(self) self.init_logo() self.init_slots() self.cap = cv2.VideoCapture() self.out = None # self.out = cv2.VideoWriter('prediction.avi', cv2.VideoWriter_fourcc(*'XVID'), 20.0, (640, 480)) parser = argparse.ArgumentParser() parser.add_argument('--weights', nargs='+', type=str, default='weights/best.pt', help='model.pt path(s)') parser.add_argument('--source', type=str, default='data/images/', help='source') parser.add_argument('--img-size', type=int, default=640, help='inference size (pixels)') parser.add_argument('--conf-thres', type=float, default=0.25, help='object confidence threshold') parser.add_argument('--iou-thres', type=float, default=0.45, help='IOU threshold for NMS') parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu') parser.add_argument('--view-img', action='store_true', help='display results') parser.add_argument('--save-txt', action='store_true', help='save results to *.txt') parser.add_argument('--save-conf', action='store_true', help='save confidences in --save-txt labels') parser.add_argument('--nosave', action='store_true', help='do not save images/videos') parser.add_argument('--classes', nargs='+', type=int, help='filter by class: --class 0, or --class 0 2 3') parser.add_argument('--agnostic-nms', action='store_true', help='class-agnostic NMS') parser.add_argument('--augment', action='store_true', help='augmented inference') parser.add_argument('--update', action='store_true', help='update all models') parser.add_argument('--project', default='runs/detect', help='save results to project/name') parser.add_argument('--name', default='exp', help='save results to project/name') parser.add_argument('--exist-ok', action='store_true', help='existing project/name ok, do not increment') self.opt = parser.parse_args() print(self.opt) source, weights, view_img, save_txt, imgsz = self.opt.source, self.opt.weights, self.opt.view_img, self.opt.save_txt, self.opt.img_size self.device = select_device(self.opt.device) self.half = self.device.type != 'cpu' # half precision only supported on CUDA cudnn.benchmark = True # Load model self.model = attempt_load( weights, map_location=self.device) # load FP32 model stride = int(self.model.stride.max()) # model stride self.imgsz = check_img_size(imgsz, s=stride) # check img_size if self.half: self.model.half() # to FP16 # Get names and colors self.names = self.model.module.names if hasattr( self.model, 'module') else self.model.names self.colors = [[random.randint(0, 255) for _ in range(3)] for _ in self.names] def setupUi(self, MainWindow): MainWindow.setObjectName("MainWindow") MainWindow.resize(800, 600) self.centralwidget = QtWidgets.QWidget(MainWindow) self.centralwidget.setObjectName("centralwidget") self.horizontalLayout_2 = QtWidgets.QHBoxLayout(self.centralwidget) self.horizontalLayout_2.setObjectName("horizontalLayout_2") self.horizontalLayout = QtWidgets.QHBoxLayout() self.horizontalLayout.setSizeConstraint( QtWidgets.QLayout.SetNoConstraint) self.horizontalLayout.setObjectName("horizontalLayout") self.verticalLayout = QtWidgets.QVBoxLayout() self.verticalLayout.setContentsMargins(-1, -1, 0, -1) self.verticalLayout.setSpacing(80) self.verticalLayout.setObjectName("verticalLayout") self.pushButton_img = QtWidgets.QPushButton(self.centralwidget) sizePolicy = QtWidgets.QSizePolicy( QtWidgets.QSizePolicy.Minimum, QtWidgets.QSizePolicy.MinimumExpanding) sizePolicy.setHorizontalStretch(0) sizePolicy.setVerticalStretch(0) sizePolicy.setHeightForWidth( self.pushButton_img.sizePolicy().hasHeightForWidth()) self.pushButton_img.setSizePolicy(sizePolicy) self.pushButton_img.setMinimumSize(QtCore.QSize(150, 100)) self.pushButton_img.setMaximumSize(QtCore.QSize(150, 100)) font = QtGui.QFont() font.setFamily("Agency FB") font.setPointSize(12) self.pushButton_img.setFont(font) self.pushButton_img.setObjectName("pushButton_img") self.verticalLayout.addWidget( self.pushButton_img, 0, QtCore.Qt.AlignHCenter) self.pushButton_camera = QtWidgets.QPushButton(self.centralwidget) sizePolicy = QtWidgets.QSizePolicy( QtWidgets.QSizePolicy.Minimum, QtWidgets.QSizePolicy.Expanding) sizePolicy.setHorizontalStretch(0) sizePolicy.setVerticalStretch(0) sizePolicy.setHeightForWidth( self.pushButton_camera.sizePolicy().hasHeightForWidth()) self.pushButton_camera.setSizePolicy(sizePolicy) self.pushButton_camera.setMinimumSize(QtCore.QSize(150, 100)) self.pushButton_camera.setMaximumSize(QtCore.QSize(150, 100)) font = QtGui.QFont() font.setFamily("Agency FB") font.setPointSize(12) self.pushButton_camera.setFont(font) self.pushButton_camera.setObjectName("pushButton_camera") self.verticalLayout.addWidget( self.pushButton_camera, 0, QtCore.Qt.AlignHCenter) self.pushButton_video = QtWidgets.QPushButton(self.centralwidget) sizePolicy = QtWidgets.QSizePolicy( QtWidgets.QSizePolicy.Minimum, QtWidgets.QSizePolicy.Expanding) sizePolicy.setHorizontalStretch(0) sizePolicy.setVerticalStretch(0) sizePolicy.setHeightForWidth( self.pushButton_video.sizePolicy().hasHeightForWidth()) self.pushButton_video.setSizePolicy(sizePolicy) self.pushButton_video.setMinimumSize(QtCore.QSize(150, 100)) self.pushButton_video.setMaximumSize(QtCore.QSize(150, 100)) font = QtGui.QFont() font.setFamily("Agency FB") font.setPointSize(12) self.pushButton_video.setFont(font) self.pushButton_video.setObjectName("pushButton_video") self.verticalLayout.addWidget( self.pushButton_video, 0, QtCore.Qt.AlignHCenter) self.verticalLayout.setStretch(2, 1) self.horizontalLayout.addLayout(self.verticalLayout) self.label = QtWidgets.QLabel(self.centralwidget) self.label.setObjectName("label") self.horizontalLayout.addWidget(self.label) self.horizontalLayout.setStretch(0, 1) self.horizontalLayout.setStretch(1, 3) self.horizontalLayout_2.addLayout(self.horizontalLayout) MainWindow.setCentralWidget(self.centralwidget) self.menubar = QtWidgets.QMenuBar(MainWindow) self.menubar.setGeometry(QtCore.QRect(0, 0, 800, 23)) self.menubar.setObjectName("menubar") MainWindow.setMenuBar(self.menubar) self.statusbar = QtWidgets.QStatusBar(MainWindow) self.statusbar.setObjectName("statusbar") MainWindow.setStatusBar(self.statusbar) self.retranslateUi(MainWindow) QtCore.QMetaObject.connectSlotsByName(MainWindow) def retranslateUi(self, MainWindow): _translate = QtCore.QCoreApplication.translate MainWindow.setWindowTitle(_translate("MainWindow", "火灾火焰烟雾识别系统")) self.pushButton_img.setText(_translate("MainWindow", "图片检测")) self.pushButton_camera.setText(_translate("MainWindow", "摄像头检测")) self.pushButton_video.setText(_translate("MainWindow", "视频检测")) self.label.setText(_translate("MainWindow", "TextLabel")) def init_slots(self): self.pushButton_img.clicked.connect(self.button_image_open) self.pushButton_video.clicked.connect(self.button_video_open) self.pushButton_camera.clicked.connect(self.button_camera_open) self.timer_video.timeout.connect(self.show_video_frame) def init_logo(self): pix = QtGui.QPixmap('R-C.jpg') self.label.setScaledContents(True) self.label.setPixmap(pix) def button_image_open(self): print('button_image_open') name_list = [] img_name, _ = QtWidgets.QFileDialog.getOpenFileName( self, "打开图片", "", "*.jpg;;*.png;;All Files(*)") if not img_name: return img = cv2.imread(img_name) print(img_name) showimg = img with torch.no_grad(): img = letterbox(img, new_shape=self.opt.img_size)[0] # Convert # BGR to RGB, to 3x416x416 img = img[:, :, ::-1].transpose(2, 0, 1) img = np.ascontiguousarray(img) img = torch.from_numpy(img).to(self.device) img = img.half() if self.half else img.float() # uint8 to fp16/32 img /= 255.0 # 0 - 255 to 0.0 - 1.0 if img.ndimension() == 3: img = img.unsqueeze(0) # Inference pred = self.model(img, augment=self.opt.augment)[0] # Apply NMS pred = non_max_suppression(pred, self.opt.conf_thres, self.opt.iou_thres, classes=self.opt.classes, agnostic=self.opt.agnostic_nms) print(pred) # Process detections for i, det in enumerate(pred): if det is not None and len(det): # Rescale boxes from img_size to im0 size det[:, :4] = scale_coords( img.shape[2:], det[:, :4], showimg.shape).round() for *xyxy, conf, cls in reversed(det): label = '%s %.2f' % (self.names[int(cls)], conf) name_list.append(self.names[int(cls)]) plot_one_box(xyxy, showimg, label=label, color=self.colors[int(cls)], line_thickness=2) cv2.imwrite('prediction.jpg', showimg) self.result = cv2.cvtColor(showimg, cv2.COLOR_BGR2BGRA) self.result = cv2.resize( self.result, (640, 480), interpolation=cv2.INTER_AREA) self.QtImg = QtGui.QImage( self.result.data, self.result.shape[1], self.result.shape[0], QtGui.QImage.Format_RGB32) self.label.setPixmap(QtGui.QPixmap.fromImage(self.QtImg)) def button_video_open(self): video_name, _ = QtWidgets.QFileDialog.getOpenFileName( self, "打开视频", "", "*.mp4;;*.avi;;All Files(*)") if not video_name: return flag = self.cap.open(video_name) if flag == False: QtWidgets.QMessageBox.warning( self, u"Warning", u"打开视频失败", buttons=QtWidgets.QMessageBox.Ok, defaultButton=QtWidgets.QMessageBox.Ok) else: self.out = cv2.VideoWriter('prediction.avi', cv2.VideoWriter_fourcc( *'MJPG'), 20, (int(self.cap.get(3)), int(self.cap.get(4)))) self.timer_video.start(30) self.pushButton_video.setDisabled(True) self.pushButton_img.setDisabled(True) self.pushButton_camera.setDisabled(True) def button_camera_open(self): if not self.timer_video.isActive(): # 默认使用第一个本地camera flag = self.cap.open(0) if flag == False: QtWidgets.QMessageBox.warning( self, u"Warning", u"打开摄像头失败", buttons=QtWidgets.QMessageBox.Ok, defaultButton=QtWidgets.QMessageBox.Ok) else: self.out = cv2.VideoWriter('prediction.avi', cv2.VideoWriter_fourcc( *'MJPG'), 20, (int(self.cap.get(3)), int(self.cap.get(4)))) self.timer_video.start(30) self.pushButton_video.setDisabled(True) self.pushButton_img.setDisabled(True) self.pushButton_camera.setText(u"关闭摄像头") else: self.timer_video.stop() self.cap.release() self.out.release() self.label.clear() self.init_logo() self.pushButton_video.setDisabled(False) self.pushButton_img.setDisabled(False) self.pushButton_camera.setText(u"摄像头检测") def show_video_frame(self): name_list = [] flag, img = self.cap.read() if img is not None: showimg = img with torch.no_grad(): img = letterbox(img, new_shape=self.opt.img_size)[0] # Convert # BGR to RGB, to 3x416x416 img = img[:, :, ::-1].transpose(2, 0, 1) img = np.ascontiguousarray(img) img = torch.from_numpy(img).to(self.device) img = img.half() if self.half else img.float() # uint8 to fp16/32 img /= 255.0 # 0 - 255 to 0.0 - 1.0 if img.ndimension() == 3: img = img.unsqueeze(0) # Inference pred = self.model(img, augment=self.opt.augment)[0] # Apply NMS pred = non_max_suppression(pred, self.opt.conf_thres, self.opt.iou_thres, classes=self.opt.classes, agnostic=self.opt.agnostic_nms) # Process detections for i, det in enumerate(pred): # detections per image if det is not None and len(det): # Rescale boxes from img_size to im0 size det[:, :4] = scale_coords( img.shape[2:], det[:, :4], showimg.shape).round() # Write results for *xyxy, conf, cls in reversed(det): label = '%s %.2f' % (self.names[int(cls)], conf) name_list.append(self.names[int(cls)]) print(label) plot_one_box( xyxy, showimg, label=label, color=self.colors[int(cls)], line_thickness=2) self.out.write(showimg) show = cv2.resize(showimg, (640, 480)) self.result = cv2.cvtColor(show, cv2.COLOR_BGR2RGB) showImage = QtGui.QImage(self.result.data, self.result.shape[1], self.result.shape[0], QtGui.QImage.Format_RGB888) self.label.setPixmap(QtGui.QPixmap.fromImage(showImage)) else: self.timer_video.stop() self.cap.release() self.out.release() self.label.clear() self.pushButton_video.setDisabled(False) self.pushButton_img.setDisabled(False) self.pushButton_camera.setDisabled(False) self.init_logo() if __name__ == '__main__': app = QtWidgets.QApplication(sys.argv) ui = Ui_MainWindow() ui.show() sys.exit(app.exec_())
选择本地图片进行识别
选择本地的视频进行识别
识别结果
可以成功调用摄像头,这里就不再进行展示!
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。