当前位置:   article > 正文

YOLOv5-seg中推理部分predict.py代码功能介绍与删减封装_retina_masks

retina_masks

 一、任务介绍

        需要将yolo-seg和ros2结合起来,所以要将yolov5-seg中的predict.py预测部分封装成类,并且只保留主体功能部分,把不需要的部分都删除。

二、关键代码介绍

        1.推理准备

  1. device = select_device(self.device)
  2. model = DetectMultiBackend(self.weights, device=self.device, dnn=self.dnn, data=self.data, fp16=self.half)
  3. stride, names, pt = model.stride, model.names, model.pt
  4. imgsz = check_img_size(self.imgsz, s=stride)
  5. bs = 1 # batch_size
  6. model.warmup(imgsz=(1 if pt else bs, 3, *imgsz)) # warmup

         select_device:选择使用的设备,有CUDA用CUDA,没有就用cpu.

        DetectMultibackend:pytorch模型类,用于在不同的推理后端上进行目标检测,对于不同的推理后端有不同的加载和初始化逻辑,还有对前向推理的方法。这里的用法是调用其构造函数创建对象。

        check_img_size用于检查图像大小是否为步幅s的倍数,如果不是则调整成s的倍数。

        warmup函数为模型推理前进行的预热操作,它会调用forward函数,根据imgsz传入一个空的张量来触发预处理过程,这个过程中模型会对计算流程进行预处理和缓存,以便在后面实际推理时更快的相应请求。

        总的来说,这段是选择设备、初始化模型,并进行模型推理准备工作。

        2.图像预处理

  1. img = letterbox(image_raw, self.imgsz, stride=stride)[0]
  2. img = img.transpose((2, 0, 1))[::-1]
  3. im = np.ascontiguousarray(img)
  4. im = torch.from_numpy(im).to(model.device)
  5. im = im.half() if model.fp16 else im.float()
  6. im /= 255 # 0 - 255 to 0.0 - 1.0
  7. if len(im.shape) == 3:
  8. im = im[None] # expand for batch dim

         letterbox函数将原始图像调整到指定尺寸imgsz,并且保持长宽比不变,函数的返回是处理后的图片、缩放比例、填充的宽度和高度。

        transpose函数是把图像维度顺序从(高度H,宽度W,通道数C)转换为(通道数C,高度H,宽度W)这样是为了符合pytorch对输入数据的要求,后面的[::-1]是将图像的通道顺序从RGB转为BGR,也是为了符合机器学习框架的需求。

        ascontiguousarray函数是为了保证图像数组是连续的(某些图像格式在存储时可能采取非连续的方式来组织数据,如BMP格式等),返回一个连续数组的拷贝,防止因为图像数组不连续导致的后续计算函数报错或者效率低下。

        后面则是:

                转换为pytorch张量移动到指定设备(cpu或CUDA);

                数据类型转换成全精度浮点数;

                像素值归一化(将像素值归一化到0.0-1.0之间);

                维度调整(如果张量的维度是3(例如形状为(H,W,C)的图像),为了适应对输入张量的要求,会在第0维额外加一个批次维度,变成(1,H,W,C))

        总的来说,这段是将原始图像进行预处理,以符合模型的输入要求。

         3.模型推理

  1. pred, proto = model(im, augment=self.augment, visualize=self.visualize)[:2]
  2. pred = non_max_suppression(pred, self.conf_thres, self.iou_thres, self.classes, self.agnostic_nms, max_det=self.max_det, nm=32)

        model通过调用前面加载的model对输入图像进行推理。其中:

                augment:推理时是否使用数据增强(对数据进行随机裁剪、翻转、旋转、缩放等来生成更多样化的样本)

                visualize:是否返回可视化结果。

        最后返回值是预测框(pred)和原型向量(proto)的元组。其中原型向量(proto)可以理解为用来代表训练集中不同类别的样本的特征向量,每个原型向量代表一个类别。原型向量可以用于计算损失函数,优化模型参数等。

        non_max_suppression函数将预测框应用非最大抑制算法(NMS),根据置信度阈值(conf_thres)和重叠IOU阈值(iou_thres)对预测框进行过滤,其中:

                classes:指定推理类别

                agnostic_nms:是否使用类别不可知的非最大抑制。

                max_det:每个图像中保留的最大检测框数量

                nm:非最大抑制的候选框数量

        最后得到经过非最大抑制处理后的预测框。

        总的来说,这段是通过模型进行推理,并使用非最大抑制算法对预测框进行过滤。

         4.绘制结果

  1. im0 = image_raw
  2. annotator = Annotator(im0, line_width=self.line_thickness, example=str(names))

        创建一个annotator注释器对象,传入原始图像im0,线条宽度line_width,类别名称str(names)

  1. if self.retina_masks:
  2. det[:, :4] = scale_boxes(im.shape[2:], det[:, :4], im0.shape).round() # rescale boxes to im0 size
  3. masks = process_mask_native(proto[i], det[:, 6:], det[:, :4], im0.shape[:2]) # HWC
  4. else:
  5. masks = process_mask(proto[i], det[:, 6:], det[:, :4], im.shape[2:], upsample=True) # HWC
  6. det[:, :4] = scale_boxes(im.shape[2:], det[:, :4], im0.shape).round() # rescale boxes to im0 size

         retina_masks为真:先把图像调整成原图像尺寸,再对掩码进行处理;为假则先对掩码进行处理,再调整到原图像尺寸。

        scale_boxes函数作用是把边界框按照给定图片形状和缩放比例进行缩放和填充。

        process_mask_native和process_mask函数都是对掩码进行裁剪、缩放、上采样等操作,区别在于裁剪顺序、裁剪方式不同,前者先上采样再裁剪、直接使用原始bboxes裁剪;而后者是先裁剪再上采样、根据输入图像大小和原始掩码大小的比例来调整边界框的坐标,裁剪时使用调整后的边界框。

        这两种方式最大的区别在于后面segments的结果不同,retina_masks为真时得到的是基于原始图像计算得到的,而为假时则是基于当前图像。

  1. segments = [
  2. scale_segments(im0.shape if self.retina_masks else im.shape[2:], x, im0.shape, normalize=True)
  3. for x in reversed(masks2segments(masks))]

        masks2segments函数把掩码转为对应的分割区域 

        scale_segments函数根据retina_masks的值把分割区域缩放至指定尺寸(原图尺寸或当前图像尺寸),后进行归一化存入segments列表。

  1. for j, (*xyxy, conf, cls) in enumerate(reversed(det[:, :6])):
  2. seg = segments[j].reshape(-1) # (n,2) to (n*2)
  3. line = (cls, *seg) # label format
  4. c = int(cls) # integer class
  5. label = f'{names[c]} {conf:.2f}'
  6. annotator.box_label(xyxy, label, color=colors(c, True))

        seg为掩码数据的一维数组形式,以便处理。box_label函数将边界框和类别、置信度绘制到原图。 

