赞
踩
今天给大家分享用pyqt5桌面小组件搭建一个检测系统,暂定为公共场合猫狗检测系统,检测算法为YOLOX。后续会更新YOLOv8+pyqt5教程
该系统可以进行图片检测,实时检测,视频检测
首先创建一个.py文件复制下面代码:
- from PIL import Image
- import numpy as np
- import time
- import os
- from PyQt5 import QtWidgets, QtCore, QtGui
- from PyQt5.QtGui import *
- import cv2
- import sys
- from PyQt5.QtWidgets import *
- # from detect_qt5 import main_detect,my_lodelmodel
- from demo import main
-
-
- '''摄像头和视频实时检测界面'''
-
-
- class Ui_MainWindow(QWidget):
- def __init__(self, parent=None):
- super(Ui_MainWindow, self).__init__(parent)
-
- # self.face_recong = face.Recognition()
- self.timer_camera1 = QtCore.QTimer()
- self.timer_camera2 = QtCore.QTimer()
- self.timer_camera3 = QtCore.QTimer()
- self.timer_camera4 = QtCore.QTimer()
- self.cap = cv2.VideoCapture()
-
- self.CAM_NUM = 0
-
- # self.slot_init()
- self.__flag_work = 0
- self.x = 0
- self.count = 0
- self.setWindowTitle("公共场合猫狗检测系统")
- self.setWindowIcon(QIcon(os.getcwd() + '\\data\\source_image\\Detective.ico'))
- self.setFixedSize(1600, 900)
- self.yolo=main()
- # self.my_model = my_lodelmodel()
- self.button_open_camera = QPushButton(self)
- self.button_open_camera.setText(u'打开摄像头')
- self.button_open_camera.setStyleSheet('''
- QPushButton
- {text-align : center;
- background-color : white;
- font: bold;
- border-color: gray;
- border-width: 2px;
- border-radius: 10px;
- padding: 6px;
- height : 14px;
- border-style: outset;
- font : 14px;}
- QPushButton:pressed
- {text-align : center;
- background-color : light gray;
- font: bold;
- border-color: gray;
- border-width: 2px;
- border-radius: 10px;
- padding: 6px;
- height : 14px;
- border-style: outset;
- font : 14px;}
- ''')
- self.button_open_camera.move(10, 40)
- self.button_open_camera.clicked.connect(self.button_open_camera_click)
- # self.button_open_camera.clicked.connect(self.button_open_camera_click1)
- # btn.clicked.connect(self.openimage)
-
- self.btn1 = QPushButton(self)
- self.btn1.setText("检测摄像头")
- self.btn1.setStyleSheet('''
- QPushButton
- {text-align : center;
- background-color : white;
- font: bold;
- border-color: gray;
- border-width: 2px;
- border-radius: 10px;
- padding: 6px;
- height : 14px;
- border-style: outset;
- font : 14px;}
- QPushButton:pressed
- {text-align : center;
- background-color : light gray;
- font: bold;
- border-color: gray;
- border-width: 2px;
- border-radius: 10px;
- padding: 6px;
- height : 14px;
- border-style: outset;
- font : 14px;}
- ''')
- self.btn1.move(10, 80)
- self.btn1.clicked.connect(self.button_open_camera_click1)
- # print("QPushButton构建")
-
- self.open_video = QPushButton(self)
- self.open_video.setText("打开视频")
- self.open_video.setStyleSheet('''
- QPushButton
- {text-align : center;
- background-color : white;
- font: bold;
- border-color: gray;
- border-width: 2px;
- border-radius: 10px;
- padding: 6px;
- height : 14px;
- border-style: outset;
- font : 14px;}
- QPushButton:pressed
- {text-align : center;
- background-color : light gray;
- font: bold;
- border-color: gray;
- border-width: 2px;
- border-radius: 10px;
- padding: 6px;
- height : 14px;
- border-style: outset;
- font : 14px;}
- ''')
- self.open_video.move(10, 160)
- self.open_video.clicked.connect(self.open_video_button)
- print("QPushButton构建")
-
- self.btn1 = QPushButton(self)
- self.btn1.setText("检测视频文件")
- self.btn1.setStyleSheet('''
- QPushButton
- {text-align : center;
- background-color : white;
- font: bold;
- border-color: gray;
- border-width: 2px;
- border-radius: 10px;
- padding: 6px;
- height : 14px;
- border-style: outset;
- font : 14px;}
- QPushButton:pressed
- {text-align : center;
- background-color : light gray;
- font: bold;
- border-color: gray;
- border-width: 2px;
- border-radius: 10px;
- padding: 6px;
- height : 14px;
- border-style: outset;
- font : 14px;}
- ''')
- self.btn1.move(10, 200)
- self.btn1.clicked.connect(self.detect_video)
- print("QPushButton构建")
-
- # btn1.clicked.connect(self.detect())
- # btn1.clicked.connect(self.button1_test)
-
- # btn1.clicked.connect(self.detect())
- # btn1.clicked.connect(self.button1_test)
-
- btn2 = QPushButton(self)
- btn2.setText("返回上一界面")
- btn2.setStyleSheet('''
- QPushButton
- {text-align : center;
- background-color : white;
- font: bold;
- border-color: gray;
- border-width: 2px;
- border-radius: 10px;
- padding: 6px;
- height : 14px;
- border-style: outset;
- font : 14px;}
- QPushButton:pressed
- {text-align : center;
- background-color : light gray;
- font: bold;
- border-color: gray;
- border-width: 2px;
- border-radius: 10px;
- padding: 6px;
- height : 14px;
- border-style: outset;
- font : 14px;}
- ''')
- btn2.move(10, 240)
- btn2.clicked.connect(self.back_lastui)
-
- # 信息显示
- self.label_show_camera = QLabel(self)
- self.label_move = QLabel()
- self.label_move.setFixedSize(100, 100)
- # self.label_move.setText(" 11 待检测图片")
- self.label_show_camera.setFixedSize(700, 500)
- self.label_show_camera.setAutoFillBackground(True)
- self.label_show_camera.move(110, 80)
- self.label_show_camera.setStyleSheet("QLabel{background:#F5F5DC;}"
- "QLabel{color:rgb(300,300,300,120);font-size:10px;font-weight:bold;font-family:宋体;}"
- )
- self.label_show_camera1 = QLabel(self)
- self.label_show_camera1.setFixedSize(700, 500)
- self.label_show_camera1.setAutoFillBackground(True)
- self.label_show_camera1.move(850, 80)
- self.label_show_camera1.setStyleSheet("QLabel{background:#F5F5DC;}"
- "QLabel{color:rgb(300,300,300,120);font-size:10px;font-weight:bold;font-family:宋体;}"
- )
-
- self.timer_camera1.timeout.connect(self.show_camera)
- self.timer_camera2.timeout.connect(self.show_camera1)
- # self.timer_camera3.timeout.connect(self.show_camera2)
- self.timer_camera4.timeout.connect(self.show_camera2)
- self.timer_camera4.timeout.connect(self.show_camera3)
- self.clicked = False
-
- # self.setWindowTitle(u'摄像头')
- self.frame_s = 3
- # 设置背景图片
- palette1 = QPalette()
- palette1.setBrush(self.backgroundRole(), QBrush(QPixmap('R-C.png')))
- self.setPalette(palette1)
-
- def back_lastui(self):
- self.timer_camera1.stop()
- self.cap.release()
- self.label_show_camera.clear()
- self.timer_camera2.stop()
-
- self.label_show_camera1.clear()
- cam_t.close()
- ui_p.show()
-
- '''摄像头'''
-
- def button_open_camera_click(self):
- if self.timer_camera1.isActive() == False:
- flag = self.cap.open(self.CAM_NUM)
- if flag == False:
- msg = QtWidgets.QMessageBox.warning(self, u"Warning", u"请检测相机与电脑是否连接正确",
- buttons=QtWidgets.QMessageBox.Ok,
- defaultButton=QtWidgets.QMessageBox.Ok)
-
- else:
- self.timer_camera1.start(30)
-
- self.button_open_camera.setText(u'关闭摄像头')
- else:
- self.timer_camera1.stop()
- self.cap.release()
- self.label_show_camera.clear()
- self.timer_camera2.stop()
-
- self.label_show_camera1.clear()
- self.button_open_camera.setText(u'打开摄像头')
-
- def show_camera(self): # 摄像头左边
- flag, self.image = self.cap.read()
-
- dir_path = os.getcwd()
- camera_source = dir_path + "\\data\\test\\2.jpg"
- cv2.imwrite(camera_source, self.image)
-
- width = self.image.shape[1]
- height = self.image.shape[0]
-
- # 设置新的图片分辨率框架
- width_new = 700
- height_new = 500
-
- # 判断图片的长宽比率
- if width / height >= width_new / height_new:
-
- show = cv2.resize(self.image, (width_new, int(height * width_new / width)))
- else:
-
- show = cv2.resize(self.image, (int(width * height_new / height), height_new))
-
- show = cv2.cvtColor(show, cv2.COLOR_BGR2RGB)
-
- showImage = QtGui.QImage(show.data, show.shape[1], show.shape[0], 3 * show.shape[1], QtGui.QImage.Format_RGB888)
-
- self.label_show_camera.setPixmap(QtGui.QPixmap.fromImage(showImage))
-
- def button_open_camera_click1(self):
- if self.timer_camera2.isActive() == False:
- flag = self.cap.open(self.CAM_NUM)
- if flag == False:
- msg = QtWidgets.QMessageBox.warning(self, u"Warning", u"请检测相机与电脑是否连接正确",
- buttons=QtWidgets.QMessageBox.Ok,
- defaultButton=QtWidgets.QMessageBox.Ok)
-
- else:
- self.timer_camera2.start(30)
- self.button_open_camera.setText(u'关闭摄像头')
- else:
- self.timer_camera2.stop()
- self.cap.release()
- self.label_show_camera1.clear()
- self.button_open_camera.setText(u'打开摄像头')
-
- def show_camera1(self):
- fps = 0.0
- t1 = time.time()
- flag, self.image = self.cap.read()
- self.image = cv2.cvtColor(self.image, cv2.COLOR_BGR2RGB)
- # self.image = Image.fromarray(np.uint8(self.image))
- im0, nums, ti = self.yolo.demoimg(self.image)
- im0= cv2.cvtColor(im0, cv2.COLOR_BGR2RGB)
- width = im0.shape[1]
- height = im0.shape[0]
-
- # 设置新的图片分辨率框架
- width_new = 640
- height_new = 640
-
- # 判断图片的长宽比率
- if width / height >= width_new / height_new:
-
- show = cv2.resize(im0, (width_new, int(height * width_new / width)))
- else:
-
- show = cv2.resize(im0, (int(width * height_new / height), height_new))
- # im0 = cv2.cvtColor(show, cv2.COLOR_RGB2BGR)
- if nums>= 1:
- fps = (fps + (1. / (time.time() - t1))) / 2
- im0 = cv2.putText(im0, "fps= %.2f" % (fps), (0, 40), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2)
- im0 = cv2.putText(im0, "No pets allowed", (0, 150), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 0, 0), 2)
-
- showImage = QtGui.QImage(im0, im0.shape[1], im0.shape[0], 3 * im0.shape[1], QtGui.QImage.Format_RGB888)
-
- self.label_show_camera1.setPixmap(QtGui.QPixmap.fromImage(showImage))
-
- '''视频检测'''
-
- def open_video_button(self):
-
- if self.timer_camera4.isActive() == False:
-
- imgName, imgType = QFileDialog.getOpenFileName(self, "打开视频", "", "*.mp4;;*.AVI;;*.rmvb;;All Files(*)")
-
- self.cap_video = cv2.VideoCapture(imgName)
-
- flag = self.cap_video.isOpened()
-
- if flag == False:
- msg = QtWidgets.QMessageBox.warning(self, u"Warning", u"请检测相机与电脑是否连接正确",
- buttons=QtWidgets.QMessageBox.Ok,
- defaultButton=QtWidgets.QMessageBox.Ok)
-
- else:
-
- # self.timer_camera3.start(10)
- self.show_camera2()
- self.open_video.setText(u'关闭视频')
- else:
- # self.timer_camera3.stop()
- self.cap_video.release()
- self.label_show_camera.clear()
- self.timer_camera4.stop()
- self.frame_s = 3
- self.label_show_camera1.clear()
- self.open_video.setText(u'打开视频')
-
- def detect_video(self):
-
- if self.timer_camera4.isActive() == False:
- flag = self.cap_video.isOpened()
- if flag == False:
- msg = QtWidgets.QMessageBox.warning(self, u"Warning", u"请检测相机与电脑是否连接正确",
- buttons=QtWidgets.QMessageBox.Ok,
- defaultButton=QtWidgets.QMessageBox.Ok)
-
- else:
- self.timer_camera4.start(30)
-
- else:
- self.timer_camera4.stop()
- self.cap_video.release()
- self.label_show_camera1.clear()
-
- def show_camera2(self): # 显示视频的左边
-
- # 抽帧
- length = int(self.cap_video.get(cv2.CAP_PROP_FRAME_COUNT)) # 抽帧
- print(self.frame_s, length) # 抽帧
- flag, self.image1 = self.cap_video.read() # image1是视频的
- if flag == True:
-
- width = self.image1.shape[1]
- height = self.image1.shape[0]
-
- # 设置新的图片分辨率框架
- width_new = 700
- height_new = 500
-
- # 判断图片的长宽比率
- if width / height >= width_new / height_new:
-
- show = cv2.resize(self.image1, (width_new, int(height * width_new / width)))
- else:
-
- show = cv2.resize(self.image1, (int(width * height_new / height), height_new))
-
- show = cv2.cvtColor(show, cv2.COLOR_BGR2RGB)
-
- showImage = QtGui.QImage(show.data, show.shape[1], show.shape[0], 3 * show.shape[1],
- QtGui.QImage.Format_RGB888)
-
- self.label_show_camera.setPixmap(QtGui.QPixmap.fromImage(showImage))
- else:
- self.cap_video.release()
- self.label_show_camera.clear()
- self.timer_camera4.stop()
-
- self.label_show_camera1.clear()
- self.open_video.setText(u'打开视频')
-
- def show_camera3(self):
-
- flag, self.image1 = self.cap_video.read()
- self.frame_s += 1
- if flag == True:
- # if self.frame_s % 3 == 0: #抽帧
- # face = self.face_detect.align(self.image)
- # if face:
- # pass
-
- # dir_path = os.getcwd()
- # camera_source = dir_path + "\\data\\test\\video.jpg"
- #
- # cv2.imwrite(camera_source, self.image1)
- # print("im01")
- # im0, label = main_detect(self.my_model, camera_source)
- im0,nums,ti = self.yolo.demoimg(self.image1)
- # print("imo",im0)
- # print(label)
- # if label == 'debug':
- # print("labelkong")
- # print("debug")
-
- # im0, label = slef.detect()
- # print("debug1")
- width = im0.shape[1]
- height = im0.shape[0]
-
- # 设置新的图片分辨率框架
- width_new = 700
- height_new = 500
-
- # 判断图片的长宽比率
- if width / height >= width_new / height_new:
-
- show = cv2.resize(im0, (width_new, int(height * width_new / width)))
- else:
-
- show = cv2.resize(im0, (int(width * height_new / height), height_new))
-
- im0 = show#cv2.cvtColor(show, cv2.COLOR_RGB2BGR)
- # print("debug2")
- if nums >= 1:
- im0 = cv2.putText(im0, "Warning", (10, 20), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 0, 0), 2)
- im0 = cv2.putText(im0, f"nums:{nums}", (10, 50), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 0, 0), 2)
- showImage = QtGui.QImage(im0, im0.shape[1], im0.shape[0], 3 * im0.shape[1], QtGui.QImage.Format_RGB888)
-
- self.label_show_camera1.setPixmap(QtGui.QPixmap.fromImage(showImage))
-
-
- '''单张图片检测'''
-
-
- class picture(QWidget):
-
- def __init__(self):
- super(picture, self).__init__()
-
- self.str_name = '0'
-
- self.yolo = main()
-
- # self.my_model=my_lodelmodel()
- self.resize(1600, 900)
- self.setWindowIcon(QIcon(os.getcwd() + '\\data\\source_image\\Detective.ico'))
- self.setWindowTitle("公共场合猫狗检测系统")
-
- # window_pale = QtGui.QPalette()
- # window_pale.setBrush(self.backgroundRole(), QtGui.QBrush(
- # QtGui.QPixmap(os.getcwd() + '\\data\\source_image\\backgroud.jpg')))
- # self.setPalette(window_pale)
- palette2 = QPalette()
- palette2.setBrush(self.backgroundRole(), QBrush(QPixmap('4.jpg')))
- self.setPalette(palette2)
-
- camera_or_video_save_path = 'data\\test'
- if not os.path.exists(camera_or_video_save_path):
- os.makedirs(camera_or_video_save_path)
-
- self.label1 = QLabel(self)
- self.label1.setText(" 待检测图片")
- self.label1.setFixedSize(700, 500)
- self.label1.move(110, 80)
-
- self.label1.setStyleSheet("QLabel{background:#7A6969;}"
- "QLabel{color:rgb(300,300,300,120);font-size:20px;font-weight:bold;font-family:宋体;}"
- )
- self.label2 = QLabel(self)
- self.label2.setText("检测结果")
- self.label2.setFixedSize(700, 500)
- self.label2.move(850, 80)
-
- self.label2.setStyleSheet("QLabel{background:#7A6969;}"
- "QLabel{color:rgb(300,300,300,120);font-size:20px;font-weight:bold;font-family:宋体;}"
- )
-
- self.label3 = QLabel(self)
- self.label3.setText("")
- self.label3.move(1200, 620)
- self.label3.setStyleSheet("font-size:20px;")
- self.label3.adjustSize()
-
- btn = QPushButton(self)
- btn.setText("打开图片")
- btn.setStyleSheet('''
- QPushButton
- {text-align : center;
- background-color : white;
- font: bold;
- border-color: gray;
- border-width: 2px;
- border-radius: 10px;
- padding: 6px;
- height : 14px;
- border-style: outset;
- font : 14px;}
- QPushButton:pressed
- {text-align : center;
- background-color : light gray;
- font: bold;
- border-color: gray;
- border-width: 2px;
- border-radius: 10px;
- padding: 6px;
- height : 14px;
- border-style: outset;
- font : 14px;}
- ''')
- btn.move(10, 30)
- btn.clicked.connect(self.openimage)
-
- btn1 = QPushButton(self)
- btn1.setText("检测图片")
- btn1.setStyleSheet('''
- QPushButton
- {text-align : center;
- background-color : white;
- font: bold;
- border-color: gray;
- border-width: 2px;
- border-radius: 10px;
- padding: 6px;
- height : 14px;
- border-style: outset;
- font : 14px;}
- QPushButton:pressed
- {text-align : center;
- background-color : light gray;
- font: bold;
- border-color: gray;
- border-width: 2px;
- border-radius: 10px;
- padding: 6px;
- height : 14px;
- border-style: outset;
- font : 14px;}
- ''')
- btn1.move(10, 80)
- # print("QPushButton构建")
- btn1.clicked.connect(self.button1_test)
-
- btn3 = QPushButton(self)
- btn3.setText("视频和摄像头检测")
- btn3.setStyleSheet('''
- QPushButton
- {text-align : center;
- background-color : white;
- font: bold;
- border-color: gray;
- border-width: 2px;
- border-radius: 10px;
- padding: 6px;
- height : 14px;
- border-style: outset;
- font : 14px;}
- QPushButton:pressed
- {text-align : center;
- background-color : light gray;
- font: bold;
- border-color: gray;
- border-width: 2px;
- border-radius: 10px;
- padding: 6px;
- height : 14px;
- border-style: outset;
- font : 14px;}
- ''')
- btn3.move(10, 160)
- btn3.clicked.connect(self.camera_find)
-
- self.imgname1 = '0'
-
- def camera_find(self):
- ui_p.close()
- cam_t.show()
-
- def openimage(self):
-
- imgName, imgType = QFileDialog.getOpenFileName(self, "打开图片", "D://",
- "Image files (*.jpg *.gif *.png *.jpeg)") # "*.jpg;;*.png;;All Files(*)"
- if imgName != '':
- self.imgname1 = imgName
- # print("imgName",imgName,type(imgName))
- self.im0 = cv2.imread(imgName)
- width = self.im0.shape[1]
- height = self.im0.shape[0]
- # 设置新的图片分辨率框架
- width_new = 700
- height_new = 500
-
- # 判断图片的长宽比率
- if width / height >= width_new / height_new:
-
- show = cv2.resize(self.im0, (width_new, int(height * width_new / width)))
- else:
-
- show = cv2.resize(self.im0, (int(width * height_new / height), height_new))
- im0 = cv2.cvtColor(show, cv2.COLOR_RGB2BGR)
- showImage = QtGui.QImage(im0, im0.shape[1], im0.shape[0], 3 * im0.shape[1], QtGui.QImage.Format_RGB888)
- self.label1.setPixmap(QtGui.QPixmap.fromImage(showImage))
-
- # jpg = QtGui.QPixmap(imgName).scaled(self.label1.width(), self.label1.height())
- # self.label1.setPixmap(jpg)
-
- def button1_test(self):
- if self.imgname1 != '0':
- # QApplication.processEvents()
- # image = Image.open(self.imgname1)
- image = cv2.imread(self.imgname1)
- # K, im0 = self.yolo.detect_image(image)
- im0,nums,time=self.yolo.demoimg(image)
- print(nums)
- # im0 = np.array(im0)
- # QApplication.processEvents()
- width = im0.shape[1]
- height = im0.shape[0]
-
- # 设置新的图片分辨率框架
- width_new = 700
- height_new = 700
-
- # 判断图片的长宽比率
- if width / height >= width_new / height_new:
-
- im0 = cv2.resize(im0, (width_new, int(height * width_new / width)))
- else:
-
- im0 = cv2.resize(im0, (int(width * height_new / height), height_new))
- im0 = cv2.putText(im0, f"Infertime:{round(time,2)}s", (410, 80), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 0, 0), 2)
- # im0 = cv2.cvtColor(show, cv2.COLOR_RGB2BGR)
- if nums >= 1:
- im0 = cv2.putText(im0, "Warning", (410, 20), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 0, 0), 2)
- im0 = cv2.putText(im0, f"nums:{nums}", (410, 50), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 0, 0), 2)
- image_name = QtGui.QImage(im0, im0.shape[1], im0.shape[0], 3 * im0.shape[1], QtGui.QImage.Format_RGB888)
- # label=label.split(' ')[0] #label 59 0.96 分割字符串 取前一个
- self.label2.setPixmap(QtGui.QPixmap.fromImage(image_name))
- # jpg = QtGui.QPixmap(image_name).scaled(self.label1.width(), self.label1.height())
- # self.label2.setPixmap(jpg)
- else:
- QMessageBox.information(self, '错误', '请先选择一个图片文件', QMessageBox.Yes, QMessageBox.Yes)
-
-
- if __name__ == '__main__':
- app = QApplication(sys.argv)
- splash = QSplashScreen(QPixmap(".\\data\\source_image\\logo.png"))
- # 设置画面中的文字的字体
- splash.setFont(QFont('Microsoft YaHei UI', 12))
- # 显示画面
- splash.show()
- # 显示信息
- splash.showMessage("程序初始化中... 0%", QtCore.Qt.AlignLeft | QtCore.Qt.AlignBottom, QtCore.Qt.black)
- time.sleep(0.3)
-
- splash.showMessage("正在加载模型配置文件...60%", QtCore.Qt.AlignLeft | QtCore.Qt.AlignBottom, QtCore.Qt.black)
- cam_t = Ui_MainWindow()
- splash.showMessage("正在加载模型配置文件...100%", QtCore.Qt.AlignLeft | QtCore.Qt.AlignBottom, QtCore.Qt.black)
-
- ui_p = picture()
- ui_p.show()
- splash.close()
-
- sys.exit(app.exec_())

