当前位置:   article > 正文

使用OpenCV和MediaPipe实现姿态识别!

mediapipe

081b94d788bd594d88f1be089f3b635e.gif

大家好,我是小F~

MediaPipe是一款由Google开发并开源的数据流处理机器学习应用开发框架。

它是一个基于图的数据处理管线,用于构建使用了多种形式的数据源,如视频、音频、传感器数据以及任何时间序列数据

MediaPipe通过将各个感知模型抽象为模块并将其连接到可维护的图中来解决这些问题。

项目地址:

https://github.com/google/mediapipe

今天小F就给大家介绍一下,如何使用MediaPipe实现姿态识别

ff8827cab41419601df2c841eb13bf57.png

通过这项技术,我们可以结合摄像头,智能识别人的行为,然后做出一些处理。

比如控制电脑音量,俯卧撑计数,坐姿矫正等功能。

/ 01 /

依赖安装

使用的Python版本是3.9.7。

需要安装以下依赖。

  1. mediapipe==0.9.2.1
  2. numpy==1.23.5
  3. opencv-python==4.7.0.72

使用pip命令进行安装,环境配置好后,就可以来看姿态识别的情况了。

有三种,包含全身、脸部、手部的姿态估计。

/ 02 /

全身姿态估计

首先是人体姿态估计,一次只能跟踪一个人。

并且会在人的身体上显示33个对应的坐标点。

具体代码如下。

  1. import os
  2. import time
  3. import cv2 as cv
  4. import mediapipe as mp
  5. class BodyPoseDetect:
  6.     def __init__(self, static_image=False, complexity=1, smooth_lm=True, segmentation=False, smooth_sm=True, detect_conf=0.5, track_conf=0.5):
  7.         self.mp_body = mp.solutions.pose
  8.         self.mp_draw = mp.solutions.drawing_utils
  9.         self.body = self.mp_body.Pose(static_image, complexity, smooth_lm, segmentation, smooth_sm, detect_conf, track_conf)
  10.     def detect_landmarks(self, img, disp=True):
  11.         img_rgb = cv.cvtColor(img, cv.COLOR_BGR2RGB)
  12.         results = self.body.process(img_rgb)
  13.         detected_landmarks = results.pose_landmarks
  14.         if detected_landmarks:
  15.             if disp:
  16.                 self.mp_draw.draw_landmarks(img, detected_landmarks, self.mp_body.POSE_CONNECTIONS)
  17.         return detected_landmarks, img
  18.     def get_info(self, detected_landmarks, img_dims):
  19.         lm_list = []
  20.         if not detected_landmarks:
  21.             return lm_list
  22.         height, width = img_dims
  23.         for id, b_landmark in enumerate(detected_landmarks.landmark):
  24.             cord_x, cord_y = int(b_landmark.x * width), int(b_landmark.y * height)
  25.             lm_list.append([id, cord_x, cord_y])
  26.         return lm_list
  27. def main(path, is_image):
  28.     if is_image:
  29.         detector = BodyPoseDetect(static_image=True)
  30.         ori_img = cv.imread(path)
  31.         img = ori_img.copy()
  32.         landmarks, output_img = detector.detect_landmarks(img)
  33.         info_landmarks = detector.get_info(landmarks, img.shape[:2])
  34.         # print(info_landmarks[3])
  35.         cv.imshow("Original", ori_img)
  36.         cv.imshow("Detection", output_img)
  37.         cv.waitKey(0)
  38.     else:
  39.         detector = BodyPoseDetect()
  40.         cap = cv.VideoCapture(path)
  41.         prev_time = time.time()
  42.         cur_time = 0
  43.         frame_width = int(cap.get(3))
  44.         frame_height = int(cap.get(4))
  45.         out = cv.VideoWriter('output.avi', cv.VideoWriter_fourcc('M''J''P''G'), 10, (frame_width, frame_height))  # 保存视频
  46.         while True:
  47.             ret, frame = cap.read()
  48.             if not ret:
  49.                 print("Video Over")
  50.                 break
  51.             img = frame.copy()
  52.             landmarks, output_img = detector.detect_landmarks(img)
  53.             info_landmarks = detector.get_info(landmarks, img.shape[:2])
  54.             cur_time = time.time()
  55.             fps = 1/(cur_time - prev_time)
  56.             prev_time = cur_time
  57.             cv.putText(output_img, f'FPS: {str(int(fps))}', (1070), cv.FONT_HERSHEY_COMPLEX_SMALL, 2, (050170), 2)
  58.             cv.namedWindow('Original', cv.WINDOW_NORMAL)  # 窗口大小可设置
  59.             cv.resizeWindow('Original'580330)  # 重设大小
  60.             cv.namedWindow('Detection', cv.WINDOW_NORMAL)  # 窗口大小可设置
  61.             cv.resizeWindow('Detection'580330)  # 重设大小
  62.             out.write(output_img)
  63.             cv.imshow("Original", frame)
  64.             cv.imshow("Detection", output_img)
  65.             if cv.waitKey(1) & 0xFF == ord('q'):
  66.                 break
  67.         cap.release()
  68.     cv.destroyAllWindows()
  69. if __name__ == "__main__":
  70.     # is_image = True
  71.     # media_path = '.\\Data\\Images\\running.jpg'
  72.     is_image = False
  73.     media_path = '.\\Data\\Videos\\basketball.mp4'
  74.     if os.path.exists(os.path.join(os.getcwd(), media_path)):
  75.         main(media_path, is_image)
  76.     else:
  77.         print("Invalid Path")

