当前位置:   article > 正文

UNET-RKNN分割眼底血管_unet pytorch rknn推理

unet pytorch rknn推理

前言

        最近找到一个比较好玩的Unet分割项目Unet的出现就是为了在医学上进行分割(比如细胞或者血管),这里进行眼底血管的分割,用的backbone是VGG16,结构如下如所示(项目里面的图片,借用的!借用标记出处,尊重别人的知识产权),模型比较小,但是效果感觉还不错的。

         相关的算法发介绍就不写了接下来从PYTORCH、ONNX、rknn三个方面看看效果

全部代码地址: https://pan.baidu.com/s/1QkOz5tvRSF-UkJhmpI__lA 提取码: 8twv 

检测原图

1. Pytroch推理代码

        gpu_test文件夹

├── predict.py:推理代码
├── test_result_cuda.png: 检测结果
├── save_weights:模型文件夹
├── images:图片文件夹
├── src:相关库文件夹
└── mask:mask图片文件夹 

  1. import os
  2. import time
  3. import torch
  4. from torchvision import transforms
  5. import numpy as np
  6. from PIL import Image
  7. from src import UNet
  8. def time_synchronized():
  9. torch.cuda.synchronize() if torch.cuda.is_available() else None
  10. return time.time()
  11. def main():
  12. classes = 1 # exclude background
  13. # 模型路径
  14. weights_path = "./save_weights/best_model.pth"
  15. # 检测图片路径
  16. img_path = "./images/01_test.tif"
  17. # mask图片路径
  18. roi_mask_path = "./mask/01_test_mask.gif"
  19. assert os.path.exists(weights_path), f"weights {weights_path} not found."
  20. assert os.path.exists(img_path), f"image {img_path} not found."
  21. assert os.path.exists(roi_mask_path), f"image {roi_mask_path} not found."
  22. mean = (0.709, 0.381, 0.224)
  23. std = (0.127, 0.079, 0.043)
  24. # get devices
  25. device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
  26. # 用cpu推理
  27. # device = "cpu"
  28. print("using {} device.".format(device))
  29. # create model
  30. model = UNet(in_channels=3, num_classes=classes+1, base_c=32)
  31. # load weights
  32. model.load_state_dict(torch.load(weights_path, map_location='cpu')['model'])
  33. model.to(device)
  34. # dummy_input = torch.randn(1, 3, 584, 565)
  35. # torch.onnx.export(model, dummy_input, 'eyes_unet.onnx', verbose=True, opset_version=11)
  36. # load roi mask
  37. roi_img = Image.open(roi_mask_path).convert('L')
  38. roi_img = np.array(roi_img)
  39. # load image
  40. original_img = Image.open(img_path).convert('RGB')
  41. # from pil image to tensor and normalize
  42. data_transform = transforms.Compose([transforms.ToTensor(),
  43. transforms.Normalize(mean=mean, std=std)])
  44. img = data_transform(original_img)
  45. # expand batch dimension
  46. img = torch.unsqueeze(img, dim=0)
  47. model.eval() # 进入验证模式
  48. with torch.no_grad():
  49. # init model
  50. img_height, img_width = img.shape[-2:]
  51. init_img = torch.zeros((1, 3, img_height, img_width), device=device)
  52. model(init_img)
  53. t_start = time_synchronized()
  54. output = model(img.to(device))
  55. print(output["out"].shape)
  56. t_end = time_synchronized()
  57. print("inference time: {}".format(t_end - t_start))
  58. prediction = output['out'].argmax(1).squeeze(0)
  59. prediction = prediction.to("cpu").numpy().astype(np.uint8)
  60. # np.save("cuda_unet.npy", prediction)
  61. print(prediction.shape)
  62. # 将前景对应的像素值改成255(白色)
  63. prediction[prediction == 1] = 255
  64. # 将不敢兴趣的区域像素设置成0(黑色)
  65. prediction[roi_img == 0] = 0
  66. mask = Image.fromarray(prediction)
  67. mask.save("test_result_cuda.png")
  68. if __name__ == '__main__':
  69. main()

        检测结果

2. ONNX代码推理

        onnx_test文件夹

