当前位置:   article > 正文

【Pytorch神经网络实战案例】27 MaskR-CNN内置模型实现语义分割_pytorch 预训练,图像语义分割代码

pytorch 预训练,图像语义分割代码

1 PyTorch中语义分割的内置模型

torchvision库下的models\segmentation目录中,找到segmentation.Py文件。该文件中存放着PyTorch内置的语义分割模型。

2 MaskR-CNN内置模型实现语义分割

2.1 代码逻辑简述

将COCO 2017数据集上的预训练模型dceplabv3_resnet101_coco加载到内存,并使用该模型对图片进行语义分割。

2.2 代码实现:MaskR-CNN内置模型实现语义分割

Maskrcnn_resent_Semantic_Segmentation.py

  1. import torch
  2. import matplotlib.pyplot as plt
  3. from PIL import Image
  4. import numpy as np
  5. from torchvision import models
  6. from torchvision import transforms
  7. import os
  8. os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"
  9. # 获取模型,如果本地没有缓存,则下载
  10. model = models.segmentation.deeplabv3_resnet101(pretrained=True) # 调用内置模型,并使用预训练权重进行初始化。
  11. model.eval() # 不然报错 Expected more than 1 value per channel when training, got input size torch.Size
  12. # 在图片的数据输入网络之前,对图片进行预处理
  13. transform = transforms.Compose([
  14. transforms.Resize(256), # 将图片尺寸调整为256×256
  15. transforms.CenterCrop(224), # 中心裁剪成224×224
  16. transforms.ToTensor(), # 转换成张量归一化到[0,1]
  17. transforms.Normalize( # 使用均值,方差标准化
  18. mean=[0.485, 0.456, 0.406],
  19. std=[0.229, 0.224, 0.225]
  20. )
  21. ])
  22. def preimg(img): # 定义图片预处理函数
  23. if img.mode == 'RGBA': # 兼容RGBA图片
  24. ch = 4
  25. print('ch', ch)
  26. a = np.asarray(img)[:, :, :3]
  27. img = Image.fromarray(a)
  28. return img
  29. # 加载要预测的图片
  30. img = Image.open('./models_2/mask.jpg') # 将图片输入模型,进行预测。
  31. # 模型预测的输出是一个OrderedDict结构。deeplabv3_resnet101模型的图片输入尺寸是[224,224],输出形状是[1,21,224,224],代表20+1(背景)个类别。
  32. plt.imshow(img)
  33. plt.axis('off')
  34. plt.show() # 显示加载图片
  35. im = preimg(img)
  36. # 对输入数据进行维度扩展,成为NCHW
  37. inputimg = transform(im).unsqueeze(0)
  38. # 显示用transform转化后的图片
  39. tt = np.transpose(inputimg.detach().numpy()[0],(1,2,0))
  40. plt.imshow(tt.astype('uint8')) # 不然报错:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers)
  41. plt.show()
  42. output = model(inputimg) # 将图片输入模型
  43. print("输出结果的形状:",output['out'].shape)
  44. # 去掉批次维度,提取结果。使用argmax函数在每个像素点的21个分类中选出概率值最大的索引,为预测结果。
  45. output = torch.argmax(output['out'].squeeze(), dim=0).detach().cpu().numpy()
  46. resultclass = set(list(output.flat))
  47. print("所发现的分类:",resultclass)
  48. # 所发现的分类.{0,13,15}
  49. # 模型从图中识别出了两个类别的内容。索引值13和15分别对应分类名称“马”和“人”。
  50. def decode_segmap(image,nc=21): # 对图片中的每个像素点根据其所属类别进行染色。不同的类别显示不同的颜色。
  51. label_colors = np.array([(0, 0, 0), # 定义每个分类对应的颜色
  52. (128, 0, 0), (0, 128, 0), (128, 128, 0), (0, 0, 128), (128, 0, 128),
  53. (0, 128, 128), (128, 128, 128), (64, 0, 0), (192, 0, 0), (64, 128, 0),
  54. (192, 128, 0), (64, 0, 128), (192, 0, 128), (64, 128, 128), (192, 128, 128),
  55. (0, 64, 0), (128, 64, 0), (0, 192, 0), (128, 192, 0), (0, 64, 128)])
  56. r = np.zeros_like(image).astype(np.uint8) # 初始化RGB
  57. g = np.zeros_like(image).astype(np.uint8)
  58. b = np.zeros_like(image).astype(np.uint8)
  59. for l in range(0, nc): # 根据预测结果进行染色
  60. idx = image == l
  61. print("idx:",idx)
  62. r[idx] = label_colors[l, 0]
  63. g[idx] = label_colors[l, 1]
  64. b[idx] = label_colors[l, 2]
  65. return np.stack([r, g, b], axis=2) # 返回结果
  66. rgb = decode_segmap(output)
  67. img = Image.fromarray(rgb)
  68. plt.axis('off') # 显示模型的可视化结果
  69. print("快完了")
  70. plt.imshow(img)
  71. plt.show()

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/我家自动化/article/detail/234578
推荐阅读
相关标签
  

闽ICP备14008679号