运行代码后,结果如下。

8cf7319c94fe06eba1118ada2643bfe2.png

左侧是原图,右侧是检测结果。

其中代码里的is_image参数表示是否为图片或视频

media_path参数则表示的是源文件的地址。

我们还可以看视频的检测效果,具体如下。

f4b36413d5d00f8b75ee6f3a330acc3c.gif

效果还不错。

/ 03 /

脸部识别跟踪

第二个是脸部,MediaPipe可以在脸部周围画一个网格来进行检测和跟踪。

具体代码如下。

  1. import os
  2. import time
  3. import argparse
  4. import cv2 as cv
  5. import mediapipe as mp
  6. class FaceDetect:
  7.     def __init__(self, static_image=False, max_faces=1, refine=False, detect_conf=0.5, track_conf=0.5):
  8.         self.draw_utils = mp.solutions.drawing_utils
  9.         self.draw_spec = self.draw_utils.DrawingSpec(color=[02550], thickness=1, circle_radius=2)
  10.         self.mp_face_track = mp.solutions.face_mesh
  11.         self.face_track = self.mp_face_track.FaceMesh(static_image, max_faces, refine, detect_conf, track_conf)
  12.     def detect_mesh(self, img, disp=True):
  13.         results = self.face_track.process(img)
  14.         detected_landmarks = results.multi_face_landmarks
  15.         if detected_landmarks:
  16.             if disp:
  17.                 for f_landmarks in detected_landmarks:
  18.                     self.draw_utils.draw_landmarks(img, f_landmarks, self.mp_face_track.FACEMESH_CONTOURS, self.draw_spec, self.draw_spec)
  19.         return detected_landmarks, img
  20.     def get_info(self, detected_landmarks, img_dims):
  21.         landmarks_info = []
  22.         img_height, img_width = img_dims
  23.         for _, face in enumerate(detected_landmarks):
  24.             mesh_info = []
  25.             for id, landmarks in enumerate(face.landmark):
  26.                 x, y = int(landmarks.x * img_width), int(landmarks.y * img_height)
  27.                 mesh_info.append((id, x, y))
  28.             landmarks_info.append(mesh_info)
  29.         return landmarks_info
  30. def main(path, is_image=True):
  31.     print(path)
  32.     if is_image:
  33.         detector = FaceDetect()
  34.         ori_img = cv.imread(path)
  35.         img = ori_img.copy()
  36.         landmarks, output = detector.detect_mesh(img)
  37.         if landmarks:
  38.             mesh_info = detector.get_info(landmarks, img.shape[:2])
  39.             # print(mesh_info)
  40.         cv.imshow("Result", output)
  41.         cv.waitKey(0)
  42.     else:
  43.         detector = FaceDetect(static_image=False)
  44.         cap = cv.VideoCapture(path)
  45.         curr_time = 0
  46.         prev_time = time.time()
  47.         frame_width = int(cap.get(3))
  48.         frame_height = int(cap.get(4))
  49.         out = cv.VideoWriter('output.avi', cv.VideoWriter_fourcc('M''J''P''G'), 10, (frame_width, frame_height))  # 保存视频
  50.         while True:
  51.             ret, frame = cap.read()
  52.             if not ret:
  53.                 print("Video Over")
  54.                 break
  55.             img = frame.copy()
  56.             landmarks, output = detector.detect_mesh(img)
  57.             if landmarks:
  58.                 mesh_info = detector.get_info(landmarks, img.shape[:2])
  59.                 # print(len(mesh_info))
  60.             curr_time = time.time()
  61.             fps = 1/(curr_time - prev_time)
  62.             prev_time = curr_time
  63.             cv.putText(output, f'FPS: {str(int(fps))}', (1070), cv.FONT_HERSHEY_COMPLEX_SMALL, 2, (050170), 2)
  64.             cv.namedWindow('Result', cv.WINDOW_NORMAL)  # 窗口大小可设置
  65.             cv.resizeWindow('Result'580330)  # 重设大小
  66.             out.write(output)
  67.             cv.imshow("Result", output)
  68.             if cv.waitKey(20) & 0xFF == ord('q'):
  69.                 break
  70.         cap.release()
  71.     cv.destroyAllWindows()
  72. if __name__ == "__main__":
  73.     # is_image = True
  74.     # media_path = '.\\Data\\Images\\human_2.jpg'
  75.     is_image = False
  76.     media_path = '.\\Data\\Videos\\humans_3.mp4'
  77.     if os.path.exists(os.path.join(os.getcwd(), media_path)):
  78.         main(media_path, is_image)
  79.     else:
  80.         print("Invalid Path")