想给系统起什么名字自己更换以及背景
然后将YOLOX demo.py文件移动至根目录下,并将下面内容复制过去:
- import argparse
- import os
- import time
- from loguru import logger
-
- import cv2
-
- import torch
-
- from yolox.data.data_augment import ValTransform
- from yolox.data.datasets import COCO_CLASSES
- from yolox.exp import get_exp
- from yolox.utils import fuse_model, get_model_info, postprocess, vis
-
- IMAGE_EXT = [".jpg", ".jpeg", ".webp", ".bmp", ".png"]
-
-
- def make_parser():
- parser = argparse.ArgumentParser("YOLOX Demo!")
- parser.add_argument(
- "--demo", default="image", help="demo type, eg. image, video and webcam"
- )
- parser.add_argument("-expn", "--experiment-name", type=str, default=None)
- parser.add_argument("-n", "--name", type=str, default=None, help="model name")
-
- parser.add_argument(
- "--path", default="./assets/dog.jpg", help="path to images or video"
- )
- parser.add_argument("--camid", type=int, default=0, help="webcam demo camera id")
- parser.add_argument(
- "--save_result",
- action="store_true",
- help="whether to save the inference result of image/video",
- )
-
- # exp file
- parser.add_argument(
- "-f",
- "--exp_file",
- default='exps/default/yolox_s.py',
- type=str,
- help="please input your experiment description file",
- )
- parser.add_argument("-c", "--ckpt", default='yolox_s.pth', type=str, help="ckpt for eval")
- parser.add_argument(
- "--device",
- default="cpu",
- type=str,
- help="device to run our model, can either be cpu or gpu",
- )
- parser.add_argument("--conf", default=0.01, type=float, help="test conf")
- parser.add_argument("--nms", default=0.45, type=float, help="test nms threshold")
- parser.add_argument("--tsize", default=640, type=int, help="test img size")
- parser.add_argument(
- "--fp16",
- dest="fp16",
- default=False,
- action="store_true",
- help="Adopting mix precision evaluating.",
- )
- parser.add_argument(
- "--legacy",
- dest="legacy",
- default=False,
- action="store_true",
- help="To be compatible with older versions",
- )
- parser.add_argument(
- "--fuse",
- dest="fuse",
- default=False,
- action="store_true",
- help="Fuse conv and bn for testing.",
- )
- parser.add_argument(
- "--trt",
- dest="trt",
- default=False,
- action="store_true",
- help="Using TensorRT model for testing.",
- )
- return parser
-
-
- def get_image_list(path):
- image_names = []
- for maindir, subdir, file_name_list in os.walk(path):
- for filename in file_name_list:
- apath = os.path.join(maindir, filename)
- ext = os.path.splitext(apath)[1]
- if ext in IMAGE_EXT:
- image_names.append(apath)
- return image_names
-
-
- class Predictor(object):
- def __init__(
- self,
- model,
- exp,
- cls_names=COCO_CLASSES,
- trt_file=None,
- decoder=None,
- device="cpu",
- fp16=False,
- legacy=False,
- ):
- self.model = model
- self.cls_names = cls_names
- self.decoder = decoder
- self.num_classes = exp.num_classes
- self.confthre = exp.test_conf
- self.nmsthre = exp.nmsthre
- self.test_size = exp.test_size
- self.device = device
- self.fp16 = fp16
- self.preproc = ValTransform(legacy=legacy)
- if trt_file is not None:
- from torch2trt import TRTModule
-
- model_trt = TRTModule()
- model_trt.load_state_dict(torch.load(trt_file))
-
- x = torch.ones(1, 3, exp.test_size[0], exp.test_size[1]).cuda()
- self.model(x)
- self.model = model_trt
-
- def inference(self, img):
- # img = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
- img_info = {"id": 0}
- # if isinstance(img, str):
- # img_info["file_name"] = os.path.basename(img)
- # img = cv2.imread(img)
- # else:
- img_info["file_name"] = None
- img= cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
- height, width = img.shape[:2]
- img_info["height"] = height
- img_info["width"] = width
- img_info["raw_img"] = img
-
- ratio = min(self.test_size[0] / img.shape[0], self.test_size[1] / img.shape[1])
- img_info["ratio"] = ratio
-
- img, _ = self.preproc(img, None, self.test_size)
- img = torch.from_numpy(img).unsqueeze(0)
- img = img.float()
- if self.device == "gpu":
- img = img.cuda()
- if self.fp16:
- img = img.half() # to FP16
-
- with torch.no_grad():
- t0 = time.time()
- outputs = self.model(img)
- if self.decoder is not None:
- outputs = self.decoder(outputs, dtype=outputs.type())
- outputs = postprocess(
- outputs, self.num_classes, self.confthre,
- self.nmsthre, class_agnostic=True
- )
- logger.info("Infer time: {:.4f}s".format(time.time() - t0))
- return outputs, img_info,time.time() - t0
-
- def visual(self, output, img_info, cls_conf=0.35):
- ratio = img_info["ratio"]
- img = img_info["raw_img"]
- if output is None:
- return img,0
- output = output.cpu()
-
- bboxes = output[:, 0:4]
-
- # preprocessing: resize
- bboxes /= ratio
-
- cls = output[:, 6]
- scores = output[:, 4] * output[:, 5]
-
- vis_res,k = vis(img, bboxes, scores, cls, cls_conf, self.cls_names)
- return vis_res,k
-
-
- def image_demo(predictor,current_time,image):
- # if os.path.isdir(path):
- # files = get_image_list(path)
- # else:
- # files = [path]
- # files.sort()
- # for image_name in files:
- outputs, img_info,ti = predictor.inference(image)
- result_image,nums = predictor.visual(outputs[0], img_info, predictor.confthre)
- return result_image,nums,ti
- # if save_result:
- # save_folder = os.path.join(
- # vis_folder, time.strftime("%Y_%m_%d_%H_%M_%S", current_time)
- # )
- # os.makedirs(save_folder, exist_ok=True)
- # save_file_name = os.path.join(save_folder, os.path.basename(image_name))
- # logger.info("Saving detection result in {}".format(save_file_name))
- # cv2.imwrite(save_file_name, result_image)
- # ch = cv2.waitKey(0)
- # if ch == 27 or ch == ord("q") or ch == ord("Q"):
- # break
-
-
- def imageflow_demo(predictor, vis_folder, current_time, args):
- # cap = cv2.VideoCapture(args.path if args.demo == "video" else args.camid)
- # width = cap.get(cv2.CAP_PROP_FRAME_WIDTH) # float
- # height = cap.get(cv2.CAP_PROP_FRAME_HEIGHT) # float
- # fps = cap.get(cv2.CAP_PROP_FPS)
- # if args.save_result:
- # save_folder = os.path.join(
- # vis_folder, time.strftime("%Y_%m_%d_%H_%M_%S", current_time)
- # )
- # os.makedirs(save_folder, exist_ok=True)
- # if args.demo == "video":
- # save_path = os.path.join(save_folder, os.path.basename(args.path))
- # else:
- # save_path = os.path.join(save_folder, "camera.mp4")
- # logger.info(f"video save_path is {save_path}")
- # vid_writer = cv2.VideoWriter(
- # save_path, cv2.VideoWriter_fourcc(*"mp4v"), fps, (int(width), int(height))
- # )
- while True:
- ret_val, frame = cap.read()
- if ret_val:
- outputs, img_info = predictor.inference(frame)
- result_frame = predictor.visual(outputs[0], img_info, predictor.confthre)
- if args.save_result:
- vid_writer.write(result_frame)
- else:
- cv2.namedWindow("yolox", cv2.WINDOW_NORMAL)
- cv2.imshow("yolox", result_frame)
- ch = cv2.waitKey(1)
- if ch == 27 or ch == ord("q") or ch == ord("Q"):
- break
- else:
- break
-
-
- class main(object):
- def __init__(self):
- args = make_parser().parse_args()
- exp = get_exp(args.exp_file, args.name)
- if not args.experiment_name:
- args.experiment_name = exp.exp_name
-
- file_name = os.path.join(exp.output_dir, args.experiment_name)
- os.makedirs(file_name, exist_ok=True)
-
- # vis_folder = None
- # if args.save_result:
- # vis_folder = os.path.join(file_name, "vis_res")
- # os.makedirs(vis_folder, exist_ok=True)
-
- if args.trt:
- args.device = "gpu"
-
- logger.info("Args: {}".format(args))
-
- if args.conf is not None:
- exp.test_conf = args.conf
- if args.nms is not None:
- exp.nmsthre = args.nms
- if args.tsize is not None:
- exp.test_size = (args.tsize, args.tsize)
-
- model = exp.get_model()
- logger.info("Model Summary: {}".format(get_model_info(model, exp.test_size)))
-
- if args.device == "gpu":
- model.cuda()
- if args.fp16:
- model.half() # to FP16
- model.eval()
-
- if not args.trt:
- if args.ckpt is None:
- ckpt_file = os.path.join(file_name, "best_ckpt.pth")
- else:
- ckpt_file = args.ckpt
- logger.info("loading checkpoint")
- ckpt = torch.load(ckpt_file, map_location="cpu")
- # load the model state dict
- model.load_state_dict(ckpt["model"])
- logger.info("loaded checkpoint done.")
-
- if args.fuse:
- logger.info("\tFusing model...")
- model = fuse_model(model)
-
- if args.trt:
- assert not args.fuse, "TensorRT model is not support model fusing!"
- trt_file = os.path.join(file_name, "model_trt.pth")
- assert os.path.exists(
- trt_file
- ), "TensorRT model is not found!\n Run python3 tools/trt.py first!"
- model.head.decode_in_inference = False
- decoder = model.head.decode_outputs
- logger.info("Using TensorRT to inference")
- else:
- trt_file = None
- decoder = None
-
- self.predictor = Predictor(
- model, exp, COCO_CLASSES, trt_file, decoder,
- args.device, args.fp16, args.legacy,
- )
- def demoimg(self,img):
- current_time = time.localtime()
-
- im=image_demo(self.predictor,current_time,img)
- return im
- def demovido(self,img):
- imageflow_demo(predictor, img, current_time)
-
-
- if __name__ == "__main__":
- args = make_parser().parse_args()
- exp = get_exp(args.exp_file, args.name)
-
- main(exp, args)

