当前位置:   article > 正文

yolov8目标检测

yolov8目标检测

项目背景与目的

项目背景

近年来,计算机视觉技术,特别是基于深度学习的目标检测方法,在许多应用领域得到了广泛的关注和应用。YOLO(You Only Look Once)作为一种高效的实时目标检测算法,凭借其速度快、精度高的优势,在学术界和工业界都取得了显著的成果。YOLOv8是YOLO系列最新版本,融合了许多先进的技术和优化策略,进一步提升了目标检测的性能。

本项目旨在利用YOLOv8模型,开发一个集成图像和视频目标检测的桌面应用程序,帮助用户方便地进行目标检测任务。该应用程序主要针对研究人员、工程师以及其他需要进行目标检测任务的用户,提供了一种简单易用的工具,能够快速检测图像或视频中的目标,并展示检测结果。

项目目的

1. **开发一个用户友好的桌面应用程序**:
   - 利用PyQt5开发图形用户界面(GUI),使用户可以方便地导入图像和视频,进行目标检测,并查看检测结果。
   - 提供基本的图像和视频导航功能,如加载图像/视频、查看上一张/下一张图像或视频帧等。

2. **集成YOLOv8模型进行目标检测**:
   - 使用YOLOv8预训练模型进行目标检测,确保高效、准确地检测图像和视频中的目标。
   - 实现目标检测结果的可视化,将检测框和标签叠加在图像或视频帧上,帮助用户直观地了解检测结果。

3. **支持模型训练与优化**:
   - 提供训练YOLOv8模型的功能,允许用户使用自定义数据集训练模型。
   - 在训练过程中保存最佳模型,确保用户能够获取最优的目标检测模型。

4. **提供结果保存功能**:
   - 允许用户将检测结果保存为图像或视频文件,方便后续分析和处理。

项目实现

项目由两个主要部分组成:前端和后端。

前端

前端部分主要是图形用户界面(GUI)的实现,使用PyQt5开发。主要功能包括:

1. **图像和视频的加载与显示**:
   - 导入图像:允许用户选择文件夹并加载其中的所有图像。
   - 导入视频:允许用户选择视频文件并逐帧加载。
   - 显示图像和视频帧:在GUI中显示原始图像或视频帧,以及检测后的结果。

2. **检测结果的导航和查看**:
   - 上一个/下一个按钮:用于导航图像列表或视频帧。
   - 开始检测按钮:触发目标检测过程,显示检测结果。

3. **结果的保存与移除**:
   - 保存按钮:允许用户将检测结果保存为图像或视频文件。
   - 移除按钮:从当前加载的列表中移除图像或视频,清除显示内容。

4. **应用退出功能**:
   - 退出按钮:关闭应用程序。

后端

后端部分主要是YOLOv8模型的集成与目标检测的实现。主要功能包括:

1. **模型加载与预测**:
   - 加载YOLOv8预训练模型。
   - 对导入的图像或视频帧进行目标检测,生成检测结果。

2. **模型训练与优化**:
   - 使用指定的数据集和超参数进行模型训练。
   - 在训练过程中保存最佳模型。

3. **检测结果的处理与保存**:
   - 将检测结果(包括检测框和标签)叠加在图像或视频帧上。
   - 保存处理后的图像或视频文件。

 结论

通过本项目,我们希望为需要进行目标检测任务的用户提供一个高效、易用的工具。该工具不仅能够快速、准确地检测图像和视频中的目标,还支持用户根据自定义数据集训练模型,满足不同应用场景的需求。未来,我们可以进一步扩展该项目的功能,如支持更多的模型类型、优化检测速度和精度、增加更多的图像和视频处理功能等。

数据预处理方法

数据预处理是为了提高模型训练效果和加速训练过程,对数据进行的一系列处理操作。常用的数据预处理方法包括:

  1. 图像缩放

    • 将图像调整到模型输入所需的尺寸。例如,对于YOLOv8,通常将图像缩放到640x640像素。
  2. 归一化

    • 将图像像素值归一化到[0, 1]范围,提高模型的收敛速度和稳定性

标注工具

labelme

标签格式

