当前位置:   article > 正文

车道线分割项目记录-tusimple数据集处理

车道线分割项目记录-tusimple数据集处理

一、数据集包含信息

该项目训练所使用的数据集是tusimple数据集,其中用于训练及验证的有约3500张图,测试的有2000多张图。数据集中,除了图片,还包含了json文件,携带了车道线信息、文件路径。每一条数据如下所示:

{"lanes": [[-2, -2, -2, 348, 358, 357, 352, 347, 341, 331, 316, 301, 286, 271, 256, 241, 226, 211, 196, 182, 167, 152, 137, 122, 107, 92, 77, 62, 47, 32, 17, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2], [-2, -2, -2, 427, 451, 469, 487, 504, 520, 526, 533, 539, 545, 551, 557, 564, 570, 576, 582, 588, 595, 601, 607, 613, 619, 626, 632, 638, 644, 650, 657, 663, 669, 675, 681, 688, 694, 700, 706, 712, 719, 725, 731, 737, 743, 750, -2, -2], [-2, -2, -2, 274, 263, 253, 227, 200, 173, 146, 118, 91, 64, 37, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2], [-2, -2, -2, -2, 552, 601, 649, 690, 721, 751, 782, 813, 844, 874, 905, 936, 967, 998, 1028, 1059, 1090, 1121, 1151, 1182, 1213, 1244, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2]], "h_samples": [240, 250, 260, 270, 280, 290, 300, 310, 320, 330, 340, 350, 360, 370, 380, 390, 400, 410, 420, 430, 440, 450, 460, 470, 480, 490, 500, 510, 520, 530, 540, 550, 560, 570, 580, 590, 600, 610, 620, 630, 640, 650, 660, 670, 680, 690, 700, 710], "raw_file": "clips/0313-2/35080/20.jpg"}

可以看到,这里面有"lanes", "h_samples", "raw_file"三部分数据。

其中,"lanes"记录的是车道线的横坐标,就是图片的宽的坐标,"h_samples"记录的是纵坐标,就是图片的高的坐标,"raw_file"记录的是文件路径。

车道线可能有多条,比如上面这个例子里面就有4条,那么横坐标就有4组,纵坐标有1组,因此横纵坐标合在一起,构成了车道线所对应的点。

 

二、数据集创建

1.创建标签

标签只有车道线的几个点,但是模型需要的,一个是二值的语义分割标签,里面1的地方是车道线,其他地方是0,另一个是实例分割标签,比如不是车道线的地方是0,第一条是1,第二条是2等等。

这里,我创建语义分割标签图的时候,是通过cv2.line,把前后点连接起来,设定一个宽度,使标签变为以下这样:

 同样,创建实例分割标签的时候,方式一致,但是填的不全是1,而是1,2,3,4,就像下面这样(示例):

