当前位置:   article > 正文

[AI智能摄像头]RV1126部署yolov5并加速

[AI智能摄像头]RV1126部署yolov5并加速

导出onnx模型

yolov5官方地址 git clone https://github.com/ultralytics/yolov5

利用官方命令导出python export.py --weights yolov5n.pt --include onnx

利用代码导出

  1. import os
  2. import sys
  3. os.chdir(sys.path[0])
  4. import onnx
  5. import torch
  6. sys.path.append('..')
  7. from models.common import DetectMultiBackend
  8. from models.experimental import attempt_load
  9. DEVICE='cuda' if torch.cuda.is_available else 'cpu'
  10. def main():
  11.     """create model """
  12.     input = torch.randn(1, 3, 640, 640, requires_grad=False).float().to(torch.device(DEVICE))
  13.     model = attempt_load('./model/yolov5n.pt', device=DEVICE, inplace=True, fuse=True)  # load FP32 model
  14.     #model = DetectMultiBackend('./model/yolov5n.pt', data=input)
  15.     model.to(DEVICE)
  16.     torch.onnx.export(model,
  17.             input,
  18.             'yolov5n_self.onnx', # name of the exported onnx model
  19.             export_params=True,
  20.             opset_version=12,
  21.             do_constant_folding=False,
  22.             input_names=["images"])
  23. if __name__=="__main__":
  24.     main()