需要的参数直接在上面改好,由于这里我没有单独训练猫狗数据集直接利用的YOLOX-s的权重文件,然后需要将索引更改为猫和狗的分类索引,更改visualize.py文件
- def vis(img, boxes, scores, cls_ids, conf=0.5, class_names=None):
- k = 0
- for i in range(len(boxes)):
- box = boxes[i]
- cls_id = int(cls_ids[i])
- if cls_id in (15,16):
- score = scores[i]
- if score < conf:
- continue
- k+=1
- x0 = int(box[0])
- y0 = int(box[1])
- x1 = int(box[2])
- y1 = int(box[3])
-
- color = (_COLORS[cls_id] * 255).astype(np.uint8).tolist()
- text = '{}:{:.1f}%'.format(class_names[cls_id], score * 100)
- txt_color = (0, 0, 0) if np.mean(_COLORS[cls_id]) > 0.5 else (255, 255, 255)
- font = cv2.FONT_HERSHEY_SIMPLEX
-
- txt_size = cv2.getTextSize(text, font, 0.4, 1)[0]
- cv2.rectangle(img, (x0, y0), (x1, y1), color, 2)
-
- txt_bk_color = (_COLORS[cls_id] * 255 * 0.7).astype(np.uint8).tolist()
- cv2.rectangle(
- img,
- (x0, y0 + 1),
- (x0 + txt_size[0] + 1, y0 + int(1.5*txt_size[1])),
- txt_bk_color,
- -1
- )
- cv2.putText(img, text, (x0, y0 + txt_size[1]), font, 0.4, txt_color, thickness=1)
-
- return img,k

因为希望知道数量所以这里设置了个参数k并进行return
最后运行一下之前的pyqt5文件就可以进行检测了:
示例如下:
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。