YOLO模型的标签格式非常简单,每个标注文件对应一个图像,包含多个标注信息。每行代表一个目标,格式如下:

  1. class_id x_center y_center width height
  2. 49 0.642859 0.0792187 0.148063 0.148062
  • class_id:目标类别的ID,从0开始。
  • x_center:目标边界框中心点的x坐标,相对于图像宽度进行归一化,范围为[0, 1]。
  • y_center:目标边界框中心点的y坐标,相对于图像高度进行归一化,范围为[0, 1]。
  • width:目标边界框的宽度,相对于图像宽度进行归一化,范围为[0, 1]。
  • height:目标边界框的高度,相对于图像高度进行归一化,范围为[0, 1]。

数据来源:自行标注200条+https://github.com/ultralytics/ultralytics在gihub上面公开数据集128条

loss变化:

模型训练:

yolo detect train data=datasets/mubiao/my_data.yaml model=yolov8n.yaml pretrained=ultralytics/yolov8n.pt epochs=50 batch=4 lr0=0.01 resume=True  

模型评估效果:

运行后参数的变化:精度,召回率......可视化

在终端中运行了训练代码,转换成一般python代码后:

  1. from ultralytics import YOLO
  2. def train_yolov8():
  3. # 加载预训练模型
  4. model = YOLO('ultralytics/yolov8n.pt') #这些代码是根据终端训练模型写出详细的训练代码
  5. # 开始训练
  6. model.train(
  7. data='datasets/mubiao/my_data.yaml', # 数据集配置文件路径
  8. model='yolov8n.yaml', # 模型配置文件路径
  9. epochs=50, # 训练的轮数
  10. batch=4, # 批处理大小
  11. lr0=0.01, # 初始学习率
  12. resume=True, # 是否从上次中断处恢复训练
  13. save_period=1 # 每个epoch保存一次模型
  14. )
  15. # 另存为最优模型
  16. best_model_path = model.ckpt_path.replace('last', 'best')
  17. model.save(best_model_path)
  18. if __name__ == "__main__":
  19. train_yolov8

以上代码就是训练完之后选择最好的模型(best)作为我们的一个项目的使用