注意:创建完标签,可以保存为图像,保存为png格式,读取的时候,使用cv2.imread的时候,注意第二个参数传-1,cv2.imread(img_path, -1),就可以按照存的方式读取了。否则你如果保存的是二值图,读取的时候可能会变成三通道的。代码如下

  1. def get_img_path_lanes(json_path):
  2. path_lane_data = []
  3. for file_path in json_path:
  4. with open(file_path,'r',encoding='utf-8') as f:
  5. data = f.readlines()
  6. for line in data:
  7. dicts = json.loads(line)
  8. lane_xy = []
  9. for lane in dicts['lanes']:
  10. y = np.array([dicts['h_samples']]).T
  11. x = np.array([lane]).T
  12. lane_xy.append(np.hstack((x,y)))
  13. path_lane_data.append([dicts['raw_file'],lane_xy])
  14. return path_lane_data
  15. def generate_labels():
  16. train_list = ['./data/train_set/label_data_0313.json', './data/train_set/label_data_0601.json']
  17. val_list = ['./data/train_set/label_data_0531.json']
  18. test_list = ['./data/test_set/test_tasks_0627.json']
  19. path_lane_train = get_img_path_lanes(train_list)
  20. path_lane_val = get_img_path_lanes(val_list)
  21. path_lane_test = get_img_path_lanes(test_list)
  22. paths = [path_lane_train, path_lane_val]
  23. label_folder = 'seg_label'
  24. # 创建语义分割标签图并保存
  25. if not os.path.exists('./seg_label/0313-1/6040'):
  26. print('未找到语义分割标签图路径,正在处理......')
  27. for one_path in paths:
  28. for item in one_path:
  29. item_save_path = label_folder+'/'+'/'.join(item[0].split('/')[1:3])
  30. # 找到并读取图片
  31. path = os.path.join('./data/train_set/', item[0])
  32. img = cv2.imread(path)
  33. # 获取高和宽
  34. H,W = img.shape[0], img.shape[1]
  35. # 创建标签图,全0
  36. mask = np.zeros((H,W))
  37. for k,lane in enumerate(item[1]):
  38. # 选择第一张图的车道线来画,就画4根
  39. if k == 4:
  40. continue
  41. for j in range(len(lane)-1):
  42. if lane[j][0] != -2 and lane[j+1][0] != -2:
  43. cv2.line(mask, tuple(lane[j]),tuple(lane[j+1]),1,8)
  44. # cv_show(mask)
  45. # 存储Mask
  46. if not os.path.exists(item_save_path):
  47. os.makedirs(item_save_path)
  48. cv2.imwrite(item_save_path+'/'+(item[0].split('/')[3]).split('.')[0]+'.png',mask)
  49. print('训练+验证集语义分割标签图保存完毕!')
  50. # 创建实例分割标签图并保存
  51. if not os.path.exists('./seg_label/instance_div/0313-1/6040'):
  52. print('未找到实例分割标签图路径,正在处理......')
  53. for one_path in paths:
  54. for item in one_path:
  55. item_save_path = label_folder+'/instance_div/'+'/'.join(item[0].split('/')[1:3])
  56. # 找到并读取图片
  57. path = os.path.join('./data/train_set/', item[0])
  58. img = cv2.imread(path)
  59. # 获取高和宽
  60. H,W = img.shape[0], img.shape[1]
  61. # 创建标签图,全0
  62. mask = np.zeros((H,W))
  63. for k,lane in enumerate(item[1]):
  64. # 选择第一张图的车道线来画,就画4根
  65. if k == 4:
  66. continue
  67. for j in range(len(lane)-1):
  68. if lane[j][0] != -2 and lane[j+1][0] != -2:
  69. cv2.line(mask, tuple(lane[j]),tuple(lane[j+1]),k+1,8)
  70. # cv_show(mask)
  71. # 存储Mask
  72. if not os.path.exists(item_save_path):
  73. os.makedirs(item_save_path)
  74. cv2.imwrite(item_save_path+'/'+(item[0].split('/')[3]).split('.')[0]+'.png',mask)
  75. print('训练+验证集实例分割标签图保存完毕!')
  76. # 创建txt文件路径并保存
  77. if not os.path.exists('./seg_label/train.txt'):
  78. print('未找到图片及标签汇总文件路径,正在处理......')
  79. file_paths = ['./seg_label/train.txt','./seg_label/val.txt','./seg_label/test.txt']
  80. for i,one_path in enumerate([path_lane_train,path_lane_val,path_lane_test]):
  81. if i == 0:
  82. with open(file_paths[i],'w',encoding='utf-8') as f:
  83. for item in one_path:
  84. num = '/'+(item[0].split('/')[3]).split('.')[0]
  85. f.write('data/train_set/'+item[0]+' '+label_folder+'/'+'/'.join(item[0].split('/')[1:3])+num+'.png'+' '
  86. + label_folder+'/instance_div/'+'/'.join(item[0].split('/')[1:3])+num+'.png'+'\n')
  87. if i == 1:
  88. with open(file_paths[i],'w',encoding='utf-8') as f:
  89. for item in one_path:
  90. num = '/'+(item[0].split('/')[3]).split('.')[0]
  91. f.write('data/train_set/'+item[0]+' '+label_folder+'/'+'/'.join(item[0].split('/')[1:3])+num+'.png'+' '
  92. + label_folder+'/instance_div/'+'/'.join(item[0].split('/')[1:3])+num+'.png'+'\n')
  93. if i == 2:
  94. with open(file_paths[i],'w',encoding='utf-8') as f:
  95. for item in one_path:
  96. f.write('data/test_set/'+item[0]+'\n')
  97. print('图片、语义分割标签、实例分割标签路径保存完毕!')

2.创建dataset类

