当前位置:   article > 正文

Tensor RT-pytorch权重文件转engine_tensort加速pth权重文件转engine

tensort加速pth权重文件转engine

1.根据pytorch保存权重的方法保存

两种方法

  1. #第一种方法
  2. '''保存weight等信息'''
  3. state = {‘net':model.state_dict(), 'optimizer':optimizer.state_dict(), 'epoch':epoch}
  4. torch.save(state, dir)
  5. '''读取方法'''
  6. checkpoint = torch.load(dir)
  7. model.load_state_dict(checkpoint['net'])
  8. optimizer.load_state_dict(checkpoint['optimizer'])
  9. start_epoch = checkpoint['epoch'] + 1

 

  1. #第二种方法
  2. '''只保存weight信息'''
  3. torch.save(net.state_dict(), Path)
  4. '''读取方法'''
  5. model.load_state_dict(torch.load(Path))

 

2.将.pth文件转成.onnx的通用格式

代码如下:主要流程就是定义好输入、创建并载入模型后,即可用pytorch的onnx接口转将.pth转成.onnx

  1. #定义参数
  2. input_name = ['input']
  3. output_name = ['output']
  4. '''input为输入模型图片的大小'''
  5. input = torch.randn(1, 3, 32, 32).cuda()
  6. # 创建模型并载入权重
  7. model = MobileNetV2(num_classes=4)
  8. model_weight_path = "./MobileNetV2.pth"
  9. model.load_state_dict(torch.load(model_weight_path))
  10. model.cuda()
  11. #导出onnx
  12. torch.onnx.export(model, input, 'mobileNetV2.onnx', input_names=input_name, output_names=output_name, verbose=True)

 

3.将onnx转换成engine

 

(1)下载好TensorRT库

 

(2)进入~/samples/trtexec,运行

make 

在~/bin下面会多出trtexec和trtexec_debug两个文件

 

(3)在TensorRT的目录下运行

  1. '''半精度导出'''
  2. ./bin/trtexec --onnx=~.onnx --fp16 --saveEngine=~.engine
  3. '''全精度导出'''
  4. ./bin/trtexec --onnx=~.onnx --saveEngine=~.engine

其中‘./bin/trtexec’为刚刚生成的trtexec所在路径,~.onnx为onnx文件所在路径,~.engine为engine的生成路径

 

4.用tensorRT自带的API,看engine做inference的时间

trtexec --loadEngine=32.engine --exportOutput=~.trt 

其中~.engine为engine文件的路径,~.trt为输出的文件路径。(实测环境1660显卡,resnet34在pytorch的inference时间为6.66ms,tensorRT FP32:2.5ms,tensorRT FP16:1.28ms)

 

 

5.用.engine文件,做inference

代码如下:

  1. import torchvision
  2. import torch
  3. from torch.autograd import Variable
  4. import onnx
  5. from model import MobileNetV2
  6. from torchvision import transforms
  7. import pycuda.autoinit
  8. import numpy as np
  9. import pycuda.driver as cuda
  10. import tensorrt as trt
  11. import torch
  12. import os
  13. import time
  14. from PIL import Image
  15. import cv2
  16. import torchvision
  17. #数据转换
  18. data_transform = transforms.Compose(
  19. [transforms.Resize(40),
  20. transforms.CenterCrop(32),
  21. transforms.ToTensor(),
  22. transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
  23. # 创建模型[图片]
  24. model = MobileNetV2(num_classes=4)
  25. model_weight_path = "./MobileNetV2.pth"
  26. model.load_state_dict(torch.load(model_weight_path))
  27. model.cuda()
  28. filename = './0.jpg'
  29. # img = Image.open(filename)
  30. # img = data_transform(img)
  31. # img=img.unsqueeze(0).cuda()
  32. # t1=time.time()
  33. # for i in range(100):
  34. # model(img)
  35. # t2=time.time()
  36. # print('inference time : ',(t2-t1)*1000/100 ,'ms')
  37. filename = './0.jpg'
  38. max_batch_size = 1
  39. onnx_model_path = 'mobilenetV2.onnx'
  40. TRT_LOGGER = trt.Logger() # This logger is required to build an engine
  41. def get_img_np_nchw(filename):
  42. img = Image.open(filename)
  43. img = data_transform(img)
  44. img = torch.unsqueeze(img, dim=0)
  45. img=img.numpy()
  46. return img
  47. class HostDeviceMem(object):
  48. def __init__(self, host_mem, device_mem):
  49. """Within this context, host_mom means the cpu memory and device means the GPU memory
  50. """
  51. self.host = host_mem
  52. self.device = device_mem
  53. def __str__(self):
  54. return "Host:\n" + str(self.host) + "\nDevice:\n" + str(self.device)
  55. def __repr__(self):
  56. return self.__str__()
  57. def allocate_buffers(engine):
  58. inputs = []
  59. outputs = []
  60. bindings = []
  61. stream = cuda.Stream()
  62. for binding in engine:
  63. size = trt.volume(engine.get_binding_shape(binding)) * engine.max_batch_size
  64. dtype = trt.nptype(engine.get_binding_dtype(binding))
  65. # Allocate host and device buffers
  66. host_mem = cuda.pagelocked_empty(size, dtype)
  67. device_mem = cuda.mem_alloc(host_mem.nbytes)
  68. # Append the device buffer to device bindings.
  69. bindings.append(int(device_mem))
  70. # Append to the appropriate list.
  71. if engine.binding_is_input(binding):
  72. inputs.append(HostDeviceMem(host_mem, device_mem))
  73. else:
  74. outputs.append(HostDeviceMem(host_mem, device_mem))
  75. return inputs, outputs, bindings, stream
  76. def get_engine(max_batch_size=1, onnx_file_path="", engine_file_path="", \
  77. fp16_mode=False, int8_mode=False, save_engine=False,
  78. ):
  79. """Attempts to load a serialized engine if available, otherwise builds a new TensorRT engine and saves it."""
  80. def build_engine(max_batch_size, save_engine):
  81. """Takes an ONNX file and creates a TensorRT engine to run inference with"""
  82. with trt.Builder(TRT_LOGGER) as builder, \
  83. builder.create_network() as network, \
  84. trt.OnnxParser(network, TRT_LOGGER) as parser:
  85. builder.max_workspace_size = 1 << 30 # Your workspace size
  86. builder.max_batch_size = max_batch_size
  87. # pdb.set_trace()
  88. builder.fp16_mode = fp16_mode # Default: False
  89. builder.int8_mode = int8_mode # Default: False
  90. if int8_mode:
  91. # To be updated
  92. raise NotImplementedError
  93. # Parse model file
  94. if not os.path.exists(onnx_file_path):
  95. quit('ONNX file {} not found'.format(onnx_file_path))
  96. print('Loading ONNX file from path {}...'.format(onnx_file_path))
  97. with open(onnx_file_path, 'rb') as model:
  98. print('Beginning ONNX file parsing')
  99. parser.parse(model.read())
  100. print('Completed parsing of ONNX file')
  101. print('Building an engine from file {}; this may take a while...'.format(onnx_file_path))
  102. engine = builder.build_cuda_engine(network)
  103. print("Completed creating Engine")
  104. if save_engine:
  105. with open(engine_file_path, "wb") as f:
  106. f.write(engine.serialize())
  107. return engine
  108. if os.path.exists(engine_file_path):
  109. # If a serialized engine exists, load it instead of building a new one.
  110. print("Reading engine from file {}".format(engine_file_path))
  111. with open(engine_file_path, "rb") as f, trt.Runtime(TRT_LOGGER) as runtime:
  112. return runtime.deserialize_cuda_engine(f.read())
  113. else:
  114. return build_engine(max_batch_size, save_engine)
  115. def do_inference(context, bindings, inputs, outputs, stream, batch_size=1):
  116. # Transfer data from CPU to the GPU.
  117. [cuda.memcpy_htod_async(inp.device, inp.host, stream) for inp in inputs]
  118. # Run inference.
  119. context.execute_async(batch_size=batch_size, bindings=bindings, stream_handle=stream.handle)
  120. # Transfer predictions back from the GPU.
  121. [cuda.memcpy_dtoh_async(out.host, out.device, stream) for out in outputs]
  122. # Synchronize the stream
  123. stream.synchronize()
  124. # Return only the host outputs.
  125. return [out.host for out in outputs]
  126. def postprocess_the_outputs(h_outputs, shape_of_output):
  127. h_outputs = h_outputs.reshape(*shape_of_output)
  128. return h_outputs
  129. #图像预处理
  130. img_np_nchw = get_img_np_nchw(filename)
  131. img_np_nchw = img_np_nchw.astype(dtype=np.float32)
  132. #fp16
  133. fp16_mode = True
  134. int8_mode = False
  135. trt_engine_path = './model_fp16_{}_int8_{}.trt'.format(fp16_mode, int8_mode)
  136. engine_file_path='./resnet34-32.engine'
  137. # 创建engine
  138. engine = get_engine(max_batch_size, onnx_model_path,engine_file_path, trt_engine_path, fp16_mode, int8_mode)
  139. context = engine.create_execution_context()
  140. inputs, outputs, bindings, stream = allocate_buffers(engine) # input, output: host # bindings
  141. shape_of_output = (max_batch_size, 4)
  142. inputs[0].host = img_np_nchw.reshape(-1)
  143. t1 = time.time()
  144. for i in range(100):
  145. trt_outputs = do_inference(context, bindings=bindings, inputs=inputs, outputs=outputs, stream=stream) # numpy data
  146. t2 = time.time()
  147. feat = postprocess_the_outputs(trt_outputs[0], shape_of_output)
  148. print('feat 16:',feat)
  149. print("Inference time with the TensorRT FP16 engine: {} ms".format((t2-t1)*1000/100))

链接:

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/从前慢现在也慢/article/detail/423920
推荐阅读
相关标签
  

闽ICP备14008679号