onnx模型测试

  1. import os
  2. import sys
  3. os.chdir(sys.path[0])
  4. import onnxruntime
  5. import torch
  6. import torchvision
  7. import numpy as np
  8. import time
  9. import cv2
  10. sys.path.append('..')
  11. from ultralytics.utils.plotting import Annotator, colors
  12. ONNX_MODEL="./yolov5n.onnx"
  13. DEVICE='cuda' if torch.cuda.is_available() else 'cpu'
  14. def xywh2xyxy(x):
  15.     """Convert nx4 boxes from [x, y, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right."""
  16.     y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
  17.     y[..., 0] = x[..., 0] - x[..., 2] / 2  # top left x
  18.     y[..., 1] = x[..., 1] - x[..., 3] / 2  # top left y
  19.     y[..., 2] = x[..., 0] + x[..., 2] / 2  # bottom right x
  20.     y[..., 3] = x[..., 1] + x[..., 3] / 2  # bottom right y
  21.     return y
  22. def box_iou(box1, box2, eps=1e-7):
  23.     # https://github.com/pytorch/vision/blob/master/torchvision/ops/boxes.py
  24.     """
  25.     Return intersection-over-union (Jaccard index) of boxes.
  26.     Both sets of boxes are expected to be in (x1, y1, x2, y2) format.
  27.     Arguments:
  28.         box1 (Tensor[N, 4])
  29.         box2 (Tensor[M, 4])
  30.     Returns:
  31.         iou (Tensor[N, M]): the NxM matrix containing the pairwise
  32.             IoU values for every element in boxes1 and boxes2
  33.     """
  34.     # inter(N,M) = (rb(N,M,2) - lt(N,M,2)).clamp(0).prod(2)
  35.     (a1, a2), (b1, b2) = box1.unsqueeze(1).chunk(2, 2), box2.unsqueeze(0).chunk(2, 2)
  36.     inter = (torch.min(a2, b2) - torch.max(a1, b1)).clamp(0).prod(2)
  37.     # IoU = inter / (area1 + area2 - inter)
  38.     return inter / ((a2 - a1).prod(2) + (b2 - b1).prod(2) - inter + eps)
  39. def non_max_suppression(
  40.     prediction,
  41.     conf_thres=0.25,
  42.     iou_thres=0.45,
  43.     classes=None,
  44.     agnostic=False,
  45.     multi_label=False,
  46.     labels=(),
  47.     max_det=300,
  48.     nm=0,  # number of masks
  49. ):
  50.     """
  51.     Non-Maximum Suppression (NMS) on inference results to reject overlapping detections.
  52.     Returns:
  53.          list of detections, on (n,6) tensor per image [xyxy, conf, cls]
  54.     """
  55.     # Checks
  56.     assert 0 <= conf_thres <= 1, f"Invalid Confidence threshold {conf_thres}, valid values are between 0.0 and 1.0"
  57.     assert 0 <= iou_thres <= 1, f"Invalid IoU {iou_thres}, valid values are between 0.0 and 1.0"
  58.     device = prediction.device
  59.     mps = "mps" in device.type  # Apple MPS
  60.     if mps:  # MPS not fully supported yet, convert tensors to CPU before NMS
  61.         prediction = prediction.cpu()
  62.     bs = prediction.shape[0]  # batch size
  63.     nc = prediction.shape[2] - nm - 5  # number of classes
  64.     xc = prediction[..., 4] > conf_thres  # candidates
  65.     # Settings
  66.     # min_wh = 2  # (pixels) minimum box width and height
  67.     max_wh = 7680  # (pixels) maximum box width and height
  68.     max_nms = 30000  # maximum number of boxes into torchvision.ops.nms()
  69.     time_limit = 0.5 + 0.05 * bs  # seconds to quit after
  70.     redundant = True  # require redundant detections
  71.     multi_label &= nc > 1  # multiple labels per box (adds 0.5ms/img)
  72.     merge = False  # use merge-NMS
  73.     t = time.time()
  74.     mi = 5 + nc  # mask start index
  75.     output = [torch.zeros((0, 6 + nm), device=prediction.device)] * bs
  76.     for xi, x in enumerate(prediction):  # image index, image inference
  77.         # Apply constraints
  78.         # x[((x[..., 2:4] < min_wh) | (x[..., 2:4] > max_wh)).any(1), 4] = 0  # width-height
  79.         x = x[xc[xi]]  # confidence
  80.         # Cat apriori labels if autolabelling
  81.         if labels and len(labels[xi]):
  82.             lb = labels[xi]
  83.             v = torch.zeros((len(lb), nc + nm + 5), device=x.device)
  84.             v[:, :4] = lb[:, 1:5]  # box
  85.             v[:, 4] = 1.0  # conf
  86.             v[range(len(lb)), lb[:, 0].long() + 5] = 1.0  # cls
  87.             x = torch.cat((x, v), 0)
  88.         # If none remain process next image
  89.         if not x.shape[0]:
  90.             continue
  91.         # Compute conf
  92.         x[:, 5:] *= x[:, 4:5]  # conf = obj_conf * cls_conf
  93.         # Box/Mask
  94.         box = xywh2xyxy(x[:, :4])  # center_x, center_y, width, height) to (x1, y1, x2, y2)
  95.         mask = x[:, mi:]  # zero columns if no masks
  96.         # Detections matrix nx6 (xyxy, conf, cls)
  97.         if multi_label:
  98.             i, j = (x[:, 5:mi] > conf_thres).nonzero(as_tuple=False).T
  99.             x = torch.cat((box[i], x[i, 5 + j, None], j[:, None].float(), mask[i]), 1)
  100.         else:  # best class only
  101.             conf, j = x[:, 5:mi].max(1, keepdim=True)
  102.             x = torch.cat((box, conf, j.float(), mask), 1)[conf.view(-1) > conf_thres]
  103.         # Filter by class
  104.         if classes is not None:
  105.             x = x[(x[:, 5:6] == torch.tensor(classes, device=x.device)).any(1)]
  106.         # Apply finite constraint
  107.         # if not torch.isfinite(x).all():
  108.         #     x = x[torch.isfinite(x).all(1)]
  109.         # Check shape
  110.         n = x.shape[0]  # number of boxes
  111.         if not n:  # no boxes
  112.             continue
  113.         x = x[x[:, 4].argsort(descending=True)[:max_nms]]  # sort by confidence and remove excess boxes
  114.         # Batched NMS
  115.         c = x[:, 5:6] * (0 if agnostic else max_wh)  # classes
  116.         boxes, scores = x[:, :4] + c, x[:, 4]  # boxes (offset by class), scores
  117.         i = torchvision.ops.nms(boxes, scores, iou_thres)  # NMS
  118.         i = i[:max_det]  # limit detections
  119.         if merge and (1 < n < 3e3):  # Merge NMS (boxes merged using weighted mean)
  120.             # update boxes as boxes(i,4) = weights(i,n) * boxes(n,4)
  121.             iou = box_iou(boxes[i], boxes) > iou_thres  # iou matrix
  122.             weights = iou * scores[None]  # box weights
  123.             x[i, :4] = torch.mm(weights, x[:, :4]).float() / weights.sum(1, keepdim=True)  # merged boxes
  124.             if redundant:
  125.                 i = i[iou.sum(1) > 1]  # require redundancy
  126.         output[xi] = x[i]
  127.         if mps:
  128.             output[xi] = output[xi].to(device)
  129.         if (time.time() - t) > time_limit:
  130.             break  # time limit exceeded
  131.     return output
  132. def draw_bbox(image, result, color=(0, 0, 255), thickness=2):
  133.     # img_path = cv2.cvtColor(img_path, cv2.COLOR_BGR2RGB)
  134.     image = image.copy()
  135.     for point in result:
  136.         x1,y1,x2,y2=point
  137.         cv2.rectangle(image, (x1, y1), (x2, y2), color, thickness)
  138.     return image
  139. def main():
  140.     input=torch.load("input.pt").to('cpu')
  141.     input_array=np.array(input)
  142.     onnx_model = onnxruntime.InferenceSession(ONNX_MODEL)
  143.     input_name = onnx_model.get_inputs()[0].name
  144.     out = onnx_model.run(None, {input_name:input_array})
  145.     out_tensor = torch.tensor(out).to(DEVICE)
  146.     pred = non_max_suppression(out_tensor,0.25,0.45,classes=None,agnostic=False,max_det=1000)
  147.     # Process predictions
  148.     for i, det in enumerate(pred):  # per image
  149.         im0_=cv2.imread('../data/images/bus.jpg')
  150.         im0=im0_.reshape(1,3,640,640)
  151.         names=torch.load('name.pt')
  152.         annotator = Annotator(im0, line_width=3, example=str(names))
  153.         coord=[]
  154.         image=im0.reshape(640,640,3)
  155.         if len(det):
  156.             # Rescale boxes from img_size to im0 size
  157.             #det[:, :4] = scale_boxes(im0.shape[2:], det[:, :4], im0.shape).round()
  158.             # Write results
  159.             for *xyxy, conf, cls in reversed(det):
  160.                 # Add bbox to image
  161.                 c = int(cls)  # integer class
  162.                 label = f"{names[c]} {conf:.2f}"
  163.                 # 创建两个顶点坐标子数组,并将它们组合成一个列表``
  164.                 coord.append([int(xyxy[0].item()), int(xyxy[1].item()),int(xyxy[2].item()), int(xyxy[3].item())])
  165.         image=draw_bbox(image,coord)
  166.         # Stream results
  167.         save_success =cv2.imwrite('result.jpg', image)
  168.         print(f"save image end {save_success}")
  169.                
  170. if __name__=="__main__":
  171.     main()

