当前位置:   article > 正文

手把手实战教学,语义分割从0到1:三、推理_批量推理

批量推理

本篇博客,是《手把手实战教学!语义分割从0到1》系列的第三篇实战教学,将重点介绍语义分割推理部分。

本系列总的介绍,以及其他章节的汇总,见:https://blog.csdn.net/oYeZhou/article/details/115123589。

目录

1、工程上进行推理的要素

2、构造推理工程

2.1、保留必要的内容

2.2、编写推理class

2.3、单图推理

2.4、批量推理

3、写在后面


1、工程上进行推理的要素

不管是语义分割还是目标检测或者其他的任务,我们要想在工程上进行模型推理,一般都是这么个流程(这里不做量化加速):

  • (1)抽离模型定义及权重文件;
  • (2)编写推理class;
  • (3)提供单张推理及数据集推理的方法;

因此,我们需要在推理的工程中,准备这样几种东西:模型定义脚本、权重文件、推理class、单张测试脚本、批量测试(评价)脚本

2、构造推理工程

2.1、保留必要的内容

在推理工程中,我们仅保留推理所必须的部分,其他的代码、文件统统去掉。

按照我们上篇博客的介绍,我们已经使用自己的VOC格式的语义分割数据集训练好了一个DeepLabv3+模型。此时,我们把其中对推理有必要的部分抽离出来,包括:

上图所示的部分是模型的定义。

然后,把训练好的模型单独放在一个文件夹中:

2.2、编写推理class

接下来,就是编写推理class了。在该class中,需要包含模型定义、模型权重加载、单张推理、结果可视化等内容。这里贴出我所编写的class内容:

  1. # -*- coding: utf-8 -*-
  2. """
  3. Created on 2020.01.12
  4. @author: LWS
  5. Inference of Edge Fence Segmentation.
  6. """
  7. import numpy as np
  8. from PIL import Image
  9. from collections import OrderedDict
  10. import torch
  11. import torch.nn.functional as F
  12. from torchvision import transforms
  13. from models.deeplabv3_plus import DeepLab
  14. from utils.helpers import colorize_mask
  15. from utils import palette
  16. class EdgeFenceSeg(object):
  17. def __init__(self,
  18. model_path="ckpt/best_model.pth",
  19. area_thres=0.1,
  20. cuda_id=0,
  21. get_mask=True):
  22. torch.set_num_threads(8)
  23. self.area_thres = area_thres
  24. self.num_classes = 2 # background + fence
  25. self.get_mask = get_mask
  26. # get palette of VOC
  27. self.my_palette = palette.get_voc_palette(self.num_classes)
  28. # data setting
  29. self._MEAN = [0.48311856, 0.49071315, 0.45774156]
  30. self._STD = [0.21628413, 0.22036915, 0.22477823]
  31. self.to_tensor = transforms.ToTensor()
  32. self.normalize = transforms.Normalize(self._MEAN, self._STD)
  33. # get Model
  34. self.model = DeepLab(num_classes=self.num_classes, backbone='resnet101', pretrained=False)
  35. availble_gpus = list(range(torch.cuda.device_count()))
  36. self.device = torch.device('cuda:{}'.format(cuda_id) if len(availble_gpus) > 0 else 'cpu')
  37. # Load checkpoint
  38. checkpoint = torch.load(model_path, map_location=self.device)
  39. if isinstance(checkpoint, dict) and 'state_dict' in checkpoint.keys():
  40. checkpoint = checkpoint['state_dict']
  41. # If during training, we used data parallel
  42. if ('module' in list(checkpoint.keys())[0] and
  43. not isinstance(self.model, torch.nn.DataParallel)):
  44. # for gpu inference, use data parallel
  45. if "cuda" in self.device.type:
  46. self.model = torch.nn.DataParallel(self.model)
  47. else:
  48. # for cpu inference, remove module
  49. new_state_dict = OrderedDict()
  50. for k, v in checkpoint.items():
  51. name = k[7:]
  52. new_state_dict[name] = v
  53. checkpoint = new_state_dict
  54. # load
  55. self.model.load_state_dict(checkpoint)
  56. self.model.to(self.device)
  57. self.model.eval()
  58. def predict(self, img):
  59. """
  60. :param img: image for predict, np.ndarray.
  61. :return: mask_img, prediction, flag;
  62. if all None, means image type error; if mask_img is None, means don't extract mask.
  63. """
  64. if str(type(img)) == "<class 'NoneType'>":
  65. return None, None, None
  66. flag = False
  67. if isinstance(img, np.ndarray):
  68. img = Image.fromarray(img)
  69. with torch.no_grad():
  70. input = self.normalize(self.to_tensor(img)).unsqueeze(0)
  71. prediction = self.model(input.to(self.device))
  72. prediction = prediction.squeeze(0).cpu().numpy()
  73. prediction = F.softmax(torch.from_numpy(prediction),
  74. dim=0).argmax(0).cpu().numpy()
  75. area_ratio = sum(prediction[prediction == 1])/(img.size[0]*img.size[1])
  76. if area_ratio >= self.area_thres:
  77. flag = True
  78. if self.get_mask:
  79. mask_img = self.colored_mask_img(img, prediction)
  80. return mask_img, prediction, flag
  81. else:
  82. return None, prediction, flag
  83. def colored_mask_img(self, image, mask):
  84. colorized_mask = colorize_mask(mask, self.my_palette)
  85. # PIL type
  86. mask_img = Image.blend(image.convert('RGBA'), colorized_mask.convert('RGBA'), 0.7)
  87. return mask_img

