赞
踩
请参阅 YOLOv5 文档,了解有关训练、测试和部署的完整文档。有关快速入门示例,请参阅下文。
在 Python>=3.8.0 环境中克隆存储库并安装requirements.txt,包括 PyTorch>=1.8。
git clone https://github.com/ultralytics/yolov5 # clone
cd yolov5
pip install -r requirements.txt # install
代码如下(示例):
import os import sys cwd = os.getcwd() sys.path.insert(0, '/CSPLAT/TEST') sys.path.append(cwd) import numpy as np import torch import torch.nn as nn import cv2 class Yolov5sDetector(nn.Module): def __init__(self, repo_or_dir = 'CSPLAT/TEST/utralytics_yolov5_master', path = 'utralytics_yolov5_master/pretrain_model/yolov5s.pt', source = 'local'): super().__init__() self.detector = torch.hub.load(repo_or_dir , 'custom', path = path, source=source) def forward(self, img): ''' Input: img shape:(260, 210, 3) img shape:(260, 210, 3) Output: class_ids confidences: boxes: indices: an array containing the indices of the retained bounding boxes. ''' with torch.no_grad(): results = self.detector(img) person_results = results.xyxy[0][results.xyxy[0][:, 5] == 0] class_ids, confidences, boxes = [], [], [] for detection in person_results: x1, y1, x2, y2, confidence, class_id = detection.tolist() class_ids.append(class_id) confidences.append(confidence) boxes.append([x1, y1, x2 - x1, y2 - y1]) indices = cv2.dnn.NMSBoxes(boxes, confidences, 0.5, 0.4) # for num, indice in enumerate(indices): # bbox = boxes[indice] # x,y,h,w [0:21.7, 1:6.6, 2:155.2, 3:248.9] return boxes, indices def load_img(self, path, order='RGB'): img = cv2.imread(path, cv2.IMREAD_COLOR | cv2.IMREAD_IGNORE_ORIENTATION) if not isinstance(img, np.ndarray): raise IOError("Fail to read %s" % path) if order == 'RGB': img = img[:, :, ::-1].copy() img = img.astype(np.float32) return img def draw_bbox(self, draw_img, bbox, label = 'Human', label_color=(255,0,255), save_path=None): bbox = [round(bbox[0]), round(bbox[1]), round(bbox[0] + bbox[2]), round(bbox[1] + bbox[3]) ] draw_img = cv2.rectangle(draw_img, (bbox[0], bbox[1]), (bbox[2], bbox[3]), color=(255,0,255), thickness=2) labelSize = cv2.getTextSize(label + '0', cv2.FONT_HERSHEY_SIMPLEX, 0.5, 2)[0] if bbox[1] - labelSize[1] - 3 < 0: draw_img = cv2.rectangle(draw_img, (bbox[0], bbox[1] + 2), (bbox[0] + labelSize[0], bbox[1] + labelSize[1] + 3), color=label_color, thickness=-1) draw_img = cv2.putText(draw_img, label, (bbox[0], bbox[1] + labelSize + 3), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 0), thickness=1) else: draw_img = cv2.rectangle(draw_img, (bbox[0], bbox[1] - labelSize[1] - 3), (bbox[0] + labelSize[0], bbox[1] - 3), color=label_color, thickness=-1) draw_img = cv2.putText(draw_img, label, (bbox[0], bbox[1] - 3), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 0), thickness=1) if save_path: cv2.imwrite(save_path, img=draw_img) return draw_img if __name__ == "__main__": yolo_detector = Yolov5sDetector( repo_or_dir = '/CSPLAT/TEST/utralytics_yolov5_master', path = '/CSPLAT/TEST/utralytics_yolov5_master/pretrain_model/yolov5s.pt') img_path = '/CSPLAT/TEST/utralytics_yolov5_master/data/images/zidane.jpg' img = yolo_detector.load_img( path = img_path, order='RGB' ) boxes, indices = yolo_detector(img) print(boxes) draw_img = yolo_detector.draw_bbox( draw_img=img, bbox=boxes[0], save_path=None)
该处使用的url网络请求的数据。
**注:**骨骼点检测不在本次任务中
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。