赞
踩
将训练好的模型转换为onnx 在终端运行此代码将导出的模型放在model文件夹内
python export.py --weights yolov5s.pt --include onnx
解析yolo模型我这边使用的是opencv的dnn来解析yolo模型进行识别 下面是他的部分代码
-
- class Colors:
- # Ultralytics color palette https://ultralytics.com/
- def __init__(self):
- # hex = matplotlib.colors.TABLEAU_COLORS.values()
- hex = ('FF3838', 'FF9D97', 'FF701F', 'FFB21D', 'CFD231', '48F90A', '92CC17', '3DDB86', '1A9334', '00D4BB',
- '2C99A8', '00C2FF', '344593', '6473FF', '0018EC', '8438FF', '520085', 'CB38FF', 'FF95C8', 'FF37C7')
- self.palette = [self.hex2rgb('#' + c) for c in hex]
- self.n = len(self.palette)
-
- def __call__(self, i, bgr=False):
- c = self.palette[int(i) % self.n]
- return (c[2], c[1], c[0]) if bgr else c
-
- @staticmethod
- def hex2rgb(h): # rgb order (PIL)
- return tuple(int(h[1 + i:1 + i + 2], 16) for i in (0, 2, 4))
-
-
- colors = Colors()
-
-
- class yolov5():
- def __init__(self, onnx_path, confThreshold=0.25, nmsThreshold=0.45):
- # self.classes=["yawn","sleep"]
- self.classes = ['person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat',
- 'traffic light',
- 'fire hydrant', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep',
- 'cow',
- 'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie', 'suitcase',
- 'frisbee',
- 'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard',
- 'surfboard',
- 'tennis racket', 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana',
- 'apple',
- 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair',
- 'couch',
- 'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote', 'keyboard',
- 'cell phone',
- 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase', 'scissors',
- 'teddy bear',
- 'hair drier', 'toothbrush']
- self.colors = [np.random.randint(0, 255, size=3).tolist() for _ in range(len(self.classes))]
- num_classes = len(self.classes)
- self.anchors = [[10, 13, 16, 30, 33, 23], [30, 61, 62, 45, 59, 119], [116, 90, 156, 198, 373, 326]]
- self.nl = len(self.anchors)
- self.na = len(self.anchors[0]) // 2
- self.no = num_classes + 5
- self.stride = np.array([8., 16., 32.])
- self.inpWidth = 640
- self.inpHeight = 640
- self.net = cv2.dnn.readNetFromONNX(onnx_path)
-
- self.confThreshold = confThreshold
- self.nmsThreshold = nmsThreshold
判断传人的文件是视频还是图片或者实时识别:
- def mult_test(onnx_path, img_dir, save_root_path, video=False):
- model = yolov5(onnx_path)
- if video:
- cap = cv2.VideoCapture(0)
- frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
- frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
- fps = cap.get(cv2.CAP_PROP_FPS) # 视频平均帧率
- size = (frame_height, frame_width) # 尺寸和帧率和原视频相同
- fourcc = cv2.VideoWriter_fourcc(*'XVID')
- out = cv2.VideoWriter('zi.mp4', fourcc, fps, size)
-
- while cap.isOpened():
- ok, frame = cap.read()
- if not ok:
- break
- frame,rr = model.detect(frame)
- out.write(frame)
- cv2.imshow('result', frame)
- c = cv2.waitKey(1) & 0xFF
- if c == 27 or c == ord('q'):
- break
- cap.release()
- out.release()
- cv2.destroyAllWindows()
- else:
- if not os.path.exists(save_root_path):
- os.mkdir(save_root_path)
- for root, dir, files in os.walk(img_dir):
- for file in files:
- image_path = os.path.join(root, file)
- save_path = os.path.join(save_root_path, file)
- if "mp4" in file or 'avi' in file:
- cap = cv2.VideoCapture(image_path)
- frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
- frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
- fps = cap.get(cv2.CAP_PROP_FPS)
- size = (frame_width, frame_height)
- fourcc = cv2.VideoWriter_fourcc(*'XVID')
- out = cv2.VideoWriter(save_path, fourcc, fps, size)
- while cap.isOpened():
- ok, frame = cap.read()
- if not ok:
- break
- frame,rr= model.detect(frame)
- out.write(frame)
- cap.release()
- out.release()
- print(" finish: ", file)
- elif 'jpg' or 'png' in file:
- srcimg = cv2.imread(image_path)
- srcimg ,rr= model.detect(srcimg)
- print(" finish: ", file)
- cv2.imwrite(save_path, srcimg)
目标跟踪我使用的是deepsort算法
- def detect(self, srcimg):
- results=[]
- im = srcimg.copy()
- im, ratio, wh = self.letterbox(srcimg, self.inpWidth, stride=self.stride, auto=False)
- # Sets the input to the network
- blob = cv2.dnn.blobFromImage(im, 1 / 255.0, swapRB=True, crop=False)
- self.net.setInput(blob)
- outs = self.net.forward(self.net.getUnconnectedOutLayersNames())[0]
- # NMS
- pred = self.non_max_suppression(outs, self.confThreshold, agnostic=False)
-
- # print(box[1])
- result_list = []
- # print(pred)
- # draw box
- for i in pred[0]:
- left = int((i[0] - wh[0]) / ratio[0])
- top = int((i[1] - wh[1]) / ratio[1])
- width = int((i[2] - wh[0]) / ratio[0])
- height = int((i[3] - wh[1]) / ratio[1])
- # 假设 left、top、width、height 分别为在640尺寸图片上得到的框的坐标
- # ratio 为 原始图片尺寸 / 640
- # print(img11[0]/640)
-
-
- # x, y, w, h = int((i[0] - wh[0]) / ratio[0]), int((i[1] - wh[1]) / ratio[1]), int((i[2]) / ratio[0]), int(
- # (i[3]) / ratio[1])
- conf = i[4]##置信度
- classId = i[5]##id
- cv2.rectangle(srcimg, (int(left), int(top)), (int(width), int(height)), colors(classId, True), 2,
- lineType=cv2.LINE_AA)
- label = '%.2f' % conf
- label = '%s:%s' % (self.classes[int(classId)], label)
- # Display the label at the top of the bounding box
- labelSize, baseLine = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1)
- top = max(top, labelSize[1])
- cv2.putText(srcimg, label, (int(left - 20), int(top - 10)), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 255),
- thickness=1, lineType=cv2.LINE_AA)
-
- result = torch.Tensor([left,top,width,height, conf, classId]).unsqueeze(0) # 转换为张量并添加一个维度
- # print(result)
- result_list.append(result)
-
- if len(result_list) == 0:
- # 创建一个空的 Tensor 对象,或者使用已有的非空 Tensor 对象
- tensor = torch.Tensor()
- result_list.append(tensor)
-
- results.append(torch.cat(result_list, dim=0).cpu().detach())
- # print([result.tolist() for result in results[0]])
-
- # print(results)
- return results
将yolo检测到的目标框的信息(每一帧)传给deepsort进行处理判断是否是同一个对象如果是新对象id+1
在每一帧中,目标跟踪模块将通过目标检测器获得的边界框作为输入,并在当前帧中建立与前一帧中目标的关联。这样可以实现目标的连续跟踪。DeepSORT使用一种称为卡尔曼滤波器的方法来对目标位置进行预测和更新。卡尔曼滤波器使用一个动态模型来估计目标的运动轨迹,并根据观测结果进行修正。
ID关联: DeepSORT通过将目标的外观特征与运动信息相结合,将当前帧中的目标与已经被跟踪的目标进行关联。它使用深度神经网络来提取目标的外观特征,并计算不同目标之间的相似度。通过计算相似度,DeepSORT能够判断当前帧中的目标是否与之前跟踪的目标相匹配,从而实现目标的ID关联。
qt界面部分代码:
- class MainWindow(QWidget):
-
- def __init__(self):
- super().__init__()
- self.num=0
- self.setWindowTitle("课堂行为检测系统V1")
- self.setup_ui()
- self.image=None
- self.media_player = QMediaPlayer(self)
- self.resize(800, 600)
-
- def setup_ui(self):
- """设置主窗口布局"""
- layout = QVBoxLayout(self)
-
- # 添加标题标签
- label = QLabel("<html><head/><body><p><span style=\" font-size:20pt; color:#00007f;\">课堂行为检测系统V1</span></p></body></html>")
- label.setAlignment(Qt.AlignHCenter)
- layout.addWidget(label)
-
- # 添加视频播放部件和文本编辑部件
- video_layout = QHBoxLayout()
- self.video_widget = QVideoWidget(self)
- video_layout.addWidget(self.video_widget, 3)
-
- self.label2 = QLabel(self)
- self.label2.hide()
- self.label2.setFixedSize(640,640)
- video_layout.addWidget(self.label2, 2)
-
- self.text_edit = QPlainTextEdit(self)
- video_layout.addWidget(self.text_edit, 2)
- self.text_edit.setPlaceholderText("识别到的数量")
- layout.addLayout(video_layout)
-
- # 添加按钮部件
- button_layout = QHBoxLayout()
- layout.addLayout(button_layout)
-
- self.open_file_button = QPushButton("传入文件")
- self.open_file_button.clicked.connect(self.open_file_dialog)
- button_layout.addWidget(self.open_file_button)
-
- self.start_detection_button = QPushButton("开始检测")
- self.start_detection_button.clicked.connect(self.start_detection)
- button_layout.addWidget(self.start_detection_button)
-
- self.start_realtime_detection_button = QPushButton("实时检测")
- self.start_realtime_detection_button.clicked.connect(self.start_realtime_detection)
- button_layout.addWidget(self.start_realtime_detection_button)
-
- self.stop_detection_button = QPushButton("停止检测")
- self.stop_detection_button.clicked.connect(self.stop_detection)
- button_layout.addWidget(self.stop_detection_button)
-
- self.stop_start_playing=QPushButton("播放")
- self.stop_start_playing.clicked.connect(self.play)
- button_layout.addWidget(self.stop_start_playing)
-
点击转入文件会将文件传入到input文件夹内 点击识别会将视频识别后存放在output文件夹内并进行播放 目标跟踪会转回目标数量 进行班级扣分处理发现睡觉和打哈欠一共有3个人所以班级扣除分数3分 并且页面内也加入实时识别功能 可以加入教室摄像头进行班级识别
教室违规行为识别系统
本文制作不易 有些地方写的不好还请见谅
如果觉得写的还好请给博主个小星星⭐️每一个小星星都是让博主努力的动力
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。