当前位置:   article > 正文

YOLOV7算法(五)pth/pt转onnx学习记录_yolov7pt转onnx

yolov7pt转onnx

输入指令

python export.py --weights /kaxier01/projects/FAS/yolov7/weights/yolov7.pt --grid --end2end --simplify --topk-all 100 --iou-thres 0.65 --conf-thres 0.35 --img-size 640 640 --max-wh 640

export.py代码学习

  1. import argparse
  2. import sys
  3. import time
  4. import warnings
  5. sys.path.append('./') # to run '$ python *.py' files in subdirectories
  6. import torch
  7. import torch.nn as nn
  8. from torch.utils.mobile_optimizer import optimize_for_mobile
  9. import models
  10. from models.experimental import attempt_load, End2End
  11. from utils.activations import Hardswish, SiLU
  12. from utils.general import set_logging, check_img_size
  13. from utils.torch_utils import select_device
  14. from utils.add_nms import RegisterNMS
  15. import sys
  16. import warnings
  17. warnings.filterwarnings('ignore')
  18. if __name__ == '__main__':
  19. parser = argparse.ArgumentParser()
  20. parser.add_argument('--weights', type=str, default='/kaxier01/projects/FAS/yolov7/weights/yolov7.pt', help='weights path')
  21. parser.add_argument('--img-size', nargs='+', type=int, default=[640, 640], help='image size') # height, width
  22. parser.add_argument('--batch-size', type=int, default=1, help='batch size')
  23. parser.add_argument('--dynamic', action='store_true', help='dynamic ONNX axes')
  24. parser.add_argument('--dynamic-batch', action='store_true', help='dynamic batch onnx for tensorrt and onnx-runtime')
  25. parser.add_argument('--grid', action='store_true', help='export Detect() layer grid')
  26. parser.add_argument('--end2end', action='store_true', help='export end2end onnx')
  27. parser.add_argument('--max-wh', type=int, default=None, help='None for tensorrt nms, int value for onnx-runtime nms')
  28. parser.add_argument('--topk-all', type=int, default=100, help='topk objects for every images')
  29. parser.add_argument('--iou-thres', type=float, default=0.45, help='iou threshold for NMS')
  30. parser.add_argument('--conf-thres', type=float, default=0.25, help='conf threshold for NMS')
  31. parser.add_argument('--device', default='cpu', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
  32. parser.add_argument('--simplify', action='store_true', help='simplify onnx model')
  33. parser.add_argument('--include-nms', action='store_true', help='export end2end onnx')
  34. parser.add_argument('--fp16', action='store_true', help='CoreML FP16 half-precision export')
  35. parser.add_argument('--int8', action='store_true', help='CoreML INT8 quantization')
  36. opt = parser.parse_args()
  37. opt.img_size *= 2 if len(opt.img_size) == 1 else 1 # opt.img_size=[640, 640]
  38. opt.dynamic = opt.dynamic and not opt.end2end # False
  39. opt.dynamic = False if opt.dynamic_batch else opt.dynamic # False
  40. print(opt)
  41. set_logging()
  42. t = time.time()
  43. # Load PyTorch model
  44. device = select_device(opt.device) # device='cpu'
  45. model = attempt_load(opt.weights, map_location=device) # load FP32 model
  46. labels = model.names
  47. '''
  48. labels=
  49. ['person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat', 'traffic light',
  50. 'fire hydrant', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow',
  51. 'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee',
  52. 'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard',
  53. 'tennis racket', 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple',
  54. 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch',
  55. 'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote', 'keyboard',
  56. 'cell phone', 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase',
  57. 'scissors', 'teddy bear', 'hair drier', 'toothbrush']
  58. '''
  59. # Checks
  60. gs = int(max(model.stride)) # grid size (max stride), gs=32
  61. opt.img_size = [check_img_size(x, gs) for x in opt.img_size] # verify img_size are gs-multiples, opt.img_size=[640, 640]
  62. # Input
  63. img = torch.zeros(opt.batch_size, 3, *opt.img_size).to(device) # image size(1,3,320,192) iDetection, img.shape=torch.Size([1, 3, 640, 640])
  64. # Update model
  65. for k, m in model.named_modules():
  66. m._non_persistent_buffers_set = set() # pytorch 1.6.0 compatibility
  67. if isinstance(m, models.common.Conv): # assign export-friendly activations
  68. if isinstance(m.act, nn.Hardswish):
  69. m.act = Hardswish()
  70. elif isinstance(m.act, nn.SiLU):
  71. m.act = SiLU()
  72. model.model[-1].export = not opt.grid # set Detect() layer grid export, model.model[-1].export=False
  73. y = model(img) # dry run
  74. if opt.include_nms:
  75. model.model[-1].include_nms = True
  76. y = None
  77. # TorchScript export
  78. try:
  79. print('\nStarting TorchScript export with torch %s...' % torch.__version__)
  80. f = opt.weights.replace('.pt', '.torchscript.pt') # f='/kaxier01/projects/FAS/yolov7/weights/yolov7.torchscript.pt'
  81. ts = torch.jit.trace(model, img, strict=False)
  82. ts.save(f) # .torchscript.pt模型可以不依赖于python而直接在c++等环境中运行
  83. print('TorchScript export success, saved as %s' % f)
  84. except Exception as e:
  85. print('TorchScript export failure: %s' % e)
  86. # CoreML export
  87. try:
  88. import coremltools as ct
  89. print('\nStarting CoreML export with coremltools %s...' % ct.__version__)
  90. # convert model from torchscript and apply pixel scaling as per detect.py
  91. ct_model = ct.convert(ts, inputs=[ct.ImageType('image', shape=img.shape, scale=1 / 255.0, bias=[0, 0, 0])])
  92. bits, mode = (8, 'kmeans_lut') if opt.int8 else (16, 'linear') if opt.fp16 else (32, None)
  93. if bits < 32:
  94. if sys.platform.lower() == 'darwin': # quantization only supported on macOS
  95. with warnings.catch_warnings():
  96. warnings.filterwarnings("ignore", category=DeprecationWarning) # suppress numpy==1.20 float warning
  97. ct_model = ct.models.neural_network.quantization_utils.quantize_weights(ct_model, bits, mode)
  98. else:
  99. print('quantization only supported on macOS, skipping...')
  100. f = opt.weights.replace('.pt', '.mlmodel') # f='/kaxier01/projects/FAS/yolov7/weights/yolov7.mlmodel'
  101. ct_model.save(f) # .mlmodel可部署到IOS端
  102. print('CoreML export success, saved as %s' % f)
  103. except Exception as e:
  104. print('CoreML export failure: %s' % e)
  105. # TorchScript-Lite export
  106. try:
  107. print('\nStarting TorchScript-Lite export with torch %s...' % torch.__version__)
  108. f = opt.weights.replace('.pt', '.torchscript.ptl') # f='/kaxier01/projects/FAS/yolov7/weights/yolov7.torchscript.ptl'
  109. tsl = torch.jit.trace(model, img, strict=False)
  110. tsl = optimize_for_mobile(tsl)
  111. tsl._save_for_lite_interpreter(f) # .torchscript.ptl模型可部署到Android端
  112. print('TorchScript-Lite export success, saved as %s' % f)
  113. except Exception as e:
  114. print('TorchScript-Lite export failure: %s' % e)
  115. # ONNX export
  116. try:
  117. import onnx
  118. print('\nStarting ONNX export with onnx %s...' % onnx.__version__)
  119. f = opt.weights.replace('.pt', '.onnx') # f='/kaxier01/projects/FAS/yolov7/weights/yolov7.onnx'
  120. model.eval()
  121. output_names = ['classes', 'boxes'] if y is None else ['output'] # output_names=['output']
  122. dynamic_axes = None
  123. if opt.dynamic:
  124. dynamic_axes = {'images': {0: 'batch', 2: 'height', 3: 'width'}, # size(1,3,640,640)
  125. 'output': {0: 'batch', 2: 'y', 3: 'x'}}
  126. if opt.dynamic_batch:
  127. opt.batch_size = 'batch'
  128. dynamic_axes = {
  129. 'images': {
  130. 0: 'batch',
  131. }, }
  132. if opt.end2end and opt.max_wh is None:
  133. output_axes = {
  134. 'num_dets': {0: 'batch'},
  135. 'det_boxes': {0: 'batch'},
  136. 'det_scores': {0: 'batch'},
  137. 'det_classes': {0: 'batch'},
  138. }
  139. else:
  140. output_axes = {
  141. 'output': {0: 'batch'},
  142. }
  143. dynamic_axes.update(output_axes)
  144. if opt.grid:
  145. if opt.end2end:
  146. print('\nStarting export end2end onnx model for %s...' % 'TensorRT' if opt.max_wh is None else 'onnxruntime')
  147. model = End2End(model,opt.topk_all,opt.iou_thres,opt.conf_thres,opt.max_wh,device,len(labels))
  148. if opt.end2end and opt.max_wh is None:
  149. output_names = ['num_dets', 'det_boxes', 'det_scores', 'det_classes']
  150. shapes = [opt.batch_size, 1, opt.batch_size, opt.topk_all, 4,
  151. opt.batch_size, opt.topk_all, opt.batch_size, opt.topk_all]
  152. else:
  153. output_names = ['output']
  154. else:
  155. model.model[-1].concat = True
  156. torch.onnx.export(model, img, f, verbose=False, opset_version=12, input_names=['images'],
  157. output_names=output_names,
  158. dynamic_axes=dynamic_axes)
  159. # Checks
  160. onnx_model = onnx.load(f) # load onnx model
  161. onnx.checker.check_model(onnx_model) # check onnx model
  162. if opt.end2end and opt.max_wh is None:
  163. for i in onnx_model.graph.output:
  164. for j in i.type.tensor_type.shape.dim:
  165. j.dim_param = str(shapes.pop(0))
  166. # print(onnx.helper.printable_graph(onnx_model.graph)) # print a human readable model
  167. if opt.simplify:
  168. try:
  169. import onnxsim
  170. print('\nStarting to simplify ONNX...')
  171. onnx_model, check = onnxsim.simplify(onnx_model) # 简化模型
  172. assert check, 'assert check failed'
  173. except Exception as e:
  174. print(f'Simplifier failure: {e}')
  175. # print(onnx.helper.printable_graph(onnx_model.graph)) # print a human readable model
  176. onnx.save(onnx_model,f)
  177. print('ONNX export success, saved as %s' % f)
  178. if opt.include_nms:
  179. print('Registering NMS plugin for ONNX...')
  180. mo = RegisterNMS(f)
  181. mo.register_nms()
  182. mo.save(f)
  183. except Exception as e:
  184. print('ONNX export failure: %s' % e)
  185. # Finish
  186. print('\nExport complete (%.2fs). Visualize with https://github.com/lutzroeder/netron.' % (time.time() - t))

如果遇到

CoreML export failure: Core ML only supports tensors with rank <= 5. Layer "model.105.anchor_grid", with type "const", outputs a rank 6 tensor.

则把输入指令改成

python export.py --weights /kaxier01/projects/FAS/yolov7/weights/yolov7.pt --end2end --simplify --topk-all 100 --iou-thres 0.65 --conf-thres 0.35 --img-size 640 640 --max-wh 640

yolov7.onnx网络结构图

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/凡人多烦事01/article/detail/295625
推荐阅读
相关标签
  

闽ICP备14008679号