当前位置:   article > 正文

DETR导出onnx模型,并进行推理(cpu环境)_detr onnx

detr onnx

        首先在detr项目目录下创建onnx文件夹,用于存放detr的pth文件,后续导出的onnx文件也存放在此。

        在detr项目目录下创建export_onnx.py文件,将下面代码拷贝之后直接运行即可导出detr.onnx模型,onnx模型存放到onnx文件夹下。

  1. import io
  2. import argparse
  3. import onnx
  4. import onnxruntime
  5. import torch
  6. from hubconf import detr_resnet50
  7. class ONNXExporter:
  8. @classmethod
  9. def setUpClass(cls):
  10. torch.manual_seed(123)
  11. def run_model(self, model, onnx_path, inputs_list, tolerate_small_mismatch=False,
  12. do_constant_folding=True,
  13. output_names=None, input_names=None):
  14. model.eval()
  15. onnx_io = io.BytesIO()
  16. onnx_path = onnx_path
  17. torch.onnx.export(model, inputs_list[0], onnx_io,
  18. input_names=input_names, output_names=output_names, export_params=True, training=False,
  19. opset_version=12,do_constant_folding=do_constant_folding)
  20. torch.onnx.export(model, inputs_list[0], onnx_path,
  21. input_names=input_names, output_names=output_names, export_params=True, training=False,
  22. opset_version=12,do_constant_folding=do_constant_folding)
  23. print(f"[INFO] ONNX model export success! save path: {onnx_path}")
  24. # validate the exported model with onnx runtime
  25. for test_inputs in inputs_list:
  26. with torch.no_grad():
  27. if isinstance(test_inputs, torch.Tensor) or isinstance(test_inputs, list):
  28. # test_inputs = (nested_tensor_from_tensor_list(test_inputs),)
  29. test_inputs = (test_inputs,)
  30. test_ouputs = model(*test_inputs)
  31. if isinstance(test_ouputs, torch.Tensor):
  32. test_ouputs = (test_ouputs,)
  33. self.ort_validate(onnx_io, test_inputs, test_ouputs, tolerate_small_mismatch)
  34. def ort_validate(self, onnx_io, inputs, outputs, tolerate_small_mismatch=False):
  35. inputs, _ = torch.jit._flatten(inputs)
  36. outputs, _ = torch.jit._flatten(outputs)
  37. def to_numpy(tensor):
  38. if tensor.requires_grad:
  39. return tensor.detach().cpu().numpy()
  40. else:
  41. return tensor.cpu().numpy()
  42. inputs = list(map(to_numpy, inputs))
  43. outputs = list(map(to_numpy, outputs))
  44. ort_session = onnxruntime.InferenceSession(onnx_io.getvalue())
  45. # compute onnxruntime output prediction
  46. ort_inputs = dict((ort_session.get_inputs()[i].name, inpt) for i, inpt in enumerate(inputs))
  47. ort_outs = ort_session.run(None, ort_inputs)
  48. for i in range(0, len(outputs)):
  49. try:
  50. torch.testing.assert_allclose(outputs[i], ort_outs[i], rtol=1e-03, atol=1e-05)
  51. except AssertionError as error:
  52. if tolerate_small_mismatch:
  53. print(error)
  54. else:
  55. raise
  56. @staticmethod
  57. def check_onnx(onnx_path):
  58. model = onnx.load(onnx_path)
  59. onnx.checker.check_model(model)
  60. print(f"[INFO] ONNX model: {onnx_path} check success!")
  61. if __name__ == '__main__':
  62. parser = argparse.ArgumentParser(description='DETR Model to ONNX Model')
  63. # detr pth 模型存放的路径
  64. parser.add_argument('--model_dir', type=str, default='onnx/detr-r50-e632da11.pth',
  65. help='DETR Pytorch Model Saved Dir')
  66. parser.add_argument('--check', default=True, action="store_true", help='Check Your ONNX Model')
  67. # pth转换onnx后存放的路径
  68. parser.add_argument('--onnx_dir', type=str, default="onnx/detr.onnx", help="Check ONNX Model's dir")
  69. parser.add_argument('--batch_size', type=int, default=1, help="Batch Size")
  70. args = parser.parse_args()
  71. # load torch model
  72. detr = detr_resnet50(pretrained=False, num_classes=90 + 1).eval() # max label index add 1
  73. # state_dict = torch.load(args.model_dir, map_location='cuda') # model path
  74. state_dict = torch.load(args.model_dir, map_location='cpu') # model path
  75. detr.load_state_dict(state_dict["model"])
  76. # dummy input
  77. dummy_image = [torch.ones(args.batch_size, 3, 800, 800)]
  78. # to onnx
  79. onnx_export = ONNXExporter()
  80. onnx_export.run_model(detr, args.onnx_dir, dummy_image, input_names=['inputs'],
  81. output_names=["pred_logits", "pred_boxes"], tolerate_small_mismatch=True)
  82. # check onnx model
  83. if args.check:
  84. ONNXExporter.check_onnx(args.onnx_dir)

        导出的时候可能会提示警告:

        无视就好,稍等一两分钟就可以完成onnx的导出。 

        导出后,在同级目录下创建inference_onnx.py文件,使用刚才导出的onnx模型进行预测。

  1. import cv2
  2. from PIL import Image
  3. import numpy as np
  4. import os
  5. import random
  6. try:
  7. import onnxruntime
  8. except ImportError:
  9. onnxruntime = None
  10. import torch
  11. import torchvision.transforms as T
  12. torch.set_grad_enabled(False)
  13. device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
  14. transform = T.Compose([
  15. T.Resize((800, 800)),
  16. T.ToTensor(),
  17. T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
  18. ])
  19. def box_cxcywh_to_xyxy(x):
  20. x = torch.from_numpy(x)
  21. x_c, y_c, w, h = x.unbind(1)
  22. b = [(x_c - 0.5 * w), (y_c - 0.5 * h),
  23. (x_c + 0.5 * w), (y_c + 0.5 * h)]
  24. return torch.stack(b, dim=1)
  25. def rescale_bboxes(out_bbox, size):
  26. img_w, img_h = size
  27. b = box_cxcywh_to_xyxy(out_bbox)
  28. b = b.cpu().numpy()
  29. b = b * np.array([img_w, img_h, img_w, img_h], dtype=np.float32)
  30. return b
  31. def plot_one_box(x, img, color=None, label=None, line_thickness=1):
  32. tl = line_thickness or round(0.002 * (img.shape[0] + img.shape[1]) / 2) + 1 # line/font thickness
  33. color = color or [random.randint(0, 255) for _ in range(3)]
  34. c1, c2 = (int(x[0]), int(x[1])), (int(x[2]), int(x[3]))
  35. cv2.rectangle(img, c1, c2, color, thickness=tl, lineType=cv2.LINE_AA)
  36. if label:
  37. tf = max(tl - 1, 1) # font thickness
  38. t_size = cv2.getTextSize(label, 0, fontScale=tl / 3, thickness=tf)[0]
  39. c2 = c1[0] + t_size[0], c1[1] - t_size[1] - 3
  40. cv2.rectangle(img, c1, c2, color, -1, cv2.LINE_AA) # filled
  41. cv2.putText(img, label, (c1[0], c1[1] - 2), 0, tl / 3, [225, 255, 255], thickness=tf, lineType=cv2.LINE_AA)
  42. CLASSES = [
  43. 'N/A', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
  44. 'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'N/A',
  45. 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse',
  46. 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'N/A', 'backpack',
  47. 'umbrella', 'N/A', 'N/A', 'handbag', 'tie', 'suitcase', 'frisbee', 'skis',
  48. 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove',
  49. 'skateboard', 'surfboard', 'tennis racket', 'bottle', 'N/A', 'wine glass',
  50. 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich',
  51. 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake',
  52. 'chair', 'couch', 'potted plant', 'bed', 'N/A', 'dining table', 'N/A',
  53. 'N/A', 'toilet', 'N/A', 'tv', 'laptop', 'mouse', 'remote', 'keyboard',
  54. 'cell phone', 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'N/A',
  55. 'book', 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush'
  56. ]
  57. def plot_result(pil_img, prob, boxes, save_name=None, imshow=False, imwrite=False):
  58. cv2Image = cv2.cvtColor(np.array(pil_img), cv2.COLOR_RGB2BGR)
  59. for p, (xmin, ymin, xmax, ymax) in zip(prob, boxes):
  60. cl = p.argmax()
  61. label_text = '{} {}%'.format(CLASSES[cl], round(p[cl] * 100, 2))
  62. plot_one_box((xmin, ymin, xmax, ymax), cv2Image, label=label_text)
  63. if imshow:
  64. cv2.imshow('detect', cv2Image)
  65. cv2.waitKey(0)
  66. if imwrite:
  67. if not os.path.exists("onnx/result"):
  68. os.makedirs('onnx/result')
  69. cv2.imwrite('onnx/result/{}'.format(save_name), cv2Image)
  70. def detect_onnx(ort_session, im, prob_threshold=0.7):
  71. img = transform(im).unsqueeze(0).cpu().numpy()
  72. ort_inputs = {"inputs": img}
  73. scores, boxs = ort_session.run(None, ort_inputs)
  74. probas = torch.from_numpy(np.array(scores)).softmax(-1)[0, :, :-1]
  75. keep = probas.max(-1).values > prob_threshold
  76. probas = probas.cpu().detach().numpy()
  77. keep = keep.cpu().detach().numpy()
  78. bboxes_scaled = rescale_bboxes(boxs[0, keep], im.size)
  79. return probas[keep], bboxes_scaled
  80. if __name__ == "__main__":
  81. onnx_path = "onnx/detr.onnx"
  82. ort_session = onnxruntime.InferenceSession(onnx_path)
  83. files = os.listdir("onnx/images")
  84. for file in files:
  85. img_path = os.path.join("onnx/images", file)
  86. im = Image.open(img_path)
  87. scores, boxes = detect_onnx(ort_session, im)
  88. plot_result(im, scores, boxes, save_name=file, imshow=False, imwrite=True)

预测结果:

直接用pth进行推理的可以看: DETR推理代码_athrunsunny的博客-CSDN博客

本文内容由网友自发贡献,转载请注明出处:【wpsshop博客】
推荐阅读
相关标签
  

闽ICP备14008679号