当前位置:   article > 正文

基于华为atlas的unet分割模型探索

基于华为atlas的unet分割模型探索

Unet模型使用官方基于kaggle Carvana Image Masking Challenge数据集训练的模型。

模型输入为572*572*3,输出为572*572*2。分割目标分别为,0:背景,1:汽车。

Pytorch的pth模型转化onnx模型:

  1. import torch
  2. from unet import UNet
  3. model = UNet(n_channels=3, n_classes=2, bilinear=False)
  4. model = model.to(memory_format=torch.channels_last)
  5. state_dict = torch.load("unet_carvana_scale1.0_epoch2.pth", map_location="cpu")
  6. #del state_dict['mask_values']
  7. model.load_state_dict(state_dict)
  8. dummy_input = torch.randn(1, 3, 572, 572)
  9. torch.onnx.export(model, dummy_input, "unet.onnx", verbose=True)

模型输入输出节点分析:

使用工具Netron查看模型结构,确定模型输入节点名称为input.1,输出节点名称为/outc/conv/Conv

onnx模型转化atlas模型:

atc --model=./unet.onnx --framework=5 --output=unet --soc_version=Ascend310P3  --input_shape="input.1:1,3,572,572" --output_type="/outc/conv/Conv:0:FP32" --out_nodes="/outc/conv/Conv:0"

推理代码实现:

  1. import base64
  2. import json
  3. import os
  4. import time
  5. import numpy as np
  6. import cv2
  7. import MxpiDataType_pb2 as mxpi_data
  8. from StreamManagerApi import InProtobufVector
  9. from StreamManagerApi import MxProtobufIn
  10. from StreamManagerApi import StreamManagerApi
  11. def check_dir(dir):
  12. if not os.path.exists(dir):
  13. os.makedirs(dir, exist_ok=True)
  14. class SDKInferWrapper:
  15. def __init__(self): # 完成初始化
  16. self._stream_name = None
  17. self._stream_mgr_api = StreamManagerApi()
  18. if self._stream_mgr_api.InitManager() != 0:
  19. raise RuntimeError("Failed to init stream manager.")
  20. pipeline_name = './nested_unet.pipeline'
  21. self.load_pipeline(pipeline_name)
  22. self.width = 572
  23. self.height = 572
  24. def load_pipeline(self, pipeline_path):
  25. with open(pipeline_path, 'r') as f:
  26. pipeline = json.load(f)
  27. self._stream_name = list(pipeline.keys())[0].encode() # 'unet_pytorch'
  28. if self._stream_mgr_api.CreateMultipleStreams(
  29. json.dumps(pipeline).encode()) != 0:
  30. raise RuntimeError("Failed to create stream.")
  31. def do_infer(self, img_bgr):
  32. # preprocess
  33. image = cv2.resize(img_bgr, (self.width, self.height))
  34. image = cv2.cvtColor(image,cv2.COLOR_BGR2RGB)
  35. image = image.astype('float32') / 255.0
  36. image = image.transpose(2, 0, 1)
  37. tensor_pkg_list = mxpi_data.MxpiTensorPackageList()
  38. tensor_pkg = tensor_pkg_list.tensorPackageVec.add()
  39. tensor_vec = tensor_pkg.tensorVec.add()
  40. tensor_vec.deviceId = 0
  41. tensor_vec.memType = 0
  42. for dim in [1, *image.shape]:
  43. tensor_vec.tensorShape.append(dim) # tensorshape属性为[1,3,572,572]
  44. input_data = image.tobytes()
  45. tensor_vec.dataStr = input_data
  46. tensor_vec.tensorDataSize = len(input_data)
  47. protobuf_vec = InProtobufVector()
  48. protobuf = MxProtobufIn()
  49. protobuf.key = b'appsrc0'
  50. protobuf.type = b'MxTools.MxpiTensorPackageList'
  51. protobuf.protobuf = tensor_pkg_list.SerializeToString()
  52. protobuf_vec.push_back(protobuf)
  53. unique_id = self._stream_mgr_api.SendProtobuf(
  54. self._stream_name, 0, protobuf_vec)
  55. if unique_id < 0:
  56. raise RuntimeError("Failed to send data to stream.")
  57. infer_result = self._stream_mgr_api.GetResult(
  58. self._stream_name, unique_id)
  59. if infer_result.errorCode != 0:
  60. raise RuntimeError(
  61. f"GetResult error. errorCode={infer_result.errorCode}, "
  62. f"errorMsg={infer_result.data.decode()}")
  63. output_tensor = self._parse_output_data(infer_result)
  64. output_tensor = np.squeeze(output_tensor)
  65. output_tensor = softmax(output_tensor)
  66. mask = np.argmax(output_tensor, axis =0)
  67. score = np.max(output_tensor, axis = 0)
  68. mask = cv2.resize(mask, [img_bgr.shape[1], img_bgr.shape[0]], interpolation=cv2.INTER_NEAREST)
  69. score = cv2.resize(score, [img_bgr.shape[1], img_bgr.shape[0]], interpolation=cv2.INTER_NEAREST)
  70. return mask, score
  71. def _parse_output_data(self, output_data):
  72. infer_result_data = json.loads(output_data.data.decode())
  73. content = json.loads(infer_result_data['metaData'][0]['content'])
  74. tensor_vec = content['tensorPackageVec'][0]['tensorVec'][0]
  75. data_str = tensor_vec['dataStr']
  76. tensor_shape = tensor_vec['tensorShape']
  77. infer_array = np.frombuffer(base64.b64decode(data_str), dtype=np.float32)
  78. return infer_array.reshape(tensor_shape)
  79. def draw(self, mask):
  80. color_lists = [(255, 0, 0), (0, 255, 0), (0, 0, 255)]
  81. drawed_img = np.stack([mask, mask, mask], axis = 2)
  82. for i in np.unique(mask):
  83. drawed_img[:,:,0][drawed_img[:,:,0]==i] = color_lists[i][0]
  84. drawed_img[:,:,1][drawed_img[:,:,1]==i] = color_lists[i][1]
  85. drawed_img[:,:,2][drawed_img[:,:,2]==i] = color_lists[i][2]
  86. return drawed_img
  87. def softmax(x):
  88. exps = np.exp(x - np.max(x))
  89. return exps/np.sum(exps)
  90. def sigmoid(x):
  91. y = x.copy()
  92. y[x >= 0] = 1.0 / (1 + np.exp(-x[x >= 0]))
  93. y[x < 0] = np.exp(x[x < 0]) / (1 + np.exp(x[x < 0]))
  94. return y
  95. def check_dir(dir):
  96. if not os.path.exists(dir):
  97. os.makedirs(dir, exist_ok=True)
  98. def test():
  99. dataset_dir = './sample_data'
  100. output_folder = "./infer_result"
  101. os.makedirs(output_folder, exist_ok=True)
  102. sdk_infer = SDKInferWrapper()
  103. # read img
  104. image_name = "./sample_data/images/111.jpg"
  105. img_bgr = cv2.imread(image_name)
  106. # infer
  107. t1 = time.time()
  108. mask, score = sdk_infer.do_infer(img_bgr)
  109. t2 = time.time()
  110. print(t2-t1, mask, score)
  111. drawed_img = sdk_infer.draw(mask)
  112. cv2.imwrite("infer_result/draw.png", drawed_img)
  113. if __name__ == "__main__":
  114. test()