效果如下。

7b3071cacaeb7706eee2e69d27ec2d81.png

/ 04 /

手部跟踪识别

最后一个是手部,可以同时跟踪2只手并且在手部显示相应的坐标点。

具体代码如下。

  1. import os
  2. import time
  3. import argparse
  4. import cv2 as cv
  5. import mediapipe as mp
  6. class HandPoseDetect:
  7.     def __init__(self, static_image=False, max_hands=2, complexity=1, detect_conf=0.5, track_conf=0.5):
  8.         self.mp_hands = mp.solutions.hands
  9.         self.mp_draw = mp.solutions.drawing_utils
  10.         self.hands = self.mp_hands.Hands(static_image, max_hands, complexity, detect_conf, track_conf)
  11.     def detect_landmarks(self, img, disp=True):
  12.         img_rgb = cv.cvtColor(img, cv.COLOR_BGR2RGB)
  13.         results = self.hands.process(img_rgb)
  14.         detected_landmarks = results.multi_hand_landmarks
  15.         if detected_landmarks:
  16.             if disp:
  17.                 for h_landmark in detected_landmarks:
  18.                     self.mp_draw.draw_landmarks(img, h_landmark, self.mp_hands.HAND_CONNECTIONS)
  19.         return detected_landmarks, img
  20.     def get_info(self, detected_landmarks, img_dims, hand_no=1):
  21.         lm_list = []
  22.         if not detected_landmarks:
  23.             return lm_list
  24.         if hand_no > 2:
  25.             print('[WARNING] Provided hand number is greater than max number 2')
  26.             print('[WARNING] Calculating information for hand 2')
  27.             hand_no = 2
  28.         elif hand_no < 1:
  29.             print('[WARNING] Provided hand number is less than min number 1')
  30.             print('[WARNING] Calculating information for hand 1')
  31.         if len(detected_landmarks) < 2:
  32.             hand_no = 0
  33.         else:
  34.             hand_no -= 1
  35.         height, width = img_dims
  36.         for id, h_landmarks in enumerate(detected_landmarks[hand_no].landmark):
  37.             cord_x, cord_y = int(h_landmarks.x * width), int(h_landmarks.y * height)
  38.             lm_list.append([id, cord_x, cord_y])
  39.         return lm_list
  40. def main(path, is_image=True):
  41.     if is_image:
  42.         detector = HandPoseDetect(static_image=True)
  43.         ori_img = cv.imread(path)
  44.         img = ori_img.copy()
  45.         landmarks, output_img = detector.detect_landmarks(img)
  46.         info_landmarks = detector.get_info(landmarks, img.shape[:2], 2)
  47.         # print(info_landmarks)
  48.         cv.imshow("Landmarks", output_img)
  49.         cv.waitKey(0)
  50.     else:
  51.         detector = HandPoseDetect()
  52.         cap = cv.VideoCapture(path)
  53.         prev_time = time.time()
  54.         cur_time = 0
  55.         frame_width = int(cap.get(3))
  56.         frame_height = int(cap.get(4))
  57.         out = cv.VideoWriter('output.avi', cv.VideoWriter_fourcc('M''J''P''G'), 10, (frame_width, frame_height))  # 保存视频
  58.         while True:
  59.             ret, frame = cap.read()
  60.             if not ret:
  61.                 print("Video Over")
  62.                 break
  63.             img = frame.copy()
  64.             landmarks, output_img = detector.detect_landmarks(img)
  65.             info_landmarks = detector.get_info(landmarks, img.shape[:2], 2)
  66.             # print(info_landmarks)
  67.             cur_time = time.time()
  68.             fps = 1/(cur_time - prev_time)
  69.             prev_time = cur_time
  70.             cv.putText(output_img, f'FPS: {str(int(fps))}', (1070), cv.FONT_HERSHEY_COMPLEX_SMALL, 2, (050170), 2)
  71.             cv.namedWindow('Original', cv.WINDOW_NORMAL)  # 窗口大小可设置
  72.             cv.resizeWindow('Original'580330)  # 重设大小
  73.             cv.namedWindow('Detection', cv.WINDOW_NORMAL)  # 窗口大小可设置
  74.             cv.resizeWindow('Detection'580330)  # 重设大小
  75.             out.write(output_img)
  76.             cv.imshow("Detection", output_img)
  77.             cv.imshow("Original", frame)
  78.             if cv.waitKey(1) & 0xFF == ord('q'):
  79.                 break
  80.         cap.release()
  81.     cv.destroyAllWindows()
  82. if __name__ == "__main__":
  83.     is_image = False
  84.     media_path = '.\\Data\\Videos\\piano_playing.mp4'
  85.     if os.path.exists(os.path.join(os.getcwd(), media_path)):
  86.         main(media_path, is_image)
  87.     else:
  88.         print("Invalid Path")

结果如下所示。

a40c6a3287dee062506b3dbe8579479c.gif

/ 05 /

总结

以上操作,就是MediaPipe姿态识别的部分内容。

当然我们还可以通过MediaPipe其它的识别功能,来做出有趣的事情。

比如结合摄像头,识别手势动作,控制电脑音量。这个大家都可以自行去学习。

相关文件及代码都已上传,公众号回复【姿态识别】即可获取。

万水千山总是情,点个 

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/笔触狂放9/article/detail/360935
推荐阅读
相关标签