├── images : 检测图片文件夹
├── test_result_onnx.png: 检测结果
├── predict_onnx.py:推理代码
├── mask:mask图片文件夹
└── eyes_unet-sim.onnx :模型文件

  1. import os
  2. import time
  3. from torchvision import transforms
  4. import numpy as np
  5. from PIL import Image
  6. import onnxruntime as rt
  7. def main():
  8. # classes = 1 # exclude background
  9. img_path = "./images/01_test.tif"
  10. roi_mask_path = "./mask/01_test_mask.gif"
  11. assert os.path.exists(img_path), f"image {img_path} not found."
  12. assert os.path.exists(roi_mask_path), f"image {roi_mask_path} not found."
  13. mean = (0.709, 0.381, 0.224)
  14. std = (0.127, 0.079, 0.043)
  15. # load roi mask
  16. roi_img = Image.open(roi_mask_path).convert('L')
  17. roi_img = np.array(roi_img)
  18. # load image
  19. original_img = Image.open(img_path).convert('RGB')
  20. # from pil image to tensor and normalize
  21. data_transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean=mean, std=std)])
  22. img = data_transform(original_img)
  23. # expand batch dimension
  24. img = img.numpy()
  25. img = img[np.newaxis, :]
  26. t_start = time.time()
  27. sess = rt.InferenceSession('./eyes_unet-sim.onnx')
  28. # 模型的输入和输出节点名,可以通过netron查看
  29. input_name = 'input.1'
  30. outputs_name = ['437']
  31. # 模型推理:模型输出节点名,模型输入节点名,输入数据(注意节点名的格式!!!!!)
  32. output = sess.run(outputs_name, {input_name: img})
  33. output = np.array(output).reshape(1, 2, 584, 565)
  34. t_end = time.time()
  35. print("inference time: {}".format(t_end - t_start))
  36. prediction = np.squeeze(np.argmax(output, axis=1))
  37. print(prediction.shape)
  38. prediction = prediction.astype(np.uint8)
  39. # 将前景对应的像素值改成255(白色)
  40. prediction[prediction == 1] = 255
  41. # 将不敢兴趣的区域像素设置成0(黑色)
  42. prediction[roi_img == 0] = 0
  43. mask = Image.fromarray(prediction)
  44. mask.save("test_result_onnx.png")
  45. if __name__ == '__main__':
  46. main()

        检测结果 

 3. RKNN模型转化

        rknn_trans_1808_3588文件夹

├── dataset.txt: 量化数据集路径 
├── images :量化数据集
├── trans_1808.py :适用1808的rknn模型
├── trans_3588.py :适用3588的rknn模型
├── mask:没用到 
└── eyes_unet-sim.onnx:原始onnx模型
        这个没什么好说的,装好环境,直接在相应的环境里面转就好啦,大家应该都会的(不会就拉出去,或者收藏留言,嘿嘿,我看看,出不出教程呢)

 4. RKNN模型推理

        4.1 rk1808_test文件夹