后端:

  1. import cv2 # OpenCV库,用于图像处理
  2. import numpy as np # NumPy库,用于数组操作
  3. from ultralytics import YOLO # YOLO库,用于加载YOLO模型
  4. class YOLOApp(QWidget): # 定义YOLOApp类,继承自QWidget
  5. def __init__(self): # 初始化方法
  6. super().__init__() # 调用父类的初始化方法
  7. self.model = YOLO('best.pt') # 加载YOLO模型
  8. self.classnameList = self.model.names # 获取类别名列表
  9. self.imagePaths = [] # 图像路径列表
  10. self.currentImageIndex = -1 # 当前图像索引
  11. self.videoPath = None # 视频路径
  12. self.cap = None # 视频捕获对象
  13. self.timer = None # 定时器
  14. self.frameIndex = 0 # 当前帧索引
  15. self.frames = [] # 视频帧列表
  16. self.detectedFrames = [] # 检测后的帧列表
  17. self.initUI() # 初始化UI
  18. def detectObjects(self): # 检测对象的方法
  19. if self.videoPath: # 如果有视频路径
  20. self.detectVideo() # 检测视频
  21. elif hasattr(self, 'imagePath'): # 如果有图像路径
  22. img = cv2.imread(self.imagePath) # 读取图像
  23. self.detectFrame(img) # 检测图像中的对象
  24. def detectFrame(self, img): # 检测图像中的对象
  25. results = self.model.predict(img, stream=True) # 使用模型进行预测
  26. used_labels = [] # 已使用标签的位置
  27. for result in results:
  28. boxes = result.boxes.cpu().numpy() # 获取检测框
  29. for box in boxes:
  30. r = box.xyxy[0].astype(int) # 获取框的坐标
  31. cv2.rectangle(img, (r[0], r[1]), (r[2], r[3]), (0, 0, 255), 2) # 绘制红色矩形框
  32. classID = int(box.cls[0]) # 获取类别ID
  33. label = self.classnameList[classID] # 获取标签名称
  34. label_pos = (r[0], r[1] - 10) # 标签位置
  35. while any(self.overlap(label_pos, used_label_pos) for used_label_pos in used_labels): # 检查是否重叠
  36. label_pos = (label_pos[0], label_pos[1] - 15) # 调整标签位置
  37. used_labels.append(label_pos) # 保存已使用标签位置
  38. cv2.putText(img, label, label_pos, cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 255), 2) # 绘制标签
  39. height, width, channel = img.shape # 获取图像形状
  40. bytesPerLine = 3 * width # 每行字节数
  41. qImg = QImage(img.data, width, height, bytesPerLine, QImage.Format_RGB888).rgbSwapped() # 转换图像格式
  42. self.labelDetected.setPixmap(QPixmap.fromImage(qImg)) # 显示检测后的图像
  43. self.detectedImage = img # 保存检测后的图像
  44. def detectVideo(self): # 检测视频中的对象
  45. self.cap = cv2.VideoCapture(self.videoPath) # 打开视频文件
  46. self.detectedFrames = [] # 初始化检测后的帧列表
  47. while True:
  48. ret, frame = self.cap.read() # 读取视频帧
  49. if not ret:
  50. break # 如果读取失败,则退出循环
  51. detected_frame = self.detectFrameForVideo(frame) # 检测帧中的对象
  52. self.detectedFrames.append(detected_frame) # 保存检测后的帧
  53. self.cap.release() # 释放视频捕获对象
  54. self.playDetectedVideo() # 播放检测后的视频
  55. def detectFrameForVideo(self, frame): # 对视频帧进行对象检测
  56. results = self.model.predict(frame, stream=True) # 使用模型进行预测
  57. used_labels = [] # 已使用标签的位置
  58. for result in results:
  59. boxes = result.boxes.cpu().numpy() # 获取检测框
  60. for box in boxes:
  61. r = box.xyxy[0].astype(int) # 获取框的坐标
  62. cv2.rectangle(frame, (r[0], r[1]), (r[2], r[3]), (0, 0, 255), 2) # 绘制红色矩形框
  63. classID = int(box.cls[0]) # 获取类别ID
  64. label = self.classnameList[classID] # 获取标签名称
  65. label_pos = (r[0], r[1] - 10) # 标签位置
  66. while any(self.overlap(label_pos, used_label_pos) for used_label_pos in used_labels): # 检查是否重叠
  67. label_pos = (label_pos[0], label_pos[1] - 15) # 调整标签位置
  68. used_labels.append(label_pos) # 保存已使用标签位置
  69. cv2.putText(frame, label, label_pos, cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 255), 2) # 绘制标签
  70. return frame # 返回检测后的帧
调整的超参数
  1. 学习率 (lr0=0.01)

    • 初始学习率设置为0.01。
    • 在开始训练时有较大的步长,可以快速下降损失。
  2. 批次大小 (batch=4)

    • 批次大小设置为4。
    • 如果显存较小,可以减小批次大小以避免显存不足问题。
  3. 训练轮数 (epochs=50)

    • 训练50个轮次。
    • 训练时间和效果之间的折中,通常需要根据验证集性能来决定是否增加轮次。
  4. 预训练权重 (pretrained=ultralytics/yolov8n.pt)

    • 使用YOLOv8预训练模型。
    • 加快收敛速度,并能在小数据集上取得较好的性能。