测试结果

板端部署

环境准备

搭建好rknntoolkit以及rknpu环境

大致流程

模型转换

新建export_rknn.py用于将onnx模型转化为rknn模型

  1. import os
  2. import sys
  3. os.chdir(sys.path[0])
  4. import numpy as np
  5. import cv2
  6. from rknn.api import RKNN
  7. import torchvision
  8. import torch
  9. import time
  10. ONNX_MODEL = './model/yolov5n.onnx'
  11. RKNN_MODEL = './model/yolov5n.rknn'
  12. def main():
  13.     """Create RKNN object"""
  14.     rknn = RKNN()
  15.     if not os.path.exists(ONNX_MODEL):
  16.         print('model not exist')
  17.         exit(-1)
  18.        
  19.     """pre-process config"""
  20.     print('--> Config model')
  21.     rknn.config(reorder_channel='0 1 2',
  22.                 mean_values=[[0, 0, 0]],
  23.                 std_values=[[255, 255, 255]],
  24.                 optimization_level=0,
  25.                 target_platform = ['rv1126'],
  26.                 output_optimize=1,
  27.                 quantize_input_node=True)
  28.     print('done')
  29.    
  30.     """Load ONNX model"""
  31.     print('--> Loading model')
  32.     ret = rknn.load_onnx(model=ONNX_MODEL,
  33.                         inputs=['images'],
  34.                         input_size_list = [[3, 640, 640]],
  35.                         outputs=['output0'])
  36.     if ret != 0:
  37.         print('Load yolov5 failed!')
  38.         exit(ret)
  39.     print('done')
  40.     """Build model"""
  41.     print('--> Building model')
  42.     #ret = rknn.build(do_quantization=True,dataset='./data/data.txt')
  43.     ret = rknn.build(do_quantization=False,pre_compile=True)
  44.     if ret != 0:
  45.         print('Build yolov5 failed!')
  46.         exit(ret)
  47.     print('done')
  48.     """Export RKNN model"""
  49.     print('--> Export RKNN model')
  50.     ret = rknn.export_rknn(RKNN_MODEL)
  51.     if ret != 0:
  52.         print('Export yolov5rknn failed!')
  53.         exit(ret)
  54.     print('done')
  55.    
  56. if __name__=="__main__":
  57.     main()

