赞
踩
首先在detr项目目录下创建onnx文件夹,用于存放detr的pth文件,后续导出的onnx文件也存放在此。
在detr项目目录下创建export_onnx.py文件,将下面代码拷贝之后直接运行即可导出detr.onnx模型,onnx模型存放到onnx文件夹下。
- import io
- import argparse
- import onnx
- import onnxruntime
- import torch
- from hubconf import detr_resnet50
-
-
- class ONNXExporter:
- @classmethod
- def setUpClass(cls):
- torch.manual_seed(123)
-
- def run_model(self, model, onnx_path, inputs_list, tolerate_small_mismatch=False,
- do_constant_folding=True,
- output_names=None, input_names=None):
- model.eval()
-
- onnx_io = io.BytesIO()
- onnx_path = onnx_path
-
- torch.onnx.export(model, inputs_list[0], onnx_io,
- input_names=input_names, output_names=output_names, export_params=True, training=False,
- opset_version=12,do_constant_folding=do_constant_folding)
- torch.onnx.export(model, inputs_list[0], onnx_path,
- input_names=input_names, output_names=output_names, export_params=True, training=False,
- opset_version=12,do_constant_folding=do_constant_folding)
-
- print(f"[INFO] ONNX model export success! save path: {onnx_path}")
-
- # validate the exported model with onnx runtime
- for test_inputs in inputs_list:
- with torch.no_grad():
- if isinstance(test_inputs, torch.Tensor) or isinstance(test_inputs, list):
- # test_inputs = (nested_tensor_from_tensor_list(test_inputs),)
- test_inputs = (test_inputs,)
- test_ouputs = model(*test_inputs)
- if isinstance(test_ouputs, torch.Tensor):
- test_ouputs = (test_ouputs,)
- self.ort_validate(onnx_io, test_inputs, test_ouputs, tolerate_small_mismatch)
-
-
- def ort_validate(self, onnx_io, inputs, outputs, tolerate_small_mismatch=False):
-
- inputs, _ = torch.jit._flatten(inputs)
- outputs, _ = torch.jit._flatten(outputs)
-
- def to_numpy(tensor):
- if tensor.requires_grad:
- return tensor.detach().cpu().numpy()
- else:
- return tensor.cpu().numpy()
-
- inputs = list(map(to_numpy, inputs))
- outputs = list(map(to_numpy, outputs))
-
- ort_session = onnxruntime.InferenceSession(onnx_io.getvalue())
- # compute onnxruntime output prediction
- ort_inputs = dict((ort_session.get_inputs()[i].name, inpt) for i, inpt in enumerate(inputs))
- ort_outs = ort_session.run(None, ort_inputs)
- for i in range(0, len(outputs)):
- try:
- torch.testing.assert_allclose(outputs[i], ort_outs[i], rtol=1e-03, atol=1e-05)
- except AssertionError as error:
- if tolerate_small_mismatch:
- print(error)
- else:
- raise
-
- @staticmethod
- def check_onnx(onnx_path):
- model = onnx.load(onnx_path)
- onnx.checker.check_model(model)
- print(f"[INFO] ONNX model: {onnx_path} check success!")
-
-
-
- if __name__ == '__main__':
-
- parser = argparse.ArgumentParser(description='DETR Model to ONNX Model')
- # detr pth 模型存放的路径
- parser.add_argument('--model_dir', type=str, default='onnx/detr-r50-e632da11.pth',
- help='DETR Pytorch Model Saved Dir')
- parser.add_argument('--check', default=True, action="store_true", help='Check Your ONNX Model')
- # pth转换onnx后存放的路径
- parser.add_argument('--onnx_dir', type=str, default="onnx/detr.onnx", help="Check ONNX Model's dir")
- parser.add_argument('--batch_size', type=int, default=1, help="Batch Size")
-
- args = parser.parse_args()
-
- # load torch model
- detr = detr_resnet50(pretrained=False, num_classes=90 + 1).eval() # max label index add 1
- # state_dict = torch.load(args.model_dir, map_location='cuda') # model path
- state_dict = torch.load(args.model_dir, map_location='cpu') # model path
- detr.load_state_dict(state_dict["model"])
-
- # dummy input
- dummy_image = [torch.ones(args.batch_size, 3, 800, 800)]
-
- # to onnx
- onnx_export = ONNXExporter()
- onnx_export.run_model(detr, args.onnx_dir, dummy_image, input_names=['inputs'],
- output_names=["pred_logits", "pred_boxes"], tolerate_small_mismatch=True)
-
- # check onnx model
- if args.check:
- ONNXExporter.check_onnx(args.onnx_dir)
导出的时候可能会提示警告:
无视就好,稍等一两分钟就可以完成onnx的导出。
导出后,在同级目录下创建inference_onnx.py文件,使用刚才导出的onnx模型进行预测。
- import cv2
- from PIL import Image
- import numpy as np
- import os
- import random
-
- try:
- import onnxruntime
- except ImportError:
- onnxruntime = None
-
- import torch
- import torchvision.transforms as T
-
- torch.set_grad_enabled(False)
- device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
-
- transform = T.Compose([
- T.Resize((800, 800)),
- T.ToTensor(),
- T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
- ])
-
-
- def box_cxcywh_to_xyxy(x):
- x = torch.from_numpy(x)
- x_c, y_c, w, h = x.unbind(1)
- b = [(x_c - 0.5 * w), (y_c - 0.5 * h),
- (x_c + 0.5 * w), (y_c + 0.5 * h)]
- return torch.stack(b, dim=1)
-
-
- def rescale_bboxes(out_bbox, size):
- img_w, img_h = size
- b = box_cxcywh_to_xyxy(out_bbox)
- b = b.cpu().numpy()
- b = b * np.array([img_w, img_h, img_w, img_h], dtype=np.float32)
- return b
-
-
- def plot_one_box(x, img, color=None, label=None, line_thickness=1):
- tl = line_thickness or round(0.002 * (img.shape[0] + img.shape[1]) / 2) + 1 # line/font thickness
- color = color or [random.randint(0, 255) for _ in range(3)]
- c1, c2 = (int(x[0]), int(x[1])), (int(x[2]), int(x[3]))
- cv2.rectangle(img, c1, c2, color, thickness=tl, lineType=cv2.LINE_AA)
- if label:
- tf = max(tl - 1, 1) # font thickness
- t_size = cv2.getTextSize(label, 0, fontScale=tl / 3, thickness=tf)[0]
- c2 = c1[0] + t_size[0], c1[1] - t_size[1] - 3
- cv2.rectangle(img, c1, c2, color, -1, cv2.LINE_AA) # filled
- cv2.putText(img, label, (c1[0], c1[1] - 2), 0, tl / 3, [225, 255, 255], thickness=tf, lineType=cv2.LINE_AA)
-
-
- CLASSES = [
- 'N/A', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
- 'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'N/A',
- 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse',
- 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'N/A', 'backpack',
- 'umbrella', 'N/A', 'N/A', 'handbag', 'tie', 'suitcase', 'frisbee', 'skis',
- 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove',
- 'skateboard', 'surfboard', 'tennis racket', 'bottle', 'N/A', 'wine glass',
- 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich',
- 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake',
- 'chair', 'couch', 'potted plant', 'bed', 'N/A', 'dining table', 'N/A',
- 'N/A', 'toilet', 'N/A', 'tv', 'laptop', 'mouse', 'remote', 'keyboard',
- 'cell phone', 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'N/A',
- 'book', 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush'
- ]
-
-
- def plot_result(pil_img, prob, boxes, save_name=None, imshow=False, imwrite=False):
- cv2Image = cv2.cvtColor(np.array(pil_img), cv2.COLOR_RGB2BGR)
-
- for p, (xmin, ymin, xmax, ymax) in zip(prob, boxes):
- cl = p.argmax()
- label_text = '{} {}%'.format(CLASSES[cl], round(p[cl] * 100, 2))
- plot_one_box((xmin, ymin, xmax, ymax), cv2Image, label=label_text)
-
- if imshow:
- cv2.imshow('detect', cv2Image)
- cv2.waitKey(0)
-
- if imwrite:
- if not os.path.exists("onnx/result"):
- os.makedirs('onnx/result')
- cv2.imwrite('onnx/result/{}'.format(save_name), cv2Image)
-
-
- def detect_onnx(ort_session, im, prob_threshold=0.7):
- img = transform(im).unsqueeze(0).cpu().numpy()
- ort_inputs = {"inputs": img}
- scores, boxs = ort_session.run(None, ort_inputs)
- probas = torch.from_numpy(np.array(scores)).softmax(-1)[0, :, :-1]
- keep = probas.max(-1).values > prob_threshold
- probas = probas.cpu().detach().numpy()
- keep = keep.cpu().detach().numpy()
- bboxes_scaled = rescale_bboxes(boxs[0, keep], im.size)
- return probas[keep], bboxes_scaled
-
-
- if __name__ == "__main__":
- onnx_path = "onnx/detr.onnx"
- ort_session = onnxruntime.InferenceSession(onnx_path)
- files = os.listdir("onnx/images")
-
- for file in files:
- img_path = os.path.join("onnx/images", file)
- im = Image.open(img_path)
- scores, boxes = detect_onnx(ort_session, im)
- plot_result(im, scores, boxes, save_name=file, imshow=False, imwrite=True)
-
-
预测结果:
直接用pth进行推理的可以看: DETR推理代码_athrunsunny的博客-CSDN博客
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。