前端代码:

  1. import sys
  2. from PyQt5.QtWidgets import QApplication, QLabel, QWidget, QVBoxLayout, QPushButton, QFileDialog, QHBoxLayout, QDesktopWidget
  3. from PyQt5.QtGui import QPixmap, QImage
  4. import requests
  5. import cv2
  6. import os
  7. import numpy as np
  8. class YOLOApp(QWidget): # 定义主窗口类YOLOApp,继承自QWidget
  9. def __init__(self):
  10. super().__init__()
  11. self.imagePaths = [] # 存储图像路径的列表
  12. self.currentImageIndex = -1 # 当前图像索引,初始值为-1表示没有图像
  13. self.videoPath = None # 视频路径,初始为None
  14. self.cap = None # 视频捕获对象,初始为None
  15. self.timer = None # 定时器对象,初始为None
  16. self.frameIndex = 0 # 当前视频帧索引,初始值为0
  17. self.frames = [] # 存储视频帧的列表
  18. self.detectedFrames = [] # 存储检测后的视频帧的列表
  19. self.initUI() # 调用initUI方法初始化UI
  20. def initUI(self): # 初始化UI的方法
  21. self.setWindowTitle('YOLOv8 目标检测') # 设置窗口标题
  22. # 获取屏幕的大小
  23. screen = QDesktopWidget().screenGeometry()
  24. self.setGeometry(0, 0, screen.width(), screen.height()) # 设置窗口为全屏大小
  25. self.layout = QVBoxLayout() # 创建一个垂直布局
  26. self.btnLayout = QHBoxLayout() # 创建一个水平布局
  27. self.btnLoadImage = QPushButton('导入图片', self) # 创建“导入图片”按钮
  28. self.btnLoadImage.clicked.connect(self.loadImages) # 连接按钮点击信号到loadImages方法
  29. self.btnLayout.addWidget(self.btnLoadImage) # 将按钮添加到水平布局中
  30. self.btnLoadVideo = QPushButton('导入视频', self) # 创建“导入视频”按钮
  31. self.btnLoadVideo.clicked.connect(self.loadVideo) # 连接按钮点击信号到loadVideo方法
  32. self.btnLayout.addWidget(self.btnLoadVideo) # 将按钮添加到水平布局中
  33. self.btnPrev = QPushButton('上一个', self) # 创建“上一个”按钮
  34. self.btnPrev.clicked.connect(self.showPrev) # 连接按钮点击信号到showPrev方法
  35. self.btnLayout.addWidget(self.btnPrev) # 将按钮添加到水平布局中
  36. self.btnNext = QPushButton('下一个', self) # 创建“下一个”按钮
  37. self.btnNext.clicked.connect(self.showNext) # 连接按钮点击信号到showNext方法
  38. self.btnLayout.addWidget(self.btnNext) # 将按钮添加到水平布局中
  39. # 将按钮布局添加到主布局中
  40. self.layout.addLayout(self.btnLayout)
  41. # 创建一个水平布局用于放置图像
  42. self.imageLayout = QHBoxLayout()
  43. # 创建一个标签用于显示原始图像
  44. self.labelOriginal = QLabel(self)
  45. self.imageLayout.addWidget(self.labelOriginal) # 将标签添加到图像布局中
  46. # 创建一个标签用于显示检测后的图像
  47. self.labelDetected = QLabel(self)
  48. self.imageLayout.addWidget(self.labelDetected) # 将标签添加到图像布局中
  49. # 将图像布局添加到主布局中
  50. self.layout.addLayout(self.imageLayout)
  51. # 设置窗口的主布局
  52. self.setLayout(self.layout)
  53. def loadImages(self): # 加载图像的方法
  54. options = QFileDialog.Options() # 创建文件对话框选项
  55. folder = QFileDialog.getExistingDirectory(self, "选择文件夹", "", options=options) # 打开文件夹选择对话框
  56. if folder: # 如果选择了文件夹
  57. # 获取文件夹中所有图像文件的路径并排序
  58. self.imagePaths = [os.path.join(folder, file) for file in os.listdir(folder) if file.endswith(('.png', '.jpg', '.jpeg'))]
  59. self.imagePaths.sort()
  60. self.currentImageIndex = 0 # 将当前图像索引设置为0
  61. self.videoPath = None # 清空视频路径
  62. self.frames = [] # 清空视频帧列表
  63. self.showImage() # 显示当前图像
  64. def loadVideo(self): # 加载视频的方法
  65. options = QFileDialog.Options() # 创建文件对话框选项
  66. fileName, _ = QFileDialog.getOpenFileName(self, "选择视频文件", "", "视频文件 (*.mp4 *.avi *.mkv)", options=options) # 打开视频文件选择对话框
  67. if fileName: # 如果选择了视频文件
  68. self.videoPath = fileName # 设置视频路径
  69. self.cap = cv2.VideoCapture(self.videoPath) # 创建视频捕获对象
  70. self.frames = [] # 清空视频帧列表
  71. while True: # 循环读取视频帧
  72. ret, frame = self.cap.read() # 读取一帧
  73. if not ret: # 如果读取失败则退出循环
  74. break
  75. self.frames.append(frame) # 将帧添加到视频帧列表
  76. self.frameIndex = 0 # 将当前帧索引设置为0
  77. self.cap.release() # 释放视频捕获对象
  78. self.imagePaths = [] # 清空图像路径列表
  79. self.labelOriginal.clear() # 清空原始图像标签
  80. self.showFrame() # 显示当前视频帧
  81. def showImage(self): # 显示图像的方法
  82. if 0 <= self.currentImageIndex < len(self.imagePaths): # 如果当前图像索引有效
  83. self.imagePath = self.imagePaths[self.currentImageIndex] # 获取当前图像路径
  84. pixmap = QPixmap(self.imagePath) # 加载图像为QPixmap对象
  85. self.labelOriginal.setPixmap(pixmap) # 在标签上显示图像
  86. self.labelDetected.clear() # 清除检测后的图像标签内容
  87. def showFrame(self): # 显示视频帧的方法
  88. if 0 <= self.frameIndex < len(self.frames): # 如果当前帧索引有效
  89. frame = self.frames[self.frameIndex] # 获取当前帧
  90. height, width, channel = frame.shape # 获取帧的高宽和通道数
  91. bytesPerLine = 3 * width # 计算每行的字节数
  92. qImg = QImage(frame.data, width, height, bytesPerLine, QImage.Format_RGB888).rgbSwapped() # 将帧转换为QImage对象
  93. self.labelOriginal.setPixmap(QPixmap.fromImage(qImg)) # 在标签上显示原始视频帧
  94. self.labelDetected.clear() # 清除检测后的图像标签内容
  95. def showPrev(self): # 显示前一个图像或视频帧的方法
  96. if self.imagePaths: # 如果有图像路径
  97. if self.currentImageIndex > 0: # 如果当前图像索引大于0
  98. self.currentImageIndex -= 1 # 减少图像索引
  99. self.showImage() # 显示前一个图像
  100. elif self.frames: # 如果有视频帧
  101. if self.frameIndex > 0: # 如果当前帧索引大于0
  102. self.frameIndex -= 1 # 减少帧索引
  103. self.showFrame() # 显示前一个视频帧
  104. def showNext(self): # 显示下一个图像或视频帧的方法
  105. if self.imagePaths: # 如果有图像路径
  106. if self.currentImageIndex < len(self.imagePaths) - 1: # 如果当前图像索引小于图像数量-1
  107. self.currentImageIndex += 1 # 增加图像索引
  108. self.showImage() # 显示下一个图像
  109. elif self.frames: # 如果有视频帧
  110. if self.frameIndex < len(self.frames) - 1: # 如果当前帧索引小于视频帧数量-1
  111. self.frameIndex += 1 # 增加帧索引
  112. self.showFrame() # 显示下一个视频帧

