当前位置:   article > 正文

ONNX Runtime介绍_onnxruntime

onnxruntime

      ONNX Runtime:由微软推出,用于优化和加速机器学习推理和训练,适用于ONNX模型,是一个跨平台推理和训练机器学习加速器(ONNX Runtime is a cross-platform inference and training machine-learning accelerator),源码地址:https://github.com/microsoft/onnxruntime,最新发布版本为v1.11.1,License为MIT:

      1.ONNX Runtime Inferencing:高性能推理引擎

      (1).可在不同的操作系统上运行,包括Windows、Linux、Mac、Android、iOS等;

      (2).可利用硬件增加性能,包括CUDA、TensorRT、DirectML、OpenVINO等;

      (3).支持PyTorch、TensorFlow等深度学习框架的模型,需先调用相应接口转换为ONNX模型;

      (4).在Python中训练,确可部署到C++/Java等应用程序中。

      2.ONNX Runtime Training:于2021年4月发布,可加快PyTorch对模型训练,可通过CUDA加速,目前多用于Linux平台。

      通过conda命令安装执行:

conda install -c conda-forge onnxruntime

      以下为测试代码:通过ResNet-50对图像进行分类

  1. import numpy as np
  2. import onnxruntime
  3. import onnx
  4. from onnx import numpy_helper
  5. import urllib.request
  6. import os
  7. import tarfile
  8. import json
  9. import cv2
  10. # reference: https://github.com/onnx/onnx-docker/blob/master/onnx-ecosystem/inference_demos/resnet50_modelzoo_onnxruntime_inference.ipynb
  11. def download_onnx_model():
  12. labels_file_name = "imagenet-simple-labels.json"
  13. model_tar_name = "resnet50v2.tar.gz"
  14. model_directory_name = "resnet50v2"
  15. if os.path.exists(model_tar_name) and os.path.exists(labels_file_name):
  16. print("files exist, don't need to download")
  17. else:
  18. print("files don't exist, need to download ...")
  19. onnx_model_url = "https://s3.amazonaws.com/onnx-model-zoo/resnet/resnet50v2/resnet50v2.tar.gz"
  20. imagenet_labels_url = "https://raw.githubusercontent.com/anishathalye/imagenet-simple-labels/master/imagenet-simple-labels.json"
  21. # retrieve our model from the ONNX model zoo
  22. urllib.request.urlretrieve(onnx_model_url, filename=model_tar_name)
  23. urllib.request.urlretrieve(imagenet_labels_url, filename=labels_file_name)
  24. print("download completed, start decompress ...")
  25. file = tarfile.open(model_tar_name)
  26. file.extractall("./")
  27. file.close()
  28. return model_directory_name, labels_file_name
  29. def load_labels(path):
  30. with open(path) as f:
  31. data = json.load(f)
  32. return np.asarray(data)
  33. def images_preprocess(images_path, images_name):
  34. input_data = []
  35. for name in images_name:
  36. img = cv2.imread(images_path + name)
  37. img = cv2.resize(img, (224, 224))
  38. img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
  39. data = np.array(img).transpose(2, 0, 1)
  40. #print(f"name: {name}, opencv image shape(h,w,c): {img.shape}, transpose shape(c,h,w): {data.shape}")
  41. # convert the input data into the float32 input
  42. data = data.astype('float32')
  43. # normalize
  44. mean_vec = np.array([0.485, 0.456, 0.406])
  45. stddev_vec = np.array([0.229, 0.224, 0.225])
  46. norm_data = np.zeros(data.shape).astype('float32')
  47. for i in range(data.shape[0]):
  48. norm_data[i,:,:] = (data[i,:,:]/255 - mean_vec[i]) / stddev_vec[i]
  49. # add batch channel
  50. norm_data = norm_data.reshape(1, 3, 224, 224).astype('float32')
  51. input_data.append(norm_data)
  52. return input_data
  53. def softmax(x):
  54. x = x.reshape(-1)
  55. e_x = np.exp(x - np.max(x))
  56. return e_x / e_x.sum(axis=0)
  57. def postprocess(result):
  58. return softmax(np.array(result)).tolist()
  59. def inference(onnx_model, labels, input_data, images_name, images_label):
  60. session = onnxruntime.InferenceSession(onnx_model, None)
  61. # get the name of the first input of the model
  62. input_name = session.get_inputs()[0].name
  63. count = 0
  64. for data in input_data:
  65. print(f"{count+1}. image name: {images_name[count]}, actual value: {images_label[count]}")
  66. count += 1
  67. raw_result = session.run([], {input_name: data})
  68. res = postprocess(raw_result)
  69. idx = np.argmax(res)
  70. print(f" result: idx: {idx}, label: {labels[idx]}, percentage: {round(res[idx]*100, 4)}%")
  71. sort_idx = np.flip(np.squeeze(np.argsort(res)))
  72. print(" top 5 labels are:", labels[sort_idx[:5]])
  73. def main():
  74. model_directory_name, labels_file_name = download_onnx_model()
  75. labels = load_labels(labels_file_name)
  76. print("the number of categories is:", len(labels)) # 1000
  77. images_path = "../../data/image/"
  78. images_name = ["5.jpg", "6.jpg", "7.jpg", "8.jpg", "9.jpg", "10.jpg"]
  79. images_label = ["goldfish", "hen", "ostrich", "crocodile", "goose", "sheep"]
  80. if len(images_name) != len(images_label):
  81. print("Error: images count and labes'length don't match")
  82. return
  83. input_data = images_preprocess(images_path, images_name)
  84. onnx_model = model_directory_name + "/resnet50v2.onnx"
  85. inference(onnx_model, labels, input_data, images_name, images_label)
  86. print("test finish")
  87. if __name__ == "__main__":
  88. main()

      测试图像如下所示:

      执行结果如下所示:

 

      GitHub: https://github.com/fengbingchun/PyTorch_Test

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/神奇cpp/article/detail/850027
推荐阅读
相关标签
  

闽ICP备14008679号