运行代码:

  1. set -e
  2. . /usr/local/Ascend/ascend-toolkit/set_env.sh
  3. # Simple log helper functions
  4. info() { echo -e "\033[1;34m[INFO ][MxStream] $1\033[1;37m" ; }
  5. warn() { echo >&2 -e "\033[1;31m[WARN ][MxStream] $1\033[1;37m" ; }
  6. #export MX_SDK_HOME=/home/work/mxVision
  7. export LD_LIBRARY_PATH=${MX_SDK_HOME}/lib:${MX_SDK_HOME}/opensource/lib:${MX_SDK_HOME}/opensource/lib64:/usr/local/Ascend/ascend-toolkit/latest/acllib/lib64:${LD_LIBRARY_PATH}
  8. export GST_PLUGIN_SCANNER=${MX_SDK_HOME}/opensource/libexec/gstreamer-1.0/gst-plugin-scanner
  9. export GST_PLUGIN_PATH=${MX_SDK_HOME}/opensource/lib/gstreamer-1.0:${MX_SDK_HOME}/lib/plugins
  10. #to set PYTHONPATH, import the StreamManagerApi.py
  11. export PYTHONPATH=$PYTHONPATH:${MX_SDK_HOME}/python
  12. python3 unet.py
  13. exit 0

运行效果:

个人思考:

华为atlas的参考案例细节不到位,步骤缺失较多,摸索困难,代码写法较差,信创化道路任重而道远。

参考资料:

GitHub - milesial/Pytorch-UNet: PyTorch implementation of the U-Net for image semantic segmentation with high quality images

https://gitee.com/ascend/samples/tree/master/python/level2_simple_inference/3_segmentation/unet++

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/知新_RL/article/detail/196919
推荐阅读
相关标签
  

闽ICP备14008679号