新建test_rknn.py用于测试rknn模型

  1. import os
  2. import sys
  3. os.chdir(sys.path[0])
  4. import numpy as np
  5. import cv2
  6. from rknn.api import RKNN
  7. import torchvision
  8. import torch
  9. import time
  10. RKNN_MODEL = './model/yolov5n.rknn'
  11. DATA='./data/bus.jpg'
  12. def xywh2xyxy(x):
  13.     coord=[]
  14.     for x_ in x:
  15.         xl=x_[0]-x_[2]/2
  16.         yl=x_[1]-x_[3]/2
  17.         xr=x_[0]+x_[2]/2
  18.         yr=x_[1]+x_[3]/2
  19.         coord.append([xl,yl,xr,yr])
  20.     coord=torch.tensor(coord).to(x.device)
  21.     return coord
  22. def box_iou(box1, box2, eps=1e-7):
  23.     # inter(N,M) = (rb(N,M,2) - lt(N,M,2)).clamp(0).prod(2)
  24.     (a1, a2), (b1, b2) = box1.unsqueeze(1).chunk(2, 2), box2.unsqueeze(0).chunk(2, 2)
  25.     inter = (torch.min(a2, b2) - torch.max(a1, b1)).clamp(0).prod(2)
  26.     # IoU = inter / (area1 + area2 - inter)
  27.     return inter / ((a2 - a1).prod(2) + (b2 - b1).prod(2) - inter + eps)
  28. def non_max_suppression(
  29.     prediction,
  30.     conf_thres=0.25,
  31.     iou_thres=0.45,
  32.     classes=None,
  33.     agnostic=False,
  34.     multi_label=False,
  35.     labels=(),
  36.     max_det=300,
  37.     nm=0,  # number of masks
  38. ):
  39.     # Checks
  40.     assert 0 <= conf_thres <= 1, f"Invalid Confidence threshold {conf_thres}, valid values are between 0.0 and 1.0"
  41.     assert 0 <= iou_thres <= 1, f"Invalid IoU {iou_thres}, valid values are between 0.0 and 1.0"
  42.     device = prediction.device
  43.     mps = "mps" in device.type  # Apple MPS
  44.     if mps:  # MPS not fully supported yet, convert tensors to CPU before NMS
  45.         prediction = prediction.cpu()
  46.     bs = prediction.shape[0]  # batch size
  47.     nc = prediction.shape[2] - nm - 5  # number of classes
  48.     xc = prediction[..., 4] > conf_thres  # candidates
  49.     count_true = torch.sum(xc.type(torch.int))
  50.     # Settings
  51.     # min_wh = 2  # (pixels) minimum box width and height
  52.     max_wh = 7680  # (pixels) maximum box width and height
  53.     max_nms = 30000  # maximum number of boxes into torchvision.ops.nms()
  54.     time_limit = 0.5 + 0.05 * bs  # seconds to quit after
  55.     redundant = True  # require redundant detections
  56.     multi_label &= nc > 1  # multiple labels per box (adds 0.5ms/img)
  57.     merge = False  # use merge-NMS
  58.     t = time.time()
  59.     mi = 5 + nc  # mask start index
  60.     output = [torch.zeros((0, 6 + nm), device=prediction.device)] * bs
  61.     for xi, x in enumerate(prediction):  # image index, image inference
  62.         # Apply constraints
  63.         # x[((x[..., 2:4] < min_wh) | (x[..., 2:4] > max_wh)).any(1), 4] = 0  # width-height
  64.         x = x[xc[xi]]  # confidence
  65.         # Cat apriori labels if autolabelling
  66.         if labels and len(labels[xi]):
  67.             lb = labels[xi]
  68.             v = torch.zeros((len(lb), nc + nm + 5), device=x.device)
  69.             v[:, :4] = lb[:, 1:5]  # box
  70.             v[:, 4] = 1.0  # conf
  71.             v[range(len(lb)), lb[:, 0].long() + 5] = 1.0  # cls
  72.             x = torch.cat((x, v), 0)
  73.         # If none remain process next image
  74.         if not x.shape[0]:
  75.             continue
  76.         # Compute conf
  77.         x[:, 5:] *= x[:, 4:5]  # conf = obj_conf * cls_conf
  78.         # Box/Mask
  79.         box = xywh2xyxy(x[:, :4])  # center_x, center_y, width, height) to (x1, y1, x2, y2)
  80.         mask = x[:, mi:]  # zero columns if no masks
  81.         # Detections matrix nx6 (xyxy, conf, cls)
  82.         if multi_label:
  83.             i, j = (x[:, 5:mi] > conf_thres).nonzero(as_tuple=False).T
  84.             x = torch.cat((box[i], x[i, 5 + j, None], j[:, None].float(), mask[i]), 1)
  85.         else:  # best class only
  86.             conf, j = x[:, 5:mi].max(1, keepdim=True)
  87.             x = torch.cat((box, conf, j.float(), mask), 1)[conf.view(-1) > conf_thres]
  88.         # Filter by class
  89.         if classes is not None:
  90.             x = x[(x[:, 5:6] == torch.tensor(classes, device=x.device)).any(1)]
  91.         # Apply finite constraint
  92.         # if not torch.isfinite(x).all():
  93.         #     x = x[torch.isfinite(x).all(1)]
  94.         # Check shape
  95.         n = x.shape[0]  # number of boxes
  96.         if not n:  # no boxes
  97.             continue
  98.         x = x[x[:, 4].argsort(descending=True)[:max_nms]]  # sort by confidence and remove excess boxes
  99.         # Batched NMS
  100.         c = x[:, 5:6] * (0 if agnostic else max_wh)  # classes
  101.         boxes, scores = x[:, :4] + c, x[:, 4]  # boxes (offset by class), scores
  102.         i = torchvision.ops.nms(boxes, scores, iou_thres)  # NMS
  103.         i = i[:max_det]  # limit detections
  104.         if merge and (1 < n < 3e3):  # Merge NMS (boxes merged using weighted mean)
  105.             # update boxes as boxes(i,4) = weights(i,n) * boxes(n,4)
  106.             iou = box_iou(boxes[i], boxes) > iou_thres  # iou matrix
  107.             weights = iou * scores[None]  # box weights
  108.             x[i, :4] = torch.mm(weights, x[:, :4]).float() / weights.sum(1, keepdim=True)  # merged boxes
  109.             if redundant:
  110.                 i = i[iou.sum(1) > 1]  # require redundancy
  111.         output[xi] = x[i]
  112.         if mps:
  113.             output[xi] = output[xi].to(device)
  114.         if (time.time() - t) > time_limit:
  115.             break  # time limit exceeded
  116.     return output
  117. def draw_bbox(image, result, color=(0, 0, 255), thickness=2):
  118.     # img_path = cv2.cvtColor(img_path, cv2.COLOR_BGR2RGB)
  119.     image = image.copy()
  120.     for point in result:
  121.         x1,y1,x2,y2=point
  122.         cv2.rectangle(image, (x1, y1), (x2, y2), color, thickness)
  123.     return image
  124. def main():
  125.     # Create RKNN object
  126.     rknn = RKNN()
  127.     rknn.list_devices()
  128.     #load rknn model
  129.     ret = rknn.load_rknn(path=RKNN_MODEL)
  130.     if ret != 0:
  131.         print('load rknn failed')
  132.         exit(ret)
  133.     # init runtime environment
  134.     print('--> Init runtime environment')
  135.     ret = rknn.init_runtime(target='rv1126', device_id='86d4fdeb7f3af5b1',perf_debug=True,eval_mem=True)
  136.     if ret != 0:
  137.         print('Init runtime environment failed')
  138.         exit(ret)
  139.     print('done')
  140.     # Set inputs
  141.     image=cv2.imread('./data/bus.jpg')
  142.     image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
  143.     # Inference
  144.     print('--> Running model')
  145.     outputs = rknn.inference(inputs=[image])
  146.     #post process
  147.     out_tensor = torch.tensor(outputs)
  148.     pred = non_max_suppression(out_tensor,0.25,0.45,classes=None,agnostic=False,max_det=1000)
  149.     # Process predictions
  150.     for i, det in enumerate(pred):  # per image
  151.         im0_=cv2.imread(DATA)
  152.         im0=im0_.reshape(1,3,640,640)
  153.         coord=[]
  154.         image=im0.reshape(640,640,3)
  155.         if len(det):
  156.             """Write results"""
  157.             for *xyxy, conf, cls in reversed(det):
  158.                 c = int(cls)  # integer class
  159.                 coord.append([int(xyxy[0].item()), int(xyxy[1].item()),int(xyxy[2].item()), int(xyxy[3].item())])
  160.                 print(f"[{coord[0][0]},{coord[0][1]},{coord[0][2]},{coord[0][3]}]:ID is {c}")
  161.         image=draw_bbox(image,coord)
  162.         # Stream results
  163.         save_success =cv2.imwrite('result.jpg', image)
  164.         print(f"save image end {save_success}")
  165.    
  166.     rknn.release()
  167. if __name__=="__main__":
  168.     main()

