当前位置:   article > 正文

语义分割数据集图像映射、ValueError: Target size (torch.Size([4, 20, 320, 320])) must be the same as input size (_语义分割数据集颜色映射

语义分割数据集颜色映射

目录

在处理语义分割数据集时需要将RGB的mask图像按照类别进一步处理成对应的灰度图。

1.结果图

2.做法(RGBmask得到label灰度图):

2.1 修改自己的Clss字典

2.2 修改问价夹

3. 灰度图label恢复RGBmask


在处理语义分割数据集时需要将RGB的mask图像按照类别进一步处理成对应的灰度图

1.结果图

以下图为例。

原图:

在这里插入图片描述

使用labelme打完标签之后的RGBmask图: 


把上图按像素点类别映射成灰度图,得到label图如下(由于灰度值是0-4,灰度值较小,所以肉眼看起来看不到区别):

2.做法(RGBmask得到label灰度图):

2.1 修改自己的Clss字典

多类别的时候,指定自己的mask图每一种颜色(RGB值)代表什么类别。

  1. Cls = namedtuple('cls', ['name', 'id', 'color']) # 姓名元组
  2. Clss = [
  3. Cls('_background_', 0, (0, 0, 0)),
  4. Cls('wall', 1, (0, 0, 128)),
  5. Cls('car', 2, (128, 0, 0)),
  6. Cls('ground', 3, (128, 128, 0)),
  7. Cls('door', 4, (0, 128, 0)),
  8. ] # 把每一个类别使用id与像素值对应上,并且颜色通道值按照r,g,b的顺序排列
  9. # 其中(0,0,0)代表黑色
  10. # (0,0,128)代表蓝色
  11. # (128,0,0)代表红色
  12. # (128,128,0)表示红绿混合色黄色
  13. # (0,128,0)代表绿色

2.2 修改问价夹

修改放置灰度图的文件夹名字和放置RGB图的文件夹名字(我这里是把RGB图放到了F盘的mask目录里),之后通过

color_gray(color_dict)

函数run一下后就在gray文件夹里面得到了label图。

  1. gts_gray_path = 'F:\gray' # 放置灰度图的文件夹名字
  2. gts_color_path = 'F:\mask' # 放置RGB图的文件夹名字

完整代码如下: 

  1. import os, sys, time, cv2
  2. import numpy as np
  3. from collections import namedtuple
  4. # 类别信息
  5. gts_gray_path = 'F:\gray' # 放置灰度图的文件夹名字
  6. gts_color_path = 'F:\mask' # 放置RGB图的文件夹名字
  7. # gts_color_path = 'color' # 放置RGB图的文件夹名字
  8. Cls = namedtuple('cls', ['name', 'id', 'color']) # 姓名元组
  9. Clss = [
  10. Cls('_background_', 0, (0, 0, 0)),
  11. Cls('wall', 1, (0, 0, 128)),
  12. Cls('car', 2, (128, 0, 0)),
  13. Cls('ground', 3, (128, 128, 0)),
  14. Cls('door', 4, (0, 128, 0)),
  15. ] # 把每一个类别使用id与像素值对应上,并且颜色通道值按照r,g,b的顺序排列
  16. # 其中(0,0,0)代表黑色
  17. # (0,0,128)代表蓝色
  18. # (128,0,0)代表红色
  19. # (128,128,0)表示红绿混合色黄色
  20. # (0,128,0)代表绿色
  21. def gray_color(color_dict, gray_path=gts_gray_path, color_path=gts_color_path):
  22. '''
  23. swift gray image to color, by color mapping relationship
  24. :param color_dict:color mapping relationship, dict format
  25. :param gray_path:gray imgs path
  26. :param color_path:color imgs path
  27. :return:
  28. '''
  29. pass
  30. t1 = time.time()
  31. gt_list = os.listdir(gray_path)
  32. for index, gt_name in enumerate(gt_list):
  33. gt_gray_path = os.path.join(gray_path, gt_name)
  34. gt_color_path = os.path.join(color_path, gt_name)
  35. gt_gray = cv2.imread(gt_gray_path, cv2.IMREAD_GRAYSCALE)
  36. assert len(gt_gray.shape) == 2 # make sure gt_gray is 1band
  37. # # region method 1: swift by pix, slow
  38. # gt_color = np.zeros((gt_gray.shape[0],gt_gray.shape[1],3),np.uint8)
  39. # for i in range(gt_gray.shape[0]):
  40. # for j in range(gt_gray.shape[1]):
  41. # gt_color[i][j] = color_dict[gt_gray[i][j]] # gray to color
  42. # # endregion
  43. # region method 2: swift by array
  44. # gt_color = np.array(np.vectorize(color_dict.get)(gt_gray),np.uint8).transpose(1,2,0)
  45. # endregion
  46. # region method 3: swift by matrix, fast
  47. gt_color = matrix_mapping(color_dict, gt_gray)
  48. # endregion
  49. gt_color = cv2.cvtColor(gt_color,
  50. cv2.COLOR_RGB2BGR) # cv2.cvtColor(p1,p2) 是颜色空间转换函数,p1是需要转换的图片,p2是转换成何种格式。,得到的图矩阵是按R,G,B排列的,CV2是按B,G,R排列的
  51. cv2.imwrite(gt_color_path, gt_color, )
  52. cv2.imshow('color image:', gt_color)
  53. cv2.waitKey(0)
  54. process_show(index + 1, len(gt_list))
  55. print(time.time() - t1)
  56. def color_gray(color_dict, color_path=gts_color_path, gray_path=gts_gray_path, ):
  57. '''
  58. swift color image to gray, by color mapping relationship
  59. :param color_dict:color mapping relationship, dict format
  60. :param gray_path:gray imgs path
  61. :param color_path:color imgs path
  62. :return:
  63. '''
  64. gray_dict = {}
  65. for k, v in color_dict.items():
  66. gray_dict[v] = k
  67. t1 = time.time()
  68. gt_list = os.listdir(color_path)
  69. for index, gt_name in enumerate(gt_list):
  70. gt_gray_path = os.path.join(gray_path, gt_name)
  71. gt_color_path = os.path.join(color_path, gt_name)
  72. color_array = cv2.imread(gt_color_path, cv2.IMREAD_COLOR)
  73. assert len(color_array.shape) == 3
  74. print(color_array.shape)
  75. gt_gray = np.zeros((color_array.shape[0], color_array.shape[1]), np.uint8)
  76. b, g, r = cv2.split(color_array)
  77. # zeros = np.zeros(color_array.shape[:2], dtype="uint8")
  78. # merged_b= cv2.merge([zeros, g, zeros])
  79. # cv2.imshow('b',merged_b)
  80. # cv2.waitKey(0)
  81. color_array = np.array([r, g, b])
  82. for cls_color, cls_index in gray_dict.items(): # 将图像的像素点的三个通道值与约定的类别代表的像素值逐个逐个地比较,直到可以将整个图像遍历
  83. cls_pos = arrays_jd(color_array, cls_color) # 将图像的像素点的三个通道的像素与约定类别对应的三通道像素值进行比较
  84. # print(cls_pos.shape,gt_gray.shape,cls_index)
  85. gt_gray[cls_pos] = cls_index # 如果比较发现是这个类别的三通道像素值,那么就把这个地方的灰度值换成对应的类别值
  86. # 其中cls_pos的形状必须要与gt_gray一致,cls_pos是有False和True构成的数组,True代表这个位置的数字替换,False代表这个位置的数字不替换。
  87. # print(gt_gray)
  88. cv2.imwrite(gt_gray_path, gt_gray)
  89. # cv2.imshow('gray image:', gt_gray)
  90. # cv2.waitKey(0)
  91. process_show(index + 1, len(gt_list))
  92. print(time.time() - t1)
  93. def arrays_jd(arrays, cond_nums): # 用于比较像素通道值是否属于某一个类别
  94. r = arrays[0] == cond_nums[0] # 比较r通道的像素值
  95. g = arrays[1] == cond_nums[1] # 比较g通道的像素值
  96. b = arrays[2] == cond_nums[2] # 比较b通道的像素值
  97. # print(r & g & b)
  98. return r & g & b # 只有三个通道都与约定的id对应的三个通道颜色像素值都一致时才认为这个像素点是属于这个id对应的类别
  99. def matrix_mapping(color_dict, gt): # 用于将灰度图按照类别转化为三通道图
  100. colorize = np.zeros([len(color_dict), 3], 'uint8')
  101. # print(colorize.shape)
  102. for cls, color in color_dict.items():
  103. colorize[cls, :] = list(color) # 直接替换行,得到一个代表像素点类别的带三通道值的矩阵。
  104. # print(colorize)
  105. # print(gt)
  106. ims = colorize[gt,
  107. :] # 将gt里面的元素作为索引逐个放入colorize中,当成访问colorize中的每一行,从而按照colorize的每一行按照灰度值代表的id逐个填入,从而将colorize从二维变成三维,按照灰度图的id索引将图像还原为BGR图像
  108. # print(ims)
  109. # ims = ims.reshape([gt.shape[0], gt.shape[1], 3])
  110. return ims
  111. def nt_dic(nt=Clss):
  112. '''
  113. swift nametuple to color dict
  114. :param nt: nametuple
  115. :return:
  116. '''
  117. pass
  118. color_dict = {}
  119. for cls in nt:
  120. color_dict[cls.id] = cls.color
  121. return color_dict
  122. def process_show(num, nums, pre_fix='', suf_fix=''):
  123. '''
  124. auxiliary function, print work progress
  125. :param num:
  126. :param nums:
  127. :param pre_fix:
  128. :param suf_fix:
  129. :return:
  130. '''
  131. rate = num / nums
  132. ratenum = round(rate, 3) * 100
  133. bar = '\r%s %g/%g [%s%s]%.1f%% %s' % \
  134. (pre_fix, num, nums, '#' * (int(ratenum) // 5), '_' * (20 - (int(ratenum) // 5)), ratenum, suf_fix)
  135. sys.stdout.write(bar)
  136. sys.stdout.flush()
  137. if __name__ == '__main__':
  138. pass
  139. color_dict = nt_dic()
  140. # gray_color(color_dict)
  141. color_gray(color_dict)
  142. gt_gray = np.zeros((2, 2), np.uint8)
  143. cls_pos = np.array([[True, False], [True, True]])
  144. gt_gray[cls_pos] = 9
  145. print(gt_gray)

3. 灰度图label恢复RGBmask

代码中使用函数:

gray_color(color_dict)

color_gray(color_dict)注释掉

并指定好映射和文件夹位置

pycharm中点击运行if __name__ == '__main__':就可以将灰度label恢复为RGBmask图了。

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/不正经/article/detail/511645
推荐阅读
相关标签
  

闽ICP备14008679号