前端实现过程

1. 导入必要的库
代码首先导入了实现图像和视频处理、界面开发所需的库,包括`PyQt5`、`cv2`、`numpy`等。

 2. 创建主窗口类
定义了一个名为`YOLOApp`的类,继承自`QWidget`,表示主窗口。

 3. 初始化类
在`__init__`方法中,初始化一些变量,如图像路径列表、当前图像索引、视频路径、视频捕获对象、帧列表等。然后调用`initUI`方法来设置用户界面。

4. 初始化用户界面
`initUI`方法中设置窗口标题,获取屏幕大小并将窗口设置为全屏。创建一个垂直布局`QVBoxLayout`来容纳其他控件,并创建一个水平布局`QHBoxLayout`来容纳按钮。

创建了“导入图片”、“导入视频”、“上一个”、“下一个”按钮,并将它们添加到按钮布局中。为每个按钮设置点击事件处理函数。

然后创建两个标签`QLabel`,一个用于显示原始图像,一个用于显示检测后的图像,并将它们添加到一个水平布局中。将该布局添加到主布局中,并设置主窗口的布局。

5. 加载图像功能
`loadImages`方法中,通过文件对话框选择图像文件夹,获取所有图像路径并排序。设置当前图像索引为0,清空视频路径和帧列表,然后调用`showImage`方法显示当前图像。

 6. 加载视频功能
`loadVideo`方法中,通过文件对话框选择视频文件,创建视频捕获对象,循环读取视频帧并存储在帧列表中。设置当前帧索引为0,清空图像路径列表并清除原始图像标签,然后调用`showFrame`方法显示当前视频帧。

