当前位置:   article > 正文

项目总结:yolov5+bytrack 入侵检测

bytrack

yolov5+bytrack可以实现行人跟踪,每个人都会有一个ID,这就能很好的入侵人数做计算,或者能保存每个入侵者的图像

思路是:

准备一个与检测图像同样大小的attribute_bord,指定的禁入区域值为255,安全区值为0

每一帧都用行人目标的中心点坐标去attribute_bord上索要属性值,检查是否已进入禁入区

如进入禁入区,则将其ID加入入侵列表中。

效果:

关键代码:

  1. *!
  2. @Description : https://github.com/shaoshengsong/
  3. @Author : shaoshengsong
  4. @Date : 2022-09-23 02:52:22
  5. */
  6. #include <fstream>
  7. #include <sstream>
  8. //#include <iostream>
  9. //#include <cstring>
  10. #include <opencv2/opencv.hpp>
  11. #include "YOLOv5Detector.h"
  12. //#include "FeatureTensor.h"
  13. #include "BYTETracker.h" //bytetrack
  14. #include "tracker.h"//deepsort
  15. //Deep SORT parameter
  16. // https://cloud.tencent.com/developer/article/2099504
  17. using namespace std;
  18. using namespace cv;
  19. std::vector<cv::Point> pts;
  20. void on_Mouse(int event, int x, int y, int flags, void* param){
  21. if (event == EVENT_LBUTTONDOWN)// 左键按下
  22. {
  23. Point pt = Point(x, y);
  24. pts.push_back(pt);
  25. }
  26. }
  27. void invade_detect(Mat& img, vector<Point>& pts, vector<STrack>& bytrack_result){
  28. int invaded = 255; //安全区标识板颜色值为0
  29. vector<int> invade_IDs; // 入侵人员id容器
  30. Mat attribute_board = Mat::zeros(img.rows, img.cols, CV_8UC1); // 标识板
  31. if(pts.size() == 2){
  32. line(img,pts[0], pts[1], Scalar(0,0,180), 2, 8);
  33. }
  34. else if (pts.size() > 2) {
  35. fillPoly(attribute_board, pts, Scalar(invaded)); // 设置入侵区域标识板颜色值为255
  36. polylines(img, pts, true, Scalar(0,0,180), 2, 8); // 画出入侵区域
  37. }
  38. for (size_t i=0; i<bytrack_result.size(); i++){
  39. STrack temp_person = bytrack_result[i];
  40. // perosn bbox center point
  41. int left = temp_person.tlwh[0];
  42. int top = temp_person.tlwh[1];
  43. int width = temp_person.tlwh[2];
  44. int height = temp_person.tlwh[3];
  45. Point center = Point(left + width/2, top + height/2);
  46. // whether invade
  47. int whether_invade = attribute_board.at<uchar>(center.y, center.x);
  48. if (whether_invade == invaded){ // 闯入划定区域
  49. if(!std::count(invade_IDs.begin(), invade_IDs.end(), temp_person.track_id )) { // 新入侵者
  50. invade_IDs.push_back(temp_person.track_id ); // 加入入侵者列表
  51. rectangle(img, cv::Rect(temp_person.tlwh[0], temp_person.tlwh[1],
  52. temp_person.tlwh[2], temp_person.tlwh[3]), Scalar(0,0,255),2, 8); // 画bbox,红框
  53. string label = "Person ID:" + to_string(temp_person.track_id);
  54. putText(img, label, Point(left, top), FONT_HERSHEY_SIMPLEX, 0.5, Scalar(0,0,255), 2, 8);
  55. }
  56. }
  57. else{ // 为闯入划定区域
  58. rectangle(img, cv::Rect(temp_person.tlwh[0], temp_person.tlwh[1],
  59. temp_person.tlwh[2], temp_person.tlwh[3]), Scalar(255,255,255),2, 8); // 画bbox,白框
  60. string label = "Person ID:" + to_string(temp_person.track_id);
  61. putText(img, label, Point(left, top), FONT_HERSHEY_SIMPLEX, 0.5, Scalar(255,255,255), 2, 8);
  62. }
  63. }
  64. }
  65. std::vector<STrack> test_bytetrack(cv::Mat& frame, std::vector<detect_result>& results,
  66. BYTETracker& tracker, std::vector<std::string> & classes)
  67. {
  68. std::vector<detect_result> objects;
  69. for (detect_result dr : results)
  70. {
  71. if(dr.classId == 0 ) // person
  72. {
  73. objects.push_back(dr);
  74. }
  75. }
  76. // bytetrack主函数
  77. std::vector<STrack> output_stracks = tracker.update(objects);
  78. // // 对于track结果
  79. // for (unsigned long i = 0; i < output_stracks.size(); i++)
  80. // {
  81. // std::vector<float> tlwh = output_stracks[i].tlwh;
  82. // // 目标像素不能小于20个像素、宽高比需 < 1.6
  83. // // 其实这里加一个判定不是很合理,因为 bytetrack主函数中ID已经增加了,只不过没有显示该目标而已
  84. bool vertical = tlwh[2] / tlwh[3] > 1.6;
  85. if (tlwh[2] * tlwh[3] > 20 && !vertical)
  86. // if(1)
  87. // {
  88. // // 框出目标,并标注ID
  89. // cv::Scalar s = tracker.get_color(output_stracks[i].track_id);
  90. // cv::putText(frame, cv::format("%s--%d",classes[output_stracks[i].class_index].c_str(),output_stracks[i].track_id),
  91. // cv::Point(tlwh[0], tlwh[1] - 5),
  92. // 0, 0.6, cv::Scalar(0, 0, 255), 2, cv::LINE_AA);
  93. // cv::rectangle(frame, cv::Rect(tlwh[0], tlwh[1], tlwh[2], tlwh[3]), s, 2);
  94. // }
  95. // }
  96. return output_stracks;
  97. }
  98. int main(int argc, char *argv[])
  99. {
  100. // 加载类别名称
  101. std::vector<std::string> classes;
  102. std::string file="/home/jason/work/my-deploy/01-bytetrack-deepsort/coco_80_labels_list.txt";
  103. std::ifstream ifs(file);
  104. if (!ifs.is_open())
  105. CV_Error(cv::Error::StsError, "File " + file + " not found");
  106. std::string line;
  107. while (std::getline(ifs, line))
  108. {
  109. classes.push_back(line);
  110. }
  111. // 检测器
  112. std::cout<<"classes:"<<classes.size();
  113. std::shared_ptr<YOLOv5Detector> detector(new YOLOv5Detector());
  114. detector->init(k_detect_model_path);
  115. //bytetrack设置
  116. int fps=20;
  117. BYTETracker bytetracker(fps, 30); // 后面的30是30帧没有发现,
  118. // 读取视频
  119. std::cout<<"begin read video"<<std::endl;
  120. cv::VideoCapture capture("/home/jason/work/my-deploy/01-bytetrack-deepsort/1.mp4");
  121. if (!capture.isOpened()) {
  122. printf("could not read this video file...\n");
  123. return -1;
  124. }
  125. std::cout<<"end read video"<<std::endl;
  126. // yolo检测结果
  127. std::vector<detect_result> results;
  128. // 输出另存
  129. cv::VideoWriter video("/home/jason/work/my-deploy/01-bytetrack-deepsort/out.avi",cv::VideoWriter::fourcc('M','J','P','G'),10, cv::Size(1920,1080));
  130. //
  131. namedWindow("bytrack && invade");
  132. setMouseCallback("bytrack && invade", on_Mouse, 0);
  133. int num_frames = 0;
  134. cv::Mat frame;
  135. while (true)
  136. {
  137. if (!capture.read(frame)) // if not success, break loop
  138. {
  139. std::cout<<"\n Cannot read the video file. please check your video.\n";
  140. break;
  141. }
  142. num_frames ++;
  143. //Second/Millisecond/Microsecond 秒s/毫秒ms/微秒us
  144. auto start = std::chrono::system_clock::now();
  145. // 获得检测结果
  146. detector->detect(frame, results);
  147. // 计算检测耗时
  148. auto end = std::chrono::system_clock::now();
  149. auto detect_time =std::chrono::duration_cast<std::chrono::milliseconds>(end - start).count();//ms
  150. // std::cout<<classes.size()<<":"<<results.size()<<":"<<num_frames<<std::endl;
  151. printf("视频尺寸:%d宽 * %d高\n", frame.cols, frame.rows );
  152. printf("帧数:%d 检测器耗时:%dms ",num_frames, (int)detect_time);
  153. // 进行跟踪
  154. std::vector<STrack> temp_tracks;
  155. auto start2 = std::chrono::system_clock::now();
  156. temp_tracks = test_bytetrack(frame, results,bytetracker,classes);
  157. auto end2 = std::chrono::system_clock::now();
  158. // 计算跟踪器耗时
  159. auto detect_time2 =std::chrono::duration_cast<std::chrono::milliseconds>(end2 - start2).count();//ms
  160. printf("跟踪器耗时:%dms \n", (int)detect_time2);
  161. // 入侵检测
  162. invade_detect(frame, pts, temp_tracks);
  163. cv::imshow("bytrack && invade", frame);
  164. // 保存结果至指定路径
  165. video.write(frame);
  166. if(cv::waitKey(1) == 27) // Wait for 'esc' key press to exit
  167. {
  168. break;
  169. }
  170. results.clear();
  171. }
  172. capture.release();
  173. video.release();
  174. cv::destroyAllWindows();
  175. }

声明:本文内容由网友自发贡献,转载请注明出处:【wpsshop博客】
推荐阅读
相关标签
  

闽ICP备14008679号