当前位置:   article > 正文

RT-DETR模型导出与推理

detr模型导出

1.准备工作

RT-DETR模型训练可参考:http://t.csdnimg.cn/Fsph5

模型导出与模型推理需要安装onnx库和onnxruntime库

可通过以下命令安装:

pip install onnx -i https://pypi.tuna.tsinghua.edu.cn/simple
pip install onnxruntime -i https://pypi.tuna.tsinghua.edu.cn/simple

2.onnx模型导出

首先找到export_onnx.py文件,该文件位于RT-DETR/RT-DETR-main/rtdetr_pytorch/tools/export_onnx.py

然后修改config与resume参数,使其路径为你的具体路径。

其中resume参数为训练生成的pth权重文件路径。

  1. """by lyuwenyu
  2. """
  3. import os
  4. import sys
  5. sys.path.insert(0, os.path.join(os.path.dirname(os.path.abspath(__file__)), '..'))
  6. import argparse
  7. import numpy as np
  8. from src.core import YAMLConfig
  9. import torch
  10. import torch.nn as nn
  11. def main(args, ):
  12. """main
  13. """
  14. cfg = YAMLConfig(args.config, resume=args.resume)
  15. if args.resume:
  16. checkpoint = torch.load(args.resume, map_location='cpu')
  17. if 'ema' in checkpoint:
  18. state = checkpoint['ema']['module']
  19. else:
  20. state = checkpoint['model']
  21. else:
  22. raise AttributeError('only support resume to load model.state_dict by now.')
  23. # NOTE load train mode state -> convert to deploy mode
  24. cfg.model.load_state_dict(state)
  25. class Model(nn.Module):
  26. def __init__(self, ) -> None:
  27. super().__init__()
  28. self.model = cfg.model.deploy()
  29. self.postprocessor = cfg.postprocessor.deploy()
  30. print(self.postprocessor.deploy_mode)
  31. def forward(self, images, orig_target_sizes):
  32. outputs = self.model(images)
  33. return self.postprocessor(outputs, orig_target_sizes)
  34. model = Model()
  35. dynamic_axes = {
  36. 'images': {0: 'N', },
  37. 'orig_target_sizes': {0: 'N'}
  38. }
  39. data = torch.rand(1, 3, 640, 640)
  40. size = torch.tensor([[640, 640]])
  41. torch.onnx.export(
  42. model,
  43. (data, size),
  44. args.file_name,
  45. input_names=['images', 'orig_target_sizes'],
  46. output_names=['labels', 'boxes', 'scores'],
  47. dynamic_axes=dynamic_axes,
  48. opset_version=16,
  49. verbose=False
  50. )
  51. if args.check:
  52. import onnx
  53. onnx_model = onnx.load(args.file_name)
  54. onnx.checker.check_model(onnx_model)
  55. print('Check export onnx model done...')
  56. if args.simplify:
  57. import onnxsim
  58. dynamic = True
  59. input_shapes = {'images': data.shape, 'orig_target_sizes': size.shape} if dynamic else None
  60. onnx_model_simplify, check = onnxsim.simplify(args.file_name, input_shapes=input_shapes, dynamic_input_shape=dynamic)
  61. onnx.save(onnx_model_simplify, args.file_name)
  62. print(f'Simplify onnx model {check}...')
  63. # import onnxruntime as ort
  64. # from PIL import Image, ImageDraw
  65. # from torchvision.transforms import ToTensor
  66. # # print(onnx.helper.printable_graph(mm.graph))
  67. # im = Image.open('./000000014439.jpg').convert('RGB')
  68. # im = im.resize((640, 640))
  69. # im_data = ToTensor()(im)[None]
  70. # print(im_data.shape)
  71. # sess = ort.InferenceSession(args.file_name)
  72. # output = sess.run(
  73. # # output_names=['labels', 'boxes', 'scores'],
  74. # output_names=None,
  75. # input_feed={'images': im_data.data.numpy(), "orig_target_sizes": size.data.numpy()}
  76. # )
  77. # # print(type(output))
  78. # # print([out.shape for out in output])
  79. # labels, boxes, scores = output
  80. # draw = ImageDraw.Draw(im)
  81. # thrh = 0.6
  82. # for i in range(im_data.shape[0]):
  83. # scr = scores[i]
  84. # lab = labels[i][scr > thrh]
  85. # box = boxes[i][scr > thrh]
  86. # print(i, sum(scr > thrh))
  87. # for b in box:
  88. # draw.rectangle(list(b), outline='red',)
  89. # draw.text((b[0], b[1]), text=str(lab[i]), fill='blue', )
  90. # im.save('test.jpg')
  91. if __name__ == '__main__':
  92. parser = argparse.ArgumentParser()
  93. parser.add_argument('--config', '-c', type=str,default = "/home/guan/RT-DETR/RT-DETR-main/rtdetr_pytorch/configs/rtdetr/rtdetr_r18vd_6x_coco.yml" )
  94. parser.add_argument('--resume', '-r', type=str,default = "/home/guan/RT-DETR/RT-DETR-main/rtdetr_pytorch/tools/output/rtdetr_r18vd_6x_coco/checkpoint0012.pth" )
  95. parser.add_argument('--file-name', '-f', type=str, default='model.onnx')
  96. parser.add_argument('--check', action='store_true', default=False,)
  97. parser.add_argument('--simplify', action='store_true', default=False,)
  98. args = parser.parse_args()
  99. main(args)