im0 = annotator.result()

        获取绘制了边界框和标签后的图像.因为前面annotator对象内部维护了一个内存中的缓冲图像,用于绘制标注,所以最后需要获取一下。

        至此,关键代码基本介绍完毕。

三、完整代码 

  1. import argparse
  2. import os
  3. import platform
  4. import sys
  5. from pathlib import Path
  6. import numpy as np
  7. import torch
  8. from ultralytics.utils.plotting import Annotator, colors, save_one_box
  9. from models.common import DetectMultiBackend
  10. from utils.dataloaders import IMG_FORMATS, VID_FORMATS, LoadImages, LoadScreenshots, LoadStreams
  11. from utils.general import (LOGGER, Profile, check_file, check_img_size, check_imshow, check_requirements, colorstr, cv2,
  12. increment_path, non_max_suppression, print_args, scale_boxes, scale_segments,
  13. strip_optimizer)
  14. from utils.segment.general import masks2segments, process_mask, process_mask_native
  15. from utils.torch_utils import select_device, time_sync
  16. from utils.augmentations import letterbox
  17. class yolov5_test():
  18. def __init__(self,
  19. weights=None,
  20. data=None,
  21. imgsz=(640, 640),
  22. conf_thres=0.25,
  23. iou_thres=0.4,
  24. max_det=1000,
  25. device='cpu',
  26. classes=None,
  27. agnostic_nms=False,
  28. augment=False,
  29. visualize=False,
  30. line_thickness=3,
  31. half=False,
  32. dnn=False,
  33. vid_stride=1,
  34. retina_masks=False):
  35. self.weights = weights
  36. self.data = data
  37. self.imgsz = imgsz
  38. self.conf_thres = conf_thres
  39. self.iou_thres = iou_thres
  40. self.max_det = max_det
  41. self.device = device
  42. self.classes = classes
  43. self.agnostic_nms = agnostic_nms
  44. self.augment = augment
  45. self.visualize = visualize
  46. self.line_thickness = line_thickness
  47. self.half = half
  48. self.dnn = dnn
  49. self.vid_stride = vid_stride
  50. self.retina_masks = retina_masks
  51. def image_callback(self, image_raw):
  52. save_path = "/home/nvidia/test.jpg" # 测试用,使用修改成自己的路径
  53. # Load model
  54. device = select_device(self.device)
  55. model = DetectMultiBackend(self.weights, device=self.device, dnn=self.dnn, data=self.data, fp16=self.half)
  56. stride, names, pt = model.stride, model.names, model.pt
  57. imgsz = check_img_size(self.imgsz, s=stride) # check image size
  58. model.warmup(imgsz=(1 if pt else bs, 3, *imgsz)) # warmup
  59. dt = [0.0, 0.0, 0.0]
  60. bs = 1 # batch_size
  61. img = letterbox(image_raw, self.imgsz, stride=stride)[0]
  62. img = img.transpose((2, 0, 1))[::-1]
  63. im = np.ascontiguousarray(img)
  64. t1 = time_sync()
  65. im = torch.from_numpy(im).to(model.device)
  66. # im = im.half() if model.fp16 else im.float() # uint8 to fp16/32
  67. im = im.half() if model.fp16 else im.float()
  68. im /= 255 # 0 - 255 to 0.0 - 1.0
  69. if len(im.shape) == 3:
  70. im = im[None] # expand for batch dim
  71. t2 = time_sync()
  72. dt[0] += t2 - t1
  73. # visualize = increment_path(save_dir / Path(path).stem, mkdir=True) if visualize else False
  74. pred, proto = model(im, augment=self.augment, visualize=self.visualize)[:2]
  75. t3 = time_sync()
  76. dt[1] += t3 - t2
  77. pred = non_max_suppression(pred, self.conf_thres, self.iou_thres, self.classes, self.agnostic_nms, max_det=self.max_det, nm=32)
  78. dt[2] += time_sync() - t3
  79. for i, det in enumerate(pred):
  80. im0 = image_raw
  81. annotator = Annotator(im0, line_width=self.line_thickness, example=str(names))
  82. if len(det):
  83. if self.retina_masks:
  84. det[:, :4] = scale_boxes(im.shape[2:], det[:, :4], im0.shape).round() # rescale boxes to im0 size
  85. masks = process_mask_native(proto[i], det[:, 6:], det[:, :4], im0.shape[:2]) # HWC
  86. else:
  87. masks = process_mask(proto[i], det[:, 6:], det[:, :4], im.shape[2:], upsample=True) # HWC
  88. det[:, :4] = scale_boxes(im.shape[2:], det[:, :4], im0.shape).round() # rescale boxes to im0 size
  89. segments = [
  90. scale_segments(im0.shape if self.retina_masks else im.shape[2:], x, im0.shape, normalize=True)
  91. for x in reversed(masks2segments(masks))]
  92. annotator.masks(
  93. masks,
  94. colors=[colors(x, True) for x in det[:, 5]],
  95. im_gpu=torch.as_tensor(im0, dtype=torch.float16).to(device).permute(2, 0, 1).flip(0).contiguous() /
  96. 255 if self.retina_masks else im[i])
  97. for j, (*xyxy, conf, cls) in enumerate(reversed(det[:, :6])):
  98. seg = segments[j].reshape(-1) # (n,2) to (n*2)
  99. line = (cls, *seg) # label format
  100. c = int(cls) # integer class
  101. label = f'{names[c]} {conf:.2f}'
  102. annotator.box_label(xyxy, label, color=colors(c, True))
  103. im0 = annotator.result()
  104. # Save results (image with detections)
  105. cv2.imwrite(save_path, im0)

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/知新_RL/article/detail/275794
推荐阅读
相关标签
  

闽ICP备14008679号