7. 显示图像和视频帧
`showImage`方法中,如果当前图像索引有效,获取当前图像路径并加载图像,在标签上显示原始图像,并清除检测后的图像标签内容。

`showFrame`方法中,如果当前帧索引有效,获取当前帧并将其转换为`QImage`对象,在标签上显示原始视频帧,并清除检测后的图像标签内容。

 8. 导航功能
`showPrev`和`showNext`方法分别实现了显示前一个和下一个图像或视频帧的功能,通过修改当前图像或帧索引,调用`showImage`或`showFrame`方法来显示相应的图像或帧。

总结
本代码实现了一个图像和视频目标检测的前端应用程序,用户可以通过简洁的界面加载图像和视频,浏览和查看检测结果。PyQt5提供了良好的用户体验,OpenCV用于图像和视频处理,结合YOLOv8模型实现目标检测功能,为目标检测应用提供了一个完整的解决方案。

完整代码:

  1. import sys
  2. from PyQt5.QtWidgets import QApplication, QLabel, QWidget, QVBoxLayout, QPushButton, QFileDialog, QHBoxLayout, QDesktopWidget
  3. from PyQt5.QtGui import QPixmap, QImage
  4. import cv2
  5. from ultralytics import YOLO
  6. import os
  7. import numpy as np
  8. class YOLOApp(QWidget):
  9. def __init__(self):
  10. super().__init__()
  11. self.model = YOLO('best.pt')
  12. self.classnameList = self.model.names
  13. self.imagePaths = []
  14. self.currentImageIndex = -1
  15. self.videoPath = None
  16. self.cap = None
  17. self.timer = None
  18. self.frameIndex = 0
  19. self.frames = []
  20. self.detectedFrames = []
  21. self.initUI()
  22. def initUI(self):
  23. self.setWindowTitle('YOLOv8 目标检测')
  24. # 获取屏幕的大小
  25. screen = QDesktopWidget().screenGeometry()
  26. self.setGeometry(0, 0, screen.width(), screen.height()) # 设置窗口为全屏大小
  27. self.layout = QVBoxLayout()
  28. self.btnLayout = QHBoxLayout()
  29. self.btnLoadImage = QPushButton('导入图片', self)
  30. self.btnLoadImage.clicked.connect(self.loadImages)
  31. self.btnLayout.addWidget(self.btnLoadImage)
  32. self.btnLoadVideo = QPushButton('导入视频', self)
  33. self.btnLoadVideo.clicked.connect(self.loadVideo)
  34. self.btnLayout.addWidget(self.btnLoadVideo)
  35. self.btnPrev = QPushButton('上一个', self)
  36. self.btnPrev.clicked.connect(self.showPrev)
  37. self.btnLayout.addWidget(self.btnPrev)
  38. self.btnNext = QPushButton('下一个', self)
  39. self.btnNext.clicked.connect(self.showNext)
  40. self.btnLayout.addWidget(self.btnNext)
  41. self.btnDetect = QPushButton('开始检测', self)
  42. self.btnDetect.clicked.connect(self.detectObjects)
  43. self.btnLayout.addWidget(self.btnDetect)
  44. self.btnSave = QPushButton('保存', self)
  45. self.btnSave.clicked.connect(self.save)
  46. self.btnLayout.addWidget(self.btnSave)
  47. self.btnRemove = QPushButton('移除', self)
  48. self.btnRemove.clicked.connect(self.remove)
  49. self.btnLayout.addWidget(self.btnRemove)
  50. self.btnExit = QPushButton('退出', self)
  51. self.btnExit.clicked.connect(self.closeApp)
  52. self.btnLayout.addWidget(self.btnExit)
  53. self.layout.addLayout(self.btnLayout)
  54. self.imageLayout = QHBoxLayout()
  55. self.labelOriginal = QLabel(self)
  56. self.imageLayout.addWidget(self.labelOriginal)
  57. self.labelDetected = QLabel(self)
  58. self.imageLayout.addWidget(self.labelDetected)
  59. self.layout.addLayout(self.imageLayout)
  60. self.setLayout(self.layout)
  61. def loadImages(self):
  62. options = QFileDialog.Options()
  63. folder = QFileDialog.getExistingDirectory(self, "选择文件夹", "", options=options)
  64. if folder:
  65. self.imagePaths = [os.path.join(folder, file) for file in os.listdir(folder) if file.endswith(('.png', '.jpg', '.jpeg'))]
  66. self.imagePaths.sort()
  67. self.currentImageIndex = 0
  68. self.videoPath = None
  69. self.frames = []
  70. self.showImage()
  71. def loadVideo(self):
  72. options = QFileDialog.Options()
  73. fileName, _ = QFileDialog.getOpenFileName(self, "选择视频文件", "", "视频文件 (*.mp4 *.avi *.mkv)", options=options)
  74. if fileName:
  75. self.videoPath = fileName
  76. self.cap = cv2.VideoCapture(self.videoPath)
  77. self.frames = []
  78. while True:
  79. ret, frame = self.cap.read()
  80. if not ret:
  81. break
  82. self.frames.append(frame)
  83. self.frameIndex = 0
  84. self.cap.release()
  85. self.imagePaths = []
  86. self.labelOriginal.clear()
  87. self.showFrame()
  88. def showImage(self):
  89. if 0 <= self.currentImageIndex < len(self.imagePaths):
  90. self.imagePath = self.imagePaths[self.currentImageIndex]
  91. pixmap = QPixmap(self.imagePath)
  92. self.labelOriginal.setPixmap(pixmap) # 在左边显示原图
  93. self.labelDetected.clear() # 清除右边检测图的内容
  94. def showFrame(self):
  95. if 0 <= self.frameIndex < len(self.frames):
  96. frame = self.frames[self.frameIndex]
  97. height, width, channel = frame.shape
  98. bytesPerLine = 3 * width
  99. qImg = QImage(frame.data, width, height, bytesPerLine, QImage.Format_RGB888).rgbSwapped()
  100. self.labelOriginal.setPixmap(QPixmap.fromImage(qImg)) # 在左边显示原图
  101. self.labelDetected.clear() # 清除右边检测图的内容
  102. def showPrev(self):
  103. if self.imagePaths:
  104. if self.currentImageIndex > 0:
  105. self.currentImageIndex -= 1
  106. self.showImage()
  107. elif self.frames:
  108. if self.frameIndex > 0:
  109. self.frameIndex -= 1
  110. self.showFrame()
  111. def showNext(self):
  112. if self.imagePaths:
  113. if self.currentImageIndex < len(self.imagePaths) - 1:
  114. self.currentImageIndex += 1
  115. self.showImage()
  116. elif self.frames:
  117. if self.frameIndex < len(self.frames) - 1:
  118. self.frameIndex += 1
  119. self.showFrame()
  120. def detectObjects(self):
  121. if self.videoPath:
  122. self.detectVideo()
  123. elif hasattr(self, 'imagePath'):
  124. img = cv2.imread(self.imagePath)
  125. self.detectFrame(img)
  126. def detectFrame(self, img):
  127. results = self.model.predict(img, stream=True)
  128. used_labels = []
  129. for result in results:
  130. boxes = result.boxes.cpu().numpy()
  131. for box in boxes:
  132. r = box.xyxy[0].astype(int)
  133. cv2.rectangle(img, (r[0], r[1]), (r[2], r[3]), (0, 0, 255), 2) # 改成红色
  134. classID = int(box.cls[0])
  135. label = self.classnameList[classID]
  136. # 检查是否重叠并调整标签位置
  137. label_pos = (r[0], r[1] - 10)
  138. while any(self.overlap(label_pos, used_label_pos) for used_label_pos in used_labels):
  139. label_pos = (label_pos[0], label_pos[1] - 15)
  140. used_labels.append(label_pos)
  141. cv2.putText(img, label, label_pos, cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 255), 2) # 改成红色
  142. # 转换图像格式以便在PyQt中显示
  143. height, width, channel = img.shape
  144. bytesPerLine = 3 * width
  145. qImg = QImage(img.data, width, height, bytesPerLine, QImage.Format_RGB888).rgbSwapped()
  146. self.labelDetected.setPixmap(QPixmap.fromImage(qImg)) # 在右边显示检测后的图片
  147. self.detectedImage = img # 保存检测后的图像
  148. def detectVideo(self):
  149. self.cap = cv2.VideoCapture(self.videoPath)
  150. self.detectedFrames = []
  151. while True:
  152. ret, frame = self.cap.read()
  153. if not ret:
  154. break
  155. detected_frame = self.detectFrameForVideo(frame)
  156. self.detectedFrames.append(detected_frame)
  157. self.cap.release()
  158. self.playDetectedVideo()
  159. def detectFrameForVideo(self, frame):
  160. results = self.model.predict(frame, stream=True)
  161. used_labels = []
  162. for result in results:
  163. boxes = result.boxes.cpu().numpy()
  164. for box in boxes:
  165. r = box.xyxy[0].astype(int)
  166. cv2.rectangle(frame, (r[0], r[1]), (r[2], r[3]), (0, 0, 255), 2)
  167. classID = int(box.cls[0])
  168. label = self.classnameList[classID]
  169. label_pos = (r[0], r[1] - 10)
  170. while any(self.overlap(label_pos, used_label_pos) for used_label_pos in used_labels):
  171. label_pos = (label_pos[0], label_pos[1] - 15)
  172. used_labels.append(label_pos)
  173. cv2.putText(frame, label, label_pos, cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 255), 2)
  174. return frame
  175. def playDetectedVideo(self):
  176. self.frameIndex = 0
  177. self.labelOriginal.clear() # 清除左边的原视频
  178. self.timer = self.startTimer(30)
  179. def timerEvent(self, event):
  180. if self.frameIndex < len(self.detectedFrames):
  181. frame = self.detectedFrames[self.frameIndex]
  182. height, width, channel = frame.shape
  183. bytesPerLine = 3 * width
  184. qImg = QImage(frame.data, width, height, bytesPerLine, QImage.Format_RGB888).rgbSwapped()
  185. self.labelDetected.setPixmap(QPixmap.fromImage(qImg))
  186. self.frameIndex += 1
  187. else:
  188. self.killTimer(self.timer)
  189. def overlap(self, pos1, pos2, threshold=10):
  190. return abs(pos1[0] - pos2[0]) < threshold and abs(pos1[1] - pos2[1]) < threshold
  191. def save(self):
  192. if hasattr(self, 'detectedImage'):
  193. options = QFileDialog.Options()
  194. filePath, _ = QFileDialog.getSaveFileName(self, "保存图片", "", "JPEG (*.jpg;*.jpeg);;PNG (*.png)", options=options)
  195. if filePath:
  196. cv2.imwrite(filePath, self.detectedImage)
  197. elif self.detectedFrames:
  198. options = QFileDialog.Options()
  199. filePath, _ = QFileDialog.getSaveFileName(self, "保存视频", "", "MP4 (*.mp4);;AVI (*.avi)", options=options)
  200. if filePath:
  201. height, width, layers = self.detectedFrames[0].shape
  202. fourcc = cv2.VideoWriter_fourcc(*'mp4v') if filePath.endswith('.mp4') else cv2.VideoWriter_fourcc(*'XVID')
  203. out = cv2.VideoWriter(filePath, fourcc, 20.0, (width, height))
  204. for frame in self.detectedFrames:
  205. out.write(frame)
  206. out.release()
  207. def remove(self):
  208. if self.imagePaths:
  209. del self.imagePaths[self.currentImageIndex]
  210. if self.currentImageIndex >= len(self.imagePaths):
  211. self.currentImageIndex = len(self.imagePaths) - 1
  212. self.showImage() if self.imagePaths else self.labelOriginal.clear()
  213. self.labelDetected.clear()
  214. elif self.frames:
  215. self.frames = []
  216. self.detectedFrames = []
  217. self.frameIndex = 0
  218. self.labelOriginal.clear()
  219. self.labelDetected.clear()
  220. def closeApp(self):
  221. self.close()
  222. def closeEvent(self, event):
  223. event.accept()
  224. if __name__ == '__main__':
  225. if not QApplication.instance():
  226. app = QApplication(sys.argv)
  227. else:
  228. app = QApplication.instance()
  229. ex = YOLOApp()
  230. ex.show()
  231. sys.exit(app.exec_())

运行前端界面:

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/Guff_9hys/article/detail/1015342
推荐阅读
相关标签
  

闽ICP备14008679号