这里还是创建的时候继承自torch.utils.data.Dataset类就可以,在类中实现__getitem__以及__len__,这里由于训练数据有3000多张,一次性导入内存的话,电脑会卡死,因此我改变了导入方式,只导入路径,__len__这里返回的也是路径列表,而在__getitem__里面,则是去索引路径信息,传入到一个专门读取并处理图片的函数中就可以了,处理的时候增加一些随机性,这里我加了亮度对比度的随机调整,随机旋转。

  1. class LaneDataset(Dataset):
  2. '''
  3. 输入训练数据路径、验证数据路径、测试数据路径,初始化的时候会自动创建索引文件,根据输入的mode不同,返回不同的处理好的样本。
  4. 如果是train或者val,返回的是处理过的图片+处理过的标签(2个标签,一个是语义分割,一个是实例分割)
  5. 如果是test,返回的是处理过的图片
  6. 通过DataLoader之后,取出的是一个list,第一个元素是(B,3,H,W),第二个元素是List,有B个元素,每一个都是包含(2,H,W)的标签
  7. '''
  8. def __init__(self,resize_shape=(640,360), transform=None,rotate_theta=2, mode='train'):
  9. super(LaneDataset, self).__init__()
  10. self.transforms = transform
  11. self.mode = mode
  12. self.resize_shape = resize_shape
  13. self.rotate_theta=rotate_theta
  14. prepared_file_paths = ['./seg_label/train.txt','./seg_label/instance_div','./seg_label/0313-1']
  15. for prepared in prepared_file_paths:
  16. if not os.path.exists(prepared):
  17. print('预备文件路径缺少{},开始准备...'.format(prepared))
  18. generate_labels()
  19. self.data_list = self.get_path()[:800]
  20. def __len__(self):
  21. return len(self.data_list)
  22. def __getitem__(self, idx):
  23. if self.mode == 'train' or self.mode == 'val':
  24. processed_img, processed_labels = self.preprocess_data(self.data_list[idx])
  25. return processed_img,processed_labels
  26. elif self.mode=='test':
  27. processed_img = self.preprocess_data(self.data_list[idx])
  28. return processed_img
  29. def get_path(self):
  30. if self.mode not in ['train','val','test']:
  31. raise Exception('数据应当是train, val, test三者之一')
  32. # 根据mode取出标签
  33. modes = ['train','val']
  34. if self.mode in modes:
  35. data_list = []
  36. with open('./seg_label/{}.txt'.format(self.mode), 'r', encoding='utf-8') as f:
  37. all_paths = f.readlines()
  38. for path in all_paths:
  39. # 每一个path.split,都是[“图片路径”,“语义分割标签路径”,“实例分割标签路径”]
  40. data_list.append(path.strip().split())
  41. else:
  42. data_list = []
  43. with open('./seg_label/test.txt', 'r', encoding='utf-8') as f:
  44. all_paths = f.readlines()
  45. for path in all_paths:
  46. data_list.append(path.strip())
  47. return data_list
  48. def preprocess_data(self, data_list):
  49. '''
  50. :return: 如果是train和val,就返回图片+标签数据,如果是test,就只返回测试图片
  51. '''
  52. if self.mode not in ['train','val','test']:
  53. raise Exception('数据应当是train, val, test三者之一')
  54. # 根据mode取出标签
  55. modes = ['train','val']
  56. if self.mode in modes:
  57. label_list = []
  58. img = cv2.imread(data_list[0])
  59. label_list.append(cv2.imread(data_list[1],-1))
  60. label_list.append(cv2.imread(data_list[2],-1))
  61. # 使用图本身的均值方差进行标准化,这块效果不好,还是使用网上通用的均值和方差吧。
  62. # 需要注意的是,如果以下面的方式操作,数据类型会发生改变,导致在dataloader拿数据
  63. # 的时候,出现错误,因此一定要记得把数据类型还原回uint8
  64. mean,std = cv2.meanStdDev(img)
  65. b,g,r = cv2.split(img)
  66. b1 = (b - mean[0]) / (1.e-6 + std[0])
  67. g1 = (g - mean[1]) / (1.e-6 + std[1])
  68. r1 = (r - mean[2]) / (1.e-6 + std[2])
  69. img = cv2.merge([b1, g1, r1]).astype('uint8')
  70. # # 转成RGB
  71. # img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
  72. # 预处理:亮度对比度
  73. img = self.bright_contra_adjust(img)
  74. # resize
  75. img = cv2.resize(img, self.resize_shape, interpolation=cv2.INTER_CUBIC)
  76. label1 = cv2.resize(label_list[0], self.resize_shape, interpolation=cv2.INTER_NEAREST)
  77. label2 = cv2.resize(label_list[1], self.resize_shape, interpolation=cv2.INTER_NEAREST)
  78. # Rotation
  79. u = np.random.uniform()
  80. degree = (u-0.5) * self.rotate_theta
  81. R = cv2.getRotationMatrix2D((img.shape[1]//2, img.shape[0]//2),degree,1)
  82. img_rotate = cv2.warpAffine(img,R,(img.shape[1], img.shape[0]), flags=cv2.INTER_LINEAR)
  83. label1_rotate = cv2.warpAffine(label1,R,(label1.shape[1], label1.shape[0]), flags=cv2.INTER_NEAREST)
  84. label2_rotate = cv2.warpAffine(label2, R, (label2.shape[1], label2.shape[0]), flags=cv2.INTER_NEAREST)
  85. # transform
  86. img = self.transforms(img_rotate)
  87. return img, [label1_rotate, label2_rotate]
  88. else:
  89. img = cv2.imread(data_list)
  90. mean,std = cv2.meanStdDev(img)
  91. b,g,r = cv2.split(img)
  92. b1 = (b - mean[0]) / (1.e-6 + std[0])
  93. g1 = (g - mean[1]) / (1.e-6 + std[1])
  94. r1 = (r - mean[2]) / (1.e-6 + std[2])
  95. img = cv2.merge([b1, g1, r1]).astype('uint8')
  96. # img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
  97. img = cv2.resize(img, self.resize_shape, interpolation=cv2.INTER_CUBIC)
  98. img = self.transforms(img)
  99. return img
  100. def bright_contra_adjust(self, img):
  101. '''亮度对比度调整,随机增加或减少0-10'''
  102. contra = random.uniform(0.85,1.15)
  103. bright = random.randint(-30,20)
  104. if random.uniform(0,1) > 0.5:
  105. return img
  106. else:
  107. img = img.astype(np.int)
  108. return np.uint8(np.clip(img*contra+ bright, 0, 255))
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/凡人多烦事01/article/detail/345249
推荐阅读
  

闽ICP备14008679号