2.3、单图推理

可以定义一个main.py文件,利用上述class进行单张图片的推理:

  1. import os
  2. import time
  3. import cv2
  4. import numpy as np
  5. from EdgeFenceSeg import EdgeFenceSeg
  6. if __name__ == "__main__":
  7. img_file = "test_imgs/V10108-115508_frame_232.jpg"
  8. output_path = "output_cv"
  9. if not os.path.exists(output_path):
  10. os.makedirs(output_path)
  11. edg = EdgeFenceSeg(area_thres=0.1, cuda_id=0, get_mask=True)
  12. img = cv2.imread(img_file)
  13. for i in range(4):
  14. # inference
  15. t1 = time.time()
  16. mask_img, prediction, flag = edg.predict(img)
  17. t2 = time.time()
  18. print("time: {}, is edge_fence: {}".format(round(t2 - t1, 4), flag))
  19. # save masked img
  20. if mask_img is not None:
  21. image_file = os.path.basename(img_file).split('.')[0]
  22. # mask_img_cv = cv2.cvtColor(np.asarray(mask_img), cv2.COLOR_RGB2BGR)
  23. mask_img_cv = np.asarray(mask_img)
  24. cv2.imwrite(os.path.join(output_path, image_file + '.png'), mask_img_cv)

2.4、批量推理

如果想对一批图片进行推理,由于我们的网络是全卷积的,所以输入图片的尺寸可以是任意的,这里我们使用原图大小以获取更好的性能,因此批量推理也是加个for循环进行的逐张推理:

  1. import os
  2. import time
  3. from glob import glob
  4. import cv2
  5. import numpy as np
  6. from EdgeFenceSeg import EdgeFenceSeg
  7. if __name__ == "__main__":
  8. imgs_path = "test_imgs"
  9. output_path = "output_cv"
  10. if not os.path.exists(output_path):
  11. os.makedirs(output_path)
  12. edg = EdgeFenceSeg()
  13. image_files = sorted(glob(os.path.join(imgs_path, f'*.{"jpg"}')))
  14. for img_file in image_files:
  15. t0 = time.time()
  16. img = cv2.imread(img_file)
  17. # inference
  18. t1 = time.time()
  19. mask_img, prediction, flag = edg.predict(img)
  20. t2 = time.time()
  21. print("{0:50}: Inference time: {1}, Is edge_fence: {2}".format(img_file, round(t2 - t1, 4), flag))
  22. # save masked img
  23. if mask_img is not None:
  24. image_file = os.path.basename(img_file).split('.')[0]
  25. # mask_img_cv = cv2.cvtColor(np.asarray(mask_img), cv2.COLOR_RGB2BGR)
  26. mask_img_cv = np.asarray(mask_img)
  27. cv2.imwrite(os.path.join(output_path, image_file + '.png'), mask_img_cv)

3、写在后面

本篇博客是本系列实战教程的最后一篇了,至此,如果你按照这几篇博客的描述一步步走过来,应该已经掌握了语义分割的入坑基本流程了。

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/笔触狂放9/article/detail/901127
推荐阅读
相关标签
  

闽ICP备14008679号