├── 01_test_mask.gif:mask图片
├── eyes_unet-sim-1808.rknn:rk1808适用模型
├── predict_rknn_1808.py:推理代码
├── test_result_1808.png :检测结果
└── 01_test.tif:检测图片

  1. import os
  2. import time
  3. import numpy as np
  4. from PIL import Image
  5. from rknn.api import RKNN
  6. def main():
  7. # classes = 1 # exclude background
  8. RKNN_MODEL = "./eyes_unet-sim-1808.rknn"
  9. img_path = "./01_test.tif"
  10. roi_mask_path = "./01_test_mask.gif"
  11. assert os.path.exists(img_path), f"image {img_path} not found."
  12. assert os.path.exists(roi_mask_path), f"image {roi_mask_path} not found."
  13. # load roi mask
  14. roi_img = Image.open(roi_mask_path).convert('L')
  15. roi_img = np.array(roi_img)
  16. # load image
  17. original_img = Image.open(img_path).convert('RGB')
  18. # from pil image to tensor and normalize
  19. # data_transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean=mean, std=std)])
  20. # img = data_transform(original_img)
  21. # expand batch dimension
  22. img = np.array(original_img)
  23. img = img[np.newaxis, :]
  24. # Create RKNN object
  25. rknn = RKNN(verbose=False)
  26. ret = rknn.load_rknn(RKNN_MODEL)
  27. # Init runtime environment
  28. print('--> Init runtime environment')
  29. ret = rknn.init_runtime(target='rk1808')
  30. if ret != 0:
  31. print('Init runtime environment failed!')
  32. exit(ret)
  33. print('done')
  34. t_start = time.time()
  35. output = rknn.inference(inputs=[img])
  36. t_end = time.time()
  37. print("inference time: {}".format(t_end - t_start))
  38. output = np.array(output).reshape(1, 2, 584, 565)
  39. prediction = np.squeeze(np.argmax(output, axis=1))
  40. print(prediction.shape)
  41. prediction = prediction.astype(np.uint8)
  42. # 将前景对应的像素值改成255(白色)
  43. prediction[prediction == 1] = 255
  44. # 将不敢兴趣的区域像素设置成0(黑色)
  45. prediction[roi_img == 0] = 0
  46. mask = Image.fromarray(prediction)
  47. mask.save("test_result_1808.png")
  48. rknn.release()
  49. if __name__ == '__main__':
  50. main()

        检测结果

      4.2 rk3588_test文件夹

├── 01_test_mask.gif:mask图片
├── test_result_3588.png:检测结果
├── eyes_unet-sim-3588.rknn:rk1808适用模型
├── 01_test.tif:检测图片
└── predict_3588.py:推理代码

  1. import os
  2. import time
  3. import numpy as np
  4. from PIL import Image
  5. from rknnlite.api import RKNNLite
  6. def main():
  7. # classes = 1 # exclude background
  8. RKNN_MODEL = "./eyes_unet-sim.rknn"
  9. img_path = "./01_test.tif"
  10. roi_mask_path = "./01_test_mask.gif"
  11. assert os.path.exists(img_path), f"image {img_path} not found."
  12. assert os.path.exists(roi_mask_path), f"image {roi_mask_path} not found."
  13. # load roi mask
  14. roi_img = Image.open(roi_mask_path).convert('L')
  15. roi_img = np.array(roi_img)
  16. # load image
  17. original_img = Image.open(img_path).convert('RGB')
  18. # from pil image to tensor and normalize
  19. # data_transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean=mean, std=std)])
  20. # img = data_transform(original_img)
  21. # expand batch dimension
  22. img = np.array(original_img)
  23. img = img[np.newaxis, :]
  24. # Create RKNN object
  25. rknn_lite = RKNNLite(verbose=False)
  26. ret = rknn_lite.load_rknn(RKNN_MODEL)
  27. # Init runtime environment
  28. print('--> Init runtime environment')
  29. ret = rknn_lite.init_runtime(core_mask=RKNNLite.NPU_CORE_AUTO)
  30. if ret != 0:
  31. print('Init runtime environment failed!')
  32. exit(ret)
  33. print('done')
  34. t_start = time.time()
  35. output = rknn_lite.inference(inputs=[img])
  36. t_end = time.time()
  37. print("inference time: {}".format(t_end - t_start))
  38. output = np.array(output).reshape(1, 2, 584, 565)
  39. prediction = np.squeeze(np.argmax(output, axis=1))
  40. print(prediction.shape)
  41. prediction = prediction.astype(np.uint8)
  42. np.save("int8_unet.npy", prediction)
  43. # 将前景对应的像素值改成255(白色)
  44. prediction[prediction == 1] = 255
  45. # 将不敢兴趣的区域像素设置成0(黑色)
  46. prediction[roi_img == 0] = 0
  47. mask = Image.fromarray(prediction)
  48. mask.save("test_result_3588.png")
  49. rknn_lite.release()
  50. if __name__ == '__main__':
  51. main()

检测结果

5. 所有结果对比 

原图

GPU/ONNX

RK1808

RK3588

         其实比对了一下数据,量化的效果还不错,精度在99.5%左右,还是蛮好的!!!

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/羊村懒王/article/detail/351143
推荐阅读
相关标签
  

闽ICP备14008679号