修改完毕后,即可运行export_onnx.py,生成的onnx文件位于该py文件的同级目录

3.推理

在tools文件夹下创建mypredict.py

mypredict.py的代码如下:

你需要修改img_path ,使其为你推理所需的图像路径。

img.save()中的路径修改为你的推理结果保存的路径。

-------2023.12.21更新--------

按照你的数据集中的类别修改classes

  1. import torch
  2. import onnxruntime as ort
  3. from PIL import Image, ImageDraw
  4. from torchvision.transforms import ToTensor
  5. if __name__ == "__main__":
  6. ##################
  7. classes = ['','LicensePlate']
  8. ##################
  9. # print(onnx.helper.printable_graph(mm.graph))
  10. #############
  11. img_path = "/home/guan/RT-DETR/RT-DETR-main/rtdetr_pytorch/tools/input/IMG_8669.jpg"
  12. #############
  13. im = Image.open(img_path).convert('RGB')
  14. im = im.resize((640, 640))
  15. im_data = ToTensor()(im)[None]
  16. print(im_data.shape)
  17. size = torch.tensor([[640, 640]])
  18. sess = ort.InferenceSession("model.onnx")
  19. output = sess.run(
  20. # output_names=['labels', 'boxes', 'scores'],
  21. output_names=None,
  22. input_feed={'images': im_data.data.numpy(), "orig_target_sizes": size.data.numpy()}
  23. )
  24. # print(type(output))
  25. # print([out.shape for out in output])
  26. labels, boxes, scores = output
  27. draw = ImageDraw.Draw(im)
  28. thrh = 0.6
  29. for i in range(im_data.shape[0]):
  30. scr = scores[i]
  31. lab = labels[i][scr > thrh]
  32. box = boxes[i][scr > thrh]
  33. print(i, sum(scr > thrh))
  34. #print(lab)
  35. print(f'box:{box}')
  36. for l, b in zip(lab, box):
  37. draw.rectangle(list(b), outline='red',)
  38. print(l.item())
  39. draw.text((b[0], b[1] - 10), text=str(classes[l.item()]), fill='blue', )
  40. #############
  41. im.save('/home/guan/RT-DETR/RT-DETR-main/rtdetr_pytorch/tools/output/predict/res.jpg')
  42. #############

运行mypredict.py,得到推理结果。

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/正经夜光杯/article/detail/850298
推荐阅读
相关标签
  

闽ICP备14008679号