板端cpp推理代码编写

拷贝一份template改名为yolov5,目录结构如下

前处理代码

  1. void PreProcess(cv::Mat *image)
  2. {
  3.     cv::cvtColor(*image, *image, cv::COLOR_BGR2RGB);
  4. }

后处理代码

1:输出维度为[1,25200,85],其中85的前四个为中心点的x,y以及框的宽和高,第五个为框的置信度,后面80个为类别的置信度(有80个类别);

2:25200=(80∗80+40∗40+20∗20)∗3,stride为8、16、32,640/8=80,640/16=40,640/32=20

3:NMS删除冗余候选框

1:IOU交并比:检测两个框重叠程度=交集面积/并集面积

2:主要步骤:

  1. 首先筛选出大于阈值的所有候选框(>0.4)
  2. 接着针对每一个种类将候选框进行分类
  3. 找到第n个类别进行循环操作
  4. 先找到置信度最大的框,放到保留区
  5. 和候选区的其他框计算交并比(IOU),若大于iou阈值则删除
  6. 再从候选区找到第二大的候选框放到保留区
  7. 重复4操作,直至候选区没有框
  8. 重复3操作,直至所有类别
  1. float iou(Bbox box1, Bbox box2) {
  2.     /*  
  3.     iou=交并比
  4.     */
  5.     int x1 = max(box1.x, box2.x);
  6.     int y1 = max(box1.y, box2.y);
  7.     int x2 = min(box1.x + box1.w, box2.x + box2.w);
  8.     int y2 = min(box1.y + box1.h, box2.y + box2.h);
  9.     int w = max(0, x2 - x1);
  10.     int h = max(0, y2 - y1);
  11.     float over_area = w * h;
  12.     return over_area / (box1.w * box1.h + box2.w * box2.h - over_area);
  13. }
  14. bool judge_in_lst(int index, vector<int> index_lst) {
  15.     //若index在列表index_lst中则返回true,否则返回false
  16.     if (index_lst.size() > 0) {
  17.         for (int i = 0; i < int(index_lst.size()); i++) {
  18.             if (index == index_lst.at(i)) {
  19.                 return true;
  20.             }
  21.         }
  22.     }
  23.     return false;
  24. }
  25. int get_max_index(vector<Detection> pre_detection) {
  26.     //返回最大置信度值对应的索引值
  27.     int index;
  28.     float conf;
  29.     if (pre_detection.size() > 0) {
  30.         index = 0;
  31.         conf = pre_detection.at(0).conf;
  32.         for (int i = 0; i < int(pre_detection.size()); i++) {
  33.             if (conf < pre_detection.at(i).conf) {
  34.                 index = i;
  35.                 conf = pre_detection.at(i).conf;
  36.             }
  37.         }
  38.         return index;
  39.     }
  40.     else {
  41.         return -1;
  42.     }
  43. }
  44. vector<int> nms(vector<Detection> pre_detection, float iou_thr)
  45. {
  46.     /*
  47.     返回需保存box的pre_detection对应位置索引值
  48.     */
  49.     int index;
  50.     vector<Detection> pre_detection_new;
  51.     //Detection det_best;
  52.     Bbox box_best, box;
  53.     float iou_value;
  54.     vector<int> keep_index;
  55.     vector<int> del_index;
  56.     bool keep_bool;
  57.     bool del_bool;
  58.     if (pre_detection.size() > 0) {
  59.         pre_detection_new.clear();
  60.         // 循环将预测结果建立索引
  61.         for (int i = 0; i < int(pre_detection.size()); i++) {
  62.             pre_detection.at(i).index = i;
  63.             pre_detection_new.push_back(pre_detection.at(i));
  64.         }
  65.         //循环遍历获得保留box位置索引-相对输入pre_detection位置
  66.         while (pre_detection_new.size() > 0) {
  67.             index = get_max_index(pre_detection_new);
  68.             if (index >= 0) {
  69.                 keep_index.push_back(pre_detection_new.at(index).index); //保留索引位置
  70.                 // 更新最佳保留box
  71.                 box_best.x = pre_detection_new.at(index).bbox[0];
  72.                 box_best.y = pre_detection_new.at(index).bbox[1];
  73.                 box_best.w = pre_detection_new.at(index).bbox[2];
  74.                 box_best.h = pre_detection_new.at(index).bbox[3];
  75.                 for (int j = 0; j < int(pre_detection.size()); j++) {
  76.                     keep_bool = judge_in_lst(pre_detection.at(j).index, keep_index);
  77.                     del_bool = judge_in_lst(pre_detection.at(j).index, del_index);
  78.                     if ((!keep_bool) && (!del_bool)) { //不在keep_index与del_index才计算iou
  79.                         box.x = pre_detection.at(j).bbox[0];
  80.                         box.y = pre_detection.at(j).bbox[1];
  81.                         box.w = pre_detection.at(j).bbox[2];
  82.                         box.h = pre_detection.at(j).bbox[3];
  83.                         iou_value = iou(box_best, box);
  84.                         if (iou_value > iou_thr) {
  85.                             del_index.push_back(j); //记录大于阈值将删除对应的位置
  86.                         }
  87.                     }
  88.                 }
  89.                 //更新pre_detection_new
  90.                 pre_detection_new.clear();
  91.                 for (int j = 0; j < int(pre_detection.size()); j++) {
  92.                     keep_bool = judge_in_lst(pre_detection.at(j).index, keep_index);
  93.                     del_bool = judge_in_lst(pre_detection.at(j).index, del_index);
  94.                     if ((!keep_bool) && (!del_bool)) {
  95.                         pre_detection_new.push_back(pre_detection.at(j));
  96.                     }
  97.                 }
  98.             }
  99.         }
  100.     }
  101.     del_index.clear();
  102.     del_index.shrink_to_fit();
  103.     pre_detection_new.clear();
  104.     pre_detection_new.shrink_to_fit();
  105.     return  keep_index;
  106. }
  107. vector<Detection> PostProcess(float* prob,float conf_thr=0.3,float nms_thr=0.5)
  108. {
  109.     vector<Detection> pre_results;
  110.     vector<int> nms_keep_index;
  111.     vector<Detection> results;
  112.     bool keep_bool;
  113.     Detection pre_res;
  114.     float conf;
  115.     int tmp_idx;
  116.     float tmp_cls_score;
  117.     for (int i = 0; i < 25200; i++) {
  118.         tmp_idx = i * (CLSNUM + 5);
  119.         pre_res.bbox[0] = prob[tmp_idx + 0];  //cx
  120.         pre_res.bbox[1] = prob[tmp_idx + 1];  //cy
  121.         pre_res.bbox[2] = prob[tmp_idx + 2];  //w
  122.         pre_res.bbox[3] = prob[tmp_idx + 3];  //h
  123.         conf = prob[tmp_idx + 4];  // 是为目标的置信度
  124.         tmp_cls_score = prob[tmp_idx + 5] * conf; //conf_thr*nms_thr
  125.         pre_res.class_id = 0;
  126.         pre_res.conf = 0;
  127.         // 这个过程相当于从除了前面5列,在后面的cla_num个数据中找出score最大的值作为pre_res.conf,对应的列作为类id
  128.         for (int j = 1; j < CLSNUM; j++) {    
  129.             tmp_idx = i * (CLSNUM + 5) + 5 + j; //获得对应类别索引
  130.             if (tmp_cls_score < prob[tmp_idx] * conf){
  131.                 tmp_cls_score = prob[tmp_idx] * conf;
  132.                 pre_res.class_id = j;
  133.                 pre_res.conf = tmp_cls_score;
  134.             }
  135.         }
  136.         if (conf >= conf_thr) {
  137.             pre_results.push_back(pre_res);
  138.         }
  139.     }
  140.     //使用nms,返回对应结果的索引
  141.     nms_keep_index=nms(pre_results,nms_thr);
  142.     // 茛据nms找到的索引,将结果取出来作为最终结果
  143.     for (int i = 0; i < int(pre_results.size()); i++) {
  144.         keep_bool = judge_in_lst(i, nms_keep_index);
  145.         if (keep_bool) {
  146.             results.push_back(pre_results.at(i));
  147.         }
  148.     }
  149.     pre_results.clear();
  150.     pre_results.shrink_to_fit();
  151.     nms_keep_index.clear();
  152.     nms_keep_index.shrink_to_fit();
  153.     return results;
  154. }

结果展示

至此板端部署结束,接下来进行优化;

优化加速

可以看到模型推理的时间近2s,对于实时处理来说是远远不够的,因此需要对模型进行加速

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/Gausst松鼠会/article/detail/584917
推荐阅读
相关标签
  

闽ICP备14008679号