当前位置:   article > 正文

车道线识别 tusimple 数据集介绍_tusimple数据集

tusimple数据集

1、tusimple 数据集介绍

标注json 文件中每一行包括三个字段 :

raw_file : 每一个数据段的第20帧图像的的 path 路径

lanes 和 h_samples 是数据具体的标注内容,为了压缩,h_sample 是纵坐标(等分确定),lanes 是每个车道的横坐标,是个二维数组。-2 表示这个点是无效的点。

标注的过程应该是,将图片的下半部分如70%*height 等分成N份。然后取车道线(如论虚实)与该标注线交叉的点

上面的数据就有 4 条车道线,第一条车道线的第一个点的坐标是(632,280)。 

 

2、下载数据集

LaneNet车道线检测使用的是Tusimple数据集,下载它

https://github.com/TuSimple/tusimple-benchmark/issues/3

3、样本处理

利用以下脚本可以处理得到标注的数据,这个脚本稍微改动下也可以作为深度学习输入的图像。

  1. # -*- coding: utf-8 -*-
  2. import cv2
  3. import json
  4. import numpy as np
  5. import os
  6. base_path = r"C:\Users\Downloads"
  7. file = open(base_path + '\label_data_0601.json', 'r')
  8. image_num = 0
  9. for line in file.readlines():
  10. data = json.loads(line)
  11. # print data['raw_file']
  12. # 取第 29 帧 看一下处理的效果
  13. if image_num == 2:
  14. image = cv2.imread(os.path.join(base_path, data['raw_file']))
  15. # 二进制图像数组初始化
  16. binaryimage = np.zeros((image.shape[0], image.shape[1], 1), np.uint8)
  17. # 实例图像数组初始化
  18. instanceimage = binaryimage.copy()
  19. arr_width = data['lanes']
  20. arr_height = data['h_samples']
  21. width_num = len(arr_width) # 标注的道路条数
  22. height_num = len(arr_height)
  23. # print width_num
  24. # print height_num
  25. # 遍历纵坐标
  26. for i in range(height_num):
  27. lane_hist = 40
  28. # 遍历各个车道的横坐标
  29. for j in range(width_num):
  30. # 端点坐标赋值
  31. if arr_width[j][i - 1] > 0 and arr_width[j][i] > 0:
  32. binaryimage[int(arr_height[i]), int(arr_width[j][i])] = 255 # 255白色,0是黑色
  33. instanceimage[int(arr_height[i]), int(arr_width[j][i])] = lane_hist
  34. if i > 0:
  35. # 画线,线宽10像素
  36. cv2.line(binaryimage, (int(arr_width[j][i - 1]), int(arr_height[i - 1])),
  37. (int(arr_width[j][i]), int(arr_height[i])), 255, 10)
  38. cv2.line(instanceimage, (int(arr_width[j][i - 1]), int(arr_height[i - 1])),
  39. (int(arr_width[j][i]), int(arr_height[i])), lane_hist, 10)
  40. lane_hist += 50
  41. # cv2.imshow('image.jpg', image)
  42. # cv2.waitKey()
  43. # cv2.imshow('binaryimage.jpg', binaryimage)
  44. # cv2.waitKey()
  45. # cv2.imshow('instanceimage.jpg', instanceimage)
  46. # cv2.waitKey()
  47. string1 = base_path + "\\" + str(image_num+10) + ".png"
  48. string2 = base_path + "\\" + str(image_num+11) + ".png"
  49. string3 = base_path + "\\" + str(image_num+12) + ".png"
  50. cv2.imwrite(string1, binaryimage)
  51. cv2.imwrite(string2, instanceimage)
  52. cv2.imwrite(string3, image)
  53. break
  54. image_num = image_num + 1
  55. file.close()
  56. print("total image_num:" + str(image_num))

处理完之后图片输出如下所示:

 

 

 

 

 

 

Tusimple 数据的标注特点:

1、车道线实际上不只是道路上的标线,虚线被当作了一种实线做处理的。这里面双实线、白线、黄线这类信息也是没有被标注的。

2、每条线实际上是点序列的坐标集合,而不是区域集合

4、创建自己的tusimple数据集格式

第一步:原始数据集标注

1、使用labelme进行数据标注:

在conda里使用指令进行安装labelme

 pip install labelme

2、在环境下使用指令进行启动labelme

labelme

3、进入界面后选择图片,进行线段标记

在顶部edit菜单栏中选择不同的标记方案,依次为:多边形(默认),矩形,圆、直线,点。点击 Create Point,回到图片,左键点击会生成一个点,标记完成后,会形成一个标注区域,同时弹出labelme的框,输入标注名,点击ok,标注完成

注意:要标注的车道线,一般会有多条,需要不同的命名加以区分,lane1,lane2等

标注完成后,会生成一个json文件。

4、将json转换为dataset

labelme_json_to_dataset xxx.json

生成一个文件夹,里面包含五个文件(只能转换一个json)

批量转换json:

在labelme的安装目录下可以看到json_to_dataset文件,默认只提供单个文件转换,我们只需要修改此代码,修改为批量转换

  1. import argparse
  2. import json
  3. import os
  4. import os.path as osp
  5. import warnings
  6. import PIL.Image
  7. import yaml
  8. from labelme import utils
  9. import base64
  10. #批量转换代码
  11. def main():
  12.     warnings.warn("This script is aimed to demonstrate how to convert the\n"
  13.                   "JSON file to a single image dataset, and not to handle\n"
  14.                   "multiple JSON files to generate a real-use dataset.")
  15.     parser = argparse.ArgumentParser()
  16.     parser.add_argument('json_file')
  17.     parser.add_argument('-o', '--out', default=None)
  18.     args = parser.parse_args()
  19.     json_file = args.json_file
  20.     if args.out is None:
  21.         out_dir = osp.basename(json_file).replace('.', '_')
  22.         out_dir = osp.join(osp.dirname(json_file), out_dir)
  23.     else:
  24.         out_dir = args.out
  25.     if not osp.exists(out_dir):
  26.         os.mkdir(out_dir)
  27.     count = os.listdir(json_file)
  28.     for i in range(0, len(count)):
  29.         path = os.path.join(json_file, count[i])
  30.         if os.path.isfile(path):
  31.             data = json.load(open(path))
  32.             if data['imageData']:
  33.                 imageData = data['imageData']
  34.             else:
  35.                 imagePath = os.path.join(os.path.dirname(path), data['imagePath'])
  36.                 with open(imagePath, 'rb') as f:
  37.                     imageData = f.read()
  38.                     imageData = base64.b64encode(imageData).decode('utf-8')
  39.             img = utils.img_b64_to_arr(imageData)
  40.             label_name_to_value = {'_background_': 0}
  41.             for shape in data['shapes']:
  42.                 label_name = shape['label']
  43.                 if label_name in label_name_to_value:
  44.                     label_value = label_name_to_value[label_name]
  45.                 else:
  46.                     label_value = len(label_name_to_value)
  47.                     label_name_to_value[label_name] = label_value
  48.             # label_values must be dense
  49.             label_values, label_names = [], []
  50.             for ln, lv in sorted(label_name_to_value.items(), key=lambda x: x[1]):
  51.                 label_values.append(lv)
  52.                 label_names.append(ln)
  53.             assert label_values == list(range(len(label_values)))
  54.             lbl = utils.shapes_to_label(img.shape, data['shapes'], label_name_to_value)
  55.             captions = ['{}: {}'.format(lv, ln)
  56.                 for ln, lv in label_name_to_value.items()]
  57.             lbl_viz = utils.draw_label(lbl, img, captions)
  58.             out_dir = osp.basename(count[i]).replace('.', '_')
  59.             out_dir = osp.join(osp.dirname(count[i]), out_dir)
  60.             if not osp.exists(out_dir):
  61.                 os.mkdir(out_dir)
  62.             PIL.Image.fromarray(img).save(osp.join(out_dir, 'img.png'))
  63.             #PIL.Image.fromarray(lbl).save(osp.join(out_dir, 'label.png'))
  64.             utils.lblsave(osp.join(out_dir, 'label.png'), lbl)
  65.             PIL.Image.fromarray(lbl_viz).save(osp.join(out_dir, 'label_viz.png'))
  66.             with open(osp.join(out_dir, 'label_names.txt'), 'w') as f:
  67.                 for lbl_name in label_names:
  68.                     f.write(lbl_name + '\n')
  69.             warnings.warn('info.yaml is being replaced by label_names.txt')
  70.             info = dict(label_names=label_names)
  71.             with open(osp.join(out_dir, 'info.yaml'), 'w') as f:
  72.                 yaml.safe_dump(info, f, default_flow_style=False)
  73.             print('Saved to: %s' % out_dir)
  74. if __name__ == '__main__':
  75.     main()

进入到保存json文件的目录,执行labelme_json_to_dataset  path

将标注之后的数据批量处理之后,生成文件夹形式如下图所示

打开文件夹里面有五个文件,分别是

5、数据格式转换

根据tuSimple数据集形式,需要得到二值化和实例化后的图像数据,也就是gt_binary_image和gt_instance_image文件中的显示结果。需要将标注之后的数据进行转换

  1. import cv2
  2. from skimage import measure, color
  3. from skimage.measure import regionprops
  4. import numpy as np
  5. import os
  6. import copy
  7. def skimageFilter(gray):
  8. binary_warped = copy.copy(gray)
  9. binary_warped[binary_warped > 0.1] = 255
  10. gray = (np.dstack((gray, gray, gray))*255).astype('uint8')
  11. labels = measure.label(gray[:, :, 0], connectivity=1)
  12. dst = color.label2rgb(labels,bg_label=0, bg_color=(0,0,0))
  13. gray = cv2.cvtColor(np.uint8(dst*255), cv2.COLOR_RGB2GRAY)
  14. return binary_warped, gray
  15. def moveImageTodir(path,targetPath,name):
  16. if os.path.isdir(path):
  17. image_name = "gt_image/"+str(name)+".png"
  18. binary_name = "gt_binary_image/"+str(name)+".png"
  19. instance_name = "gt_instance_image/"+str(name)+".png"
  20. train_rows = image_name + " " + binary_name + " " + instance_name + "\n"
  21. origin_img = cv2.imread(path+"/img.png")
  22. origin_img = cv2.resize(origin_img, (1280,720))
  23. cv2.imwrite(targetPath+"/"+image_name, origin_img)
  24. img = cv2.imread(path+'/label.png')
  25. img = cv2.resize(img, (1280,720))
  26. gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
  27. binary_warped, instance = skimageFilter(gray)
  28. cv2.imwrite(targetPath+"/"+binary_name, binary_warped)
  29. cv2.imwrite(targetPath+"/"+instance_name, instance)
  30. print("success create data name is : ", train_rows)
  31. return train_rows
  32. return None
  33. if __name__ == "__main__":
  34. count = 1
  35. with open("./train.txt", 'w+') as file:
  36. for images_dir in os.listdir("./images"):
  37. dir_name = os.path.join("./images", images_dir + "/annotations")
  38. for annotations_dir in os.listdir(dir_name):
  39. json_dir = os.path.join(dir_name, annotations_dir)
  40. if os.path.isdir(json_dir):
  41. train_rows = moveImageTodir(json_dir, "./", str(count).zfill(4))
  42. file.write(train_rows)
  43. count += 1

转换之后的显示结果:

由于lanenet模型处理需要按照tusimple数据进行,首先需要将上一步处理的数据生成tfrecords格式,调用laneNet中lanenet_data_feed_pipline.py文件。

  1. python data_provider/lanenet_data_feed_pipline.py
  2. --dataset_dir ../dataset/lane_detection_dataset/
  3. --tfrecords_dir ../dataset/lane_detection_dataset/tfrecords

欢迎关注公众号:算法工程师的学习日志

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/Cpp五条/article/detail/345229
推荐阅读
相关标签
  

闽ICP备14008679号