当前位置:   article > 正文

可视化特征图:python读取pth模型,并可视化特征图。亲测有效。_pth模型可视化

pth模型可视化

一、前言

我们有时候需要可视化特征图,尤其是发paper,或者对比算法等情况。而且通过可视化特征图,也可以让我们对这个整个cnn模型更加熟悉,废话不多说了。

二、效果图

下面我会给出代码,效果图分为单channel绘图和1:1通道特征图融合图。

我生成了很多特征图,我就简单的放两张吧,意思意思。

                                                             单通道特征图

                                                              叠加后的特征图

三、代码

我再次描述清楚我的需求以及我现有的东西,我有网络的结构和网络的预训练权重,我想通过输入图片,得到图片在网络特定层的特征图。

从main()开始看代码,我会说得详细一点,尽量让大家看懂, 这样你修改起来会方便很多。

图片保存和读取的路径相关的问题,我就不说了,这里大家应该懂。

1.首先我们看导入的包,DepthCompletionFrontNet 这是我的网络结构,首先你要搭建起的你的网络(这个得有)。

2.看main()函数,定位到get_feature()函数

3.get_feature做了下面得几个事儿,第一,读取图片,也就是要输入网络得图片(我得网络是双分支,所以是读取两个图,这里你读取一个图就行,就 img_rgb 就行,把 img_pc 相关内容注释);第二,定义网络,实例化,载入预训练权重模型;第三,定义我们要提取出得特定层,这里必须和你网络结构定义得层一模一样,一模一样,一模一样。

4.已经定义的网络结构需要进行修改,假设你网络定义的代码如下:

  1. # 仅仅举例子,我懒得补全了,直接csdn手打的
  2. class Net(nn.Module):
  3. super(Net,self).__init__()
  4. self.conv1 = nn.conv1
  5. self.conv2 = nn.conv2
  6. self.conv3 = nn.conv3
  7. forward(self,x):
  8. x = conv1(x)
  9. x = conv2(x)
  10. x = conv3(x)
  11. return x

网络的定义不需要修改,我们需要修改下网络的 forward,加入字典 all_dict去存储每层的tensor,forward修改如下:

  1. forward(self,x):
  2. all_dict = {}
  3. x = conv1(x)
  4. all_dict['conv1'] = x
  5. x = conv2(x)
  6. all_dict['conv2'] = x
  7. x = conv3(x)
  8. all_dict['conv3'] = x
  9. return x,all_dict

这样子就修改完成了

总结一下:首先读入模型和图片,图片在前向传播的过程中,我们通过字典保存每层的tensor,需要提取哪层,就从哪层去获取tensor,进而可视化。

大家有问题可以留言,我看到一定会回复。如可以运行,麻烦点赞下,谢谢!希望帮到大家。

 

 

完整代码如下(网络结构我的很复杂,就不放了, 网络结构修改就像上面我说的一样,你可以直接读取img_rgb,在模型的前向传播输入img_rgb,我的网络是双分支,所以我输入两个图组合的字典):

  1. import torch
  2. import torchvision.transforms as transforms
  3. import skimage.data
  4. import skimage.io
  5. import skimage.transform
  6. import numpy as np
  7. import matplotlib.pyplot as plt
  8. from completion_segmentation_model import DepthCompletionFrontNet
  9. # from completion_segmentation_model_v3_eca_attention import DepthCompletionFrontNet
  10. import math
  11. #https://blog.csdn.net/missyougoon/article/details/85645195
  12. # https://blog.csdn.net/grayondream/article/details/99090247
  13. # 定义是否使用GPU
  14. device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  15. # 定义数据预处理方式(将输入的类似numpy中arrary形式的数据转化为pytorch中的张量(tensor))
  16. transform = transforms.ToTensor()
  17. def get_picture(picture_dir, transform):
  18. '''
  19. 该算法实现了读取图片,并将其类型转化为Tensor
  20. '''
  21. img = skimage.io.imread(picture_dir)
  22. img256 = skimage.transform.resize(img, (128, 256))
  23. img256 = np.asarray(img256)
  24. img256 = img256.astype(np.float32)
  25. return transform(img256)
  26. def get_picture_rgb(picture_dir):
  27. '''
  28. 该函数实现了显示图片的RGB三通道颜色
  29. '''
  30. img = skimage.io.imread(picture_dir)
  31. img256 = skimage.transform.resize(img, (256, 256))
  32. skimage.io.imsave('4.jpg', img256)
  33. # 取单一通道值显示
  34. # for i in range(3):
  35. # img = img256[:,:,i]
  36. # ax = plt.subplot(1, 3, i + 1)
  37. # ax.set_title('Feature {}'.format(i))
  38. # ax.axis('off')
  39. # plt.imshow(img)
  40. # r = img256.copy()
  41. # r[:,:,0:2]=0
  42. # ax = plt.subplot(1, 4, 1)
  43. # ax.set_title('B Channel')
  44. # # ax.axis('off')
  45. # plt.imshow(r)
  46. # g = img256.copy()
  47. # g[:,:,0]=0
  48. # g[:,:,2]=0
  49. # ax = plt.subplot(1, 4, 2)
  50. # ax.set_title('G Channel')
  51. # # ax.axis('off')
  52. # plt.imshow(g)
  53. # b = img256.copy()
  54. # b[:,:,1:3]=0
  55. # ax = plt.subplot(1, 4, 3)
  56. # ax.set_title('R Channel')
  57. # # ax.axis('off')
  58. # plt.imshow(b)
  59. # img = img256.copy()
  60. # ax = plt.subplot(1, 4, 4)
  61. # ax.set_title('image')
  62. # # ax.axis('off')
  63. # plt.imshow(img)
  64. img = img256.copy()
  65. ax = plt.subplot()
  66. ax.set_title('image')
  67. # ax.axis('off')
  68. plt.imshow(img)
  69. plt.show()
  70. def visualize_feature_map_sum(item,name):
  71. '''
  72. 将每张子图进行相加
  73. :param feature_batch:
  74. :return:
  75. '''
  76. feature_map = item.squeeze(0)
  77. c = item.shape[1]
  78. print(feature_map.shape)
  79. feature_map_combination=[]
  80. for i in range(0,c):
  81. feature_map_split = feature_map.data.cpu().numpy()[i, :, :]
  82. feature_map_combination.append(feature_map_split)
  83. feature_map_sum = sum(one for one in feature_map_combination)
  84. # feature_map = np.squeeze(feature_batch,axis=0)
  85. plt.figure()
  86. plt.title("combine figure")
  87. plt.imshow(feature_map_sum)
  88. plt.savefig('E:/Dataset/qhms/feature_map/feature_map_sum_'+name+'.png') # 保存图像到本地
  89. plt.show()
  90. def get_feature():
  91. # 输入数据
  92. root_path = 'E:/Dataset/qhms/data/small_data/'
  93. pic_dir = 'test_umm_000067.png'
  94. pc_path = root_path+'knn_pc_crop_0.6/'+pic_dir
  95. rgb_path = root_path+'train_image_2_lane_crop_0.6/'+pic_dir
  96. img_rgb = get_picture(rgb_path, transform)
  97. # 插入维度
  98. img_rgb = img_rgb.unsqueeze(0)
  99. img_rgb = img_rgb.to(device)
  100. img_pc = get_picture(pc_path, transform)
  101. # 插入维度
  102. img_pc = img_pc.unsqueeze(0)
  103. img_pc = img_pc.to(device)
  104. # 加载模型
  105. checkpoint = torch.load('E:/Dataset/qhms/all_result/v3/crop_0.6_old/hah/checkpoint-195.pth.tar')
  106. args = checkpoint['args']
  107. print(args)
  108. model = DepthCompletionFrontNet(args)
  109. print(model.keys())
  110. model.load_state_dict(checkpoint['model'])
  111. model.to(device)
  112. exact_list = ["conv1","conv2","conv3","conv4","convt4","convt3","convt2_","convt1_","lane"]
  113. # myexactor = FeatureExtractor(model, exact_list)
  114. img1 = {
  115. 'pc': img_pc, 'rgb': img_rgb
  116. }
  117. # print(img1['pc'])
  118. # x = myexactor(img1)
  119. result,all_dict = model(img1)
  120. outputs = []
  121. # 挑选exact_list的层
  122. for item in exact_list:
  123. x = all_dict[item]
  124. outputs.append(x)
  125. # 特征输出可视化
  126. x = outputs
  127. k=0
  128. print(x[0].shape[1])
  129. for item in x:
  130. c = item.shape[1]
  131. plt.figure()
  132. name = exact_list[k]
  133. plt.suptitle(name)
  134. for i in range(c):
  135. wid = math.ceil(math.sqrt(c))
  136. ax = plt.subplot(wid, wid, i + 1)
  137. ax.set_title('Feature {}'.format(i))
  138. ax.axis('off')
  139. figure_map = item.data.cpu().numpy()[0, i, :, :]
  140. plt.imshow(figure_map, cmap='jet')
  141. plt.savefig('E:/Dataset/qhms/feature_map/feature_map_' + name + '.png') # 保存图像到本地
  142. visualize_feature_map_sum(item,name)
  143. k = k + 1
  144. plt.show()
  145. # 训练
  146. if __name__ == "__main__":
  147. # get_picture_rgb(pic_dir)
  148. get_feature()

 

参考:

  1. https://blog.csdn.net/missyougoon/article/details/85645195
  2. https://blog.csdn.net/grayondream/article/details/99090247

 

 

 

 

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/我家小花儿/article/detail/514423
推荐阅读
相关标签
  

闽ICP备14008679号