当前位置:   article > 正文

人工智能框架数据集转换C++实现(一):TFRecord数据集

tfrecord数据集

最近在研究将各种数据集转换为不同AI框架的自有数据,这些框架包括Caffe,MXNet,Tensorflow等.C++这样一个通用而又强大的语言,却让使用C++的同鞋在AI时代有一个痛点,那就是目前的AI框架基本都是底层用C/C++实现,但提供的接口却大部分都是python的接口,而且Python的接口封装的都特别好,MXNet还好,提供im2rec.cc这样的C/C++源码,而Caffe,尤其是Tensorflow这样的框架,想用C++来转换数据就需要花点功夫了.所以本文首先讲解Tensorflow的数据集格式转换.

1.不同框架的数据分别是怎样的?

MXNet的自有数据集:rec格式

Caffe的自有据集:Lmdb格式

Tensorflow的自有数据集:TFRecord格式

2.什么是TFRecord格式?

关于tensorflow读取数据,官网给出了三种方法:
1、供给数据:在tensorflow程序运行的每一步,让python代码来供给数据
2、从文件读取数据:建立输入管线从文件中读取数据
3、预加载数据:如果数据量不太大,可以在程序中定义常量或者变量来保存所有的数据。

而tfrecord格式是Tensorflow官方推荐的标准格式。tfrecord数据文件是一种将图像数据和标签统一存储的二进制文件,能更好的利用内存,在tensorflow中快速的复制,移动,读取,存储等。

该数据集由一个example.proto文件定义:

syntax = "proto3";


message Example{
    Features features = 1;
};
message Features{
    map<string,Feature> feature = 1;
};

// Containers to hold repeated fundamental values.
message BytesList {
  repeated bytes value = 1;
}
message FloatList {
  repeated float value = 1 [packed = true];
}
message Int64List {
  repeated int64 value = 1 [packed = true];
}

message Feature{
    oneof kind{
        BytesList bytes_list = 1;
        FloatList float_list = 2;
        Int64List int64_list = 3;
    }
};

这是一个protobuf3的格式定义,需要使用以下命令通过该文件生成头文件example.pb.h和cc文件example.pb.cc:

protoc -I=. --cpp_out=./ example.proto

3.自有数据集该准备成什么样?

此处以VOC2007数据集为检测任务的例子讲解,LFW数据集为分类任务讲解.

对于分类任务,数据集统一构建一个这样的列表,该表的构建可以参考Caffe的分类任务列表的构建(文件名和标签中间不是空格,而是\t):

/output/oldFile/1000015_10/wKgB5Fr6WwWAJb7iAAABKohu5Nw109.png   0
/output/oldFile/1000015_10/wKgB5Fr6WwWAEbg6AAABC_mxdD8880.png   0
/output/oldFile/1000015_10/wKgB5Fr6WwWAUGTdAAAA8wVERrQ677.png   0
/output/oldFile/1000015_10/wKgB5Fr6WwWAPJ-lAAABPYAoeuY242.png   0
/output/oldFile/1000015_10/wKgB5Fr6WwWARVIWAAABCK2alGs331.png   0
/output/oldFile/1000015_10/wKgB5Fr6WwWAV3R5AAAA5573dko147.png   0
/output/oldFile/1000015_10/wKgB5Fr6WwaAUjQRAAABIkYxqoY008.png   0
...
/output/oldFile/1000015_10/wKgB5Vr6YF-AALG-AAAA-qStI_Q208.png   1
/output/oldFile/1000015_10/wKgB5Vr6YGCAe1VYAAABN5fz53Y240.png   1
/output/oldFile/1000015_10/wKgB5Vr6YGCAQo7fAAABVFasXJ4223.png   1
/output/oldFile/1000015_10/wKgB5Vr6YGCAL00yAAABJdrU4U0508.png   1
/output/oldFile/1000015_10/wKgB5Vr6YGCAFjTyAAABJVgoCrU242.png   1
/output/oldFile/1000015_10/wKgB5Vr6YGCAKmMMAAABMd1_pJg240.png   1
/output/oldFile/1000015_10/wKgB5Vr6YGCAR2FqAAABFCQ7LRY651.png   1

对于VOC2007数据集,构建的列表如下(文件名和标签中间不是空格,而是\t):

/home/test/data/VOC2007/JPEGImages/004379.jpg /home/xbx/data/VOC2007/Annotations/004379.xml
/home/test/data/VOC2007/JPEGImages/001488.jpg /home/xbx/data/VOC2007/Annotations/001488.xml
/home/test/data/VOC2007/JPEGImages/004105.jpg /home/xbx/data/VOC2007/Annotations/004105.xml
/home/test/data/VOC2007/JPEGImages/006146.jpg /home/xbx/data/VOC2007/Annotations/006146.xml
/home/test/data/VOC2007/JPEGImages/004295.jpg /home/xbx/data/VOC2007/Annotations/004295.xml
/home/test/data/VOC2007/JPEGImages/001360.jpg /home/xbx/data/VOC2007/Annotations/001360.xml
/home/test/data/VOC2007/JPEGImages/003468.jpg /home/xbx/data/VOC2007/Annotations/003468.xml
...

4.数据集转换的流程是怎样的?

数据列表准备好之后,就可以开始分析数据集转换的流程,大体上来说就是对于分类任务,首先初始化一个RecordWriter,然后处理列表中的数据,每一行对应一个Example,每行包含图片路径和相应的标签,使用OPENCV读取图片为Mat后,将其转换为string的格式(为什么不是char*,因为图像中可能存在\0),保存到Example中的feature中,map名称取为image_raw,并获取图片的宽高通道数,标签等信息,也都保存到Example中的feature中,map名分别为width,height,depth,label等,最后将每行的Example序列化SerializeToString为string,调用writer_->WriteRecord写入.对于检测任务区别则在于增加了对xml文件的解析,并保存bbox信息等.

需要用到的头文件包括:

  1. #include <fcntl.h>
  2. #include <stdio.h>
  3. #include <sys/stat.h>
  4. #include <sys/types.h>
  5. #include <unistd.h>
  6. #include <boost/foreach.hpp>
  7. #include <boost/property_tree/ptree.hpp>
  8. #include <boost/property_tree/xml_parser.hpp>
  9. #include <fstream>
  10. #include <iostream>
  11. #include <map>
  12. #include <opencv2/core/core.hpp>
  13. #include <opencv2/highgui/highgui.hpp>
  14. #include <opencv2/imgproc/imgproc.hpp>
  15. #include <vector>
  16. #include "tensorflow/core/lib/core/status_test_util.h"
  17. #include "tensorflow/core/lib/core/stringpiece.h"
  18. #include "tensorflow/core/lib/io/record_writer.h"
  19. #include <boost/lexical_cast.hpp>
  20. #include "rng.hpp"
using namespace tensorflow::io;
using namespace tensorflow;

主函数的判断:

  1. if ((dataset_type == "object_detect") && (label_map_file.length() > 0)) {
  2. //检测任务,其中datalist_file是列表名,label_map_file是标签name和label的转换文件,output_dir是tfrecord需要输出的路径,output_name是tfrecord输出的文件名,samples_pre是tfrecord单个文件保存多少行,Shuffle是是否打乱
  3. if (!detecteddata_to_tfrecords(datalist_file, label_map_file, output_dir, output_name,
  4. samples_pre, Shuffle)) {
  5. printf("convert wrong!!!\n");
  6. return false;
  7. }
  8. } else if ((dataset_type == "classification") && (label_width > 0)) {
  9. //分类任务,其中datalist_file是列表名,output_dir是tfrecord需要输出的路径,output_name是tfrecord输出的文件名,samples_pre是tfrecord单个文件保存多少行,label_width是标签数目,对应单标签还是多标签,Shuffle是是否打乱
  10. if (!clsdata_to_tfrecords(datalist_file, output_dir, output_name, samples_pre, label_width,
  11. Shuffle)) {
  12. printf("convert wrong!!!\n");
  13. return false;
  14. }
  15. } else {
  16. printf(
  17. "dataset type is not object_detect or classification, or label_width [%lu], label_map_file "
  18. "[%s] is wrong!!!\n",
  19. label_width, label_map_file.c_str());
  20. return false;
  21. }
  22. // Optional: Delete all global objects allocated by libprotobuf.清理在各子函数中打开的protobuf资源
  23. google::protobuf::ShutdownProtobufLibrary();

对于分类任务,代码如下:

  1. bool clsdata_to_tfrecords(string datalist_file, string output_dir, string output_name,
  2. int samples_pre, size_t label_width, int Shuffle) {
  3. std::ifstream infile(datalist_file.c_str());
  4. std::string line;
  5. std::vector<std::pair<string, std::vector<int> > > dataset;
  6. //读取列表文件,并将信息保存到dataset中
  7. while (getline(infile, line)) {
  8. vector<string> tmp_str = param_split(line, "\t");
  9. std::string filename;
  10. std::vector<int> label_v;
  11. if (tmp_str.size() != (label_width + 1)) {
  12. std::cout << "line " << line << "has too many param!!!" << std::endl;
  13. return false;
  14. }
  15. for (size_t i = 0; i < (label_width + 1); ++i) {
  16. if (i == 0) {
  17. filename = tmp_str[0];
  18. } else {
  19. try {
  20. int label = boost::lexical_cast<int>(tmp_str[i]);
  21. label_v.push_back(label);
  22. } catch (boost::bad_lexical_cast& e) {
  23. printf("%s\n", e.what());
  24. return false;
  25. }
  26. }
  27. }
  28. if (filename.size() > 0) dataset.push_back(std::make_pair(filename, label_v));
  29. }
  30.   
  31.   //打乱数据集,该代码借用caffe中rng.hpp代码
  32. if (Shuffle) {
  33. printf("tensorflow task will be shuffled!!!");
  34. caffe::shuffle(dataset.begin(), dataset.end());
  35. }
  36. printf("A total of %lu images.\n", dataset.size());
  37. // create recordwriter
  38. std::unique_ptr<WritableFile> file;
  39. RecordWriterOptions options = RecordWriterOptions::CreateRecordWriterOptions("ZLIB");
  40. RecordWriter* writer_ = NULL;
  41. int j = 0, fidx = 0;
  42. size_t line_id = 0;
  43. for (line_id = 0; line_id < dataset.size(); ++line_id) {
  44. if (line_id == 0 || j > samples_pre) {
  45.       //如果是第一次或者单个文件的tfrecord记录达到samples_pre上限,则重新初始化一个新的RecordWriter
  46. if (writer_ != NULL) {
  47. delete writer_;
  48. writer_ = NULL;
  49. }
  50. char output_file[1024];
  51. memset(output_file, 0, 1024);
  52. sprintf(output_file, "%s/%s_%03d.tfrecord", output_dir.c_str(), output_name.c_str(), fidx);
  53. printf("create new tfrecord file: [%s] \n", output_file);
  54. Status s = Env::Default()->NewWritableFile((string)output_file, &file);
  55. if (!s.ok()) {
  56. printf("create write record file [%s] wrong!!!\n", output_file);
  57. return false;
  58. }
  59. writer_ = new RecordWriter(file.get(), options);
  60. j = 0;
  61. fidx += 1;
  62. }
  63. //读取图片
  64. cv::Mat image = ReadImageToCVMat(dataset[line_id].first);
  65. //将Mat转为string的形式
  66. std::string image_b = matToBytes(image);
  67. int height = image.rows;
  68. int width = image.cols;
  69. int depth = image.channels();
  70.     //每一条数据对应一个Example
  71. Example example1;
  72. Features* features1 = example1.mutable_features();
  73. ::google::protobuf::Map<string, Feature>* feature1 = features1->mutable_feature();
  74. Feature feature_tmp;
  75. feature_tmp.Clear();
  76. if (!bytes_feature(feature_tmp, image_b)) {
  77. printf("image: [%s] wrong\n", dataset[line_id].first.c_str());
  78. continue;
  79. }
  80. (*feature1)["image_raw"] = feature_tmp;
  81. feature_tmp.Clear();
  82. if (!int64_feature(feature_tmp, height)) {
  83. printf("image: [%s] , height [%d] wrong\n", dataset[line_id].first.c_str(), height);
  84. continue;
  85. }
  86. (*feature1)["height"] = feature_tmp;
  87. feature_tmp.Clear();
  88. if (!int64_feature(feature_tmp, width)) {
  89. printf("image: [%s] , width [%d] wrong\n", dataset[line_id].first.c_str(), width);
  90. continue;
  91. }
  92. (*feature1)["width"] = feature_tmp;
  93. feature_tmp.Clear();
  94. if (!int64_feature(feature_tmp, depth)) {
  95. printf("image: [%s] , depth [%d] wrong\n", dataset[line_id].first.c_str(), depth);
  96. continue;
  97. }
  98. (*feature1)["depth"] = feature_tmp;
  99. //此次默认分类数据集的label已经转化为了0,1,2,3,4,5这样的形式,否则此处需要加上name to label的转化代码
  100. feature_tmp.Clear();
  101. if (!int64_feature(feature_tmp, dataset[line_id].second)) {
  102. printf("image: [%s] wrong\n", dataset[line_id].first.c_str());
  103. continue;
  104. }
  105. (*feature1)["label"] = feature_tmp;
  106. //将example序列化为string,并写入Writer_
  107. std::string str;
  108. example1.SerializeToString(&str);
  109. writer_->WriteRecord(str);
  110. ++j;
  111. if (line_id % 1000 == 0) {
  112. printf("Processed %lu files.\n", line_id);
  113. }
  114. }
  115. printf("Processed %lu files.\n finished", line_id);
  116. if (writer_ != NULL) {
  117. delete writer_;
  118. writer_ = NULL;
  119. }
  120. return true;
  121. }

其中,matToBytes函数定义如下:

  1. std::string matToBytes(cv::Mat image) {
  2. int size = image.total() * image.elemSize();
  3. byte* bytes = new byte[size];
  4. memcpy(bytes, image.data, size * sizeof(byte));
  5. std::string img_s(bytes, size);
  6. return img_s;
  7. }

string转feature,或vector<int>转feature等定义如下:

  1. //函数重载,使得int和vector<int>都可以转换为feature
  2. bool int64_feature(Feature& feature, int value) {
  3. Int64List* i_list1 = feature.mutable_int64_list();
  4. i_list1->add_value(value);
  5. return true;
  6. }
  7. bool int64_feature(Feature& feature, std::vector<int> value) {
  8. if (value.size() < 1) {
  9. printf("value int64 is wrong!!!");
  10. return false;
  11. }
  12. Int64List* i_list1 = feature.mutable_int64_list();
  13. for (size_t i = 0; i < value.size(); ++i) i_list1->add_value(value[i]);
  14. return true;
  15. }
  16. bool float_feature(Feature& feature, std::vector<double> value) {
  17. if (value.size() < 1) {
  18. printf("value float is wrong!!!");
  19. return false;
  20. }
  21. FloatList* f_list1 = feature.mutable_float_list();
  22. for (size_t i = 0; i < value.size(); ++i) f_list1->add_value(value[i]);
  23. return true;
  24. }
  25. //将图像信息转换为feature
  26. bool bytes_feature(Feature& feature, std::string value) {
  27. BytesList* b_list1 = feature.mutable_bytes_list();
  28. //图像中含有0可能会存在问题
  29. b_list1->add_value(value);
  30. return true;
  31. }

对于检测任务,大体流程一致,列表读取代码有点差异,另外需要增加对xml文件的格式化处理,可以使用boost的xml解析,大体代码如下:

  1. bool ReadXMLToExapmle(const string& image_file, const string& xmlfile, const int img_height,
  2. const int img_width, const std::map<string, int>& name_to_label,
  3. RecordWriter* writer_) {
  4. //图像读取
  5. cv::Mat image = ReadImageToCVMat(image_file);
  6. if (!image.data) {
  7. cout << "Could not open or find file " << image_file;
  8. return false;
  9. }
  10. //将Mat转换为string
  11. std::string image_b = matToBytes(image);
  12. Example example1;
  13. Features* features1 = example1.mutable_features();
  14. ::google::protobuf::Map<string, Feature>* feature1 = features1->mutable_feature();
  15. Feature feature_tmp;
  16. feature_tmp.Clear();
  17. if (!bytes_feature(feature_tmp, image_b)) {
  18. printf("image: [%s] wrong\n", image_file.c_str());
  19. return false;
  20. ;
  21. }
  22. (*feature1)["image/encoded"] = feature_tmp;
  23. ptree pt;
  24. read_xml(xmlfile, pt);
  25. // Parse annotation.
  26. int width = 0, height = 0, depth = 0;
  27. try {
  28. height = pt.get<int>("annotation.size.height");
  29. width = pt.get<int>("annotation.size.width");
  30. depth = pt.get<int>("annotation.size.depth");
  31. } catch (const ptree_error& e) {
  32. std::cout << "when parsing " << xmlfile << ":" << e.what() << std::endl;
  33. height = img_height;
  34. width = img_width;
  35. return false;
  36. }
  37. feature_tmp.Clear();
  38. feature_tmp.Clear();
  39. if (!int64_feature(feature_tmp, height)) {
  40. printf("xml : [%s] 's height wrong\n", xmlfile.c_str());
  41. return false;
  42. }
  43. (*feature1)["image/height"] = feature_tmp;
  44. feature_tmp.Clear();
  45. if (!int64_feature(feature_tmp, width)) {
  46. printf("xml : [%s] 's width wrong\n", xmlfile.c_str());
  47. return false;
  48. }
  49. (*feature1)["image/width"] = feature_tmp;
  50. feature_tmp.Clear();
  51. if (!int64_feature(feature_tmp, depth)) {
  52. printf("xml : [%s] 's depth wrong\n", xmlfile.c_str());
  53. return false;
  54. }
  55. (*feature1)["image/depth"] = feature_tmp;
  56. std::vector<int> v_label;
  57. std::vector<int> v_difficult;
  58. std::vector<double> v_xmin;
  59. std::vector<double> v_ymin;
  60. std::vector<double> v_xmax;
  61. std::vector<double> v_ymax;
  62. BOOST_FOREACH (ptree::value_type& v1, pt.get_child("annotation")) {
  63. ptree pt1 = v1.second;
  64. if (v1.first == "object") {
  65. bool difficult = false;
  66. ptree object = v1.second;
  67. BOOST_FOREACH (ptree::value_type& v2, object.get_child("")) {
  68. ptree pt2 = v2.second;
  69. if (v2.first == "name") {
  70. string name = pt2.data();
  71. if (name_to_label.find(name) == name_to_label.end()) {
  72. std::cout << "file : [" << xmlfile << "] Unknown name: " << name << std::endl;
  73. return true;
  74. }
  75. int label = name_to_label.find(name)->second;
  76. v_label.push_back(label);
  77. } else if (v2.first == "difficult") {
  78. difficult = pt2.data() == "1";
  79. v_difficult.push_back(difficult);
  80. } else if (v2.first == "bndbox") {
  81. int xmin = pt2.get("xmin", 0);
  82. int ymin = pt2.get("ymin", 0);
  83. int xmax = pt2.get("xmax", 0);
  84. int ymax = pt2.get("ymax", 0);
  85. if ((xmin > width) || (ymin > height) || (xmax > width) || (ymax > height) ||
  86. (xmin < 0) || (ymin < 0) || (xmax < 0) || (ymax < 0)) {
  87. std::cout << "bounding box exceeds image boundary." << std::endl;
  88. return false;
  89. }
  90. v_xmin.push_back(xmin);
  91. v_ymin.push_back(ymin);
  92. v_xmax.push_back(xmax);
  93. v_ymax.push_back(ymax);
  94. }
  95. }
  96. }
  97. }
  98. feature_tmp.Clear();
  99. if (!int64_feature(feature_tmp, v_label)) {
  100. printf("xml : [%s]'s label wrong\n", xmlfile.c_str());
  101. return false;
  102. }
  103. (*feature1)["image/object/bbox/label"] = feature_tmp;
  104. feature_tmp.Clear();
  105. if (!int64_feature(feature_tmp, v_difficult)) {
  106. printf("xml : [%s]'s difficult wrong\n", xmlfile.c_str());
  107. return false;
  108. }
  109. (*feature1)["image/object/bbox/difficult"] = feature_tmp;
  110. feature_tmp.Clear();
  111. if (!float_feature(feature_tmp, v_xmin)) {
  112. printf("xml : [%s]'s v_xmin wrong\n", xmlfile.c_str());
  113. return false;
  114. }
  115. (*feature1)["image/object/bbox/xmin"] = feature_tmp;
  116. feature_tmp.Clear();
  117. if (!float_feature(feature_tmp, v_ymin)) {
  118. printf("xml : [%s]'s v_ymin wrong\n", xmlfile.c_str());
  119. return false;
  120. }
  121. (*feature1)["image/object/bbox/ymin"] = feature_tmp;
  122. feature_tmp.Clear();
  123. if (!float_feature(feature_tmp, v_xmax)) {
  124. printf("xml : [%s]'s v_xmax wrong\n", xmlfile.c_str());
  125. return false;
  126. }
  127. (*feature1)["image/object/bbox/xmax"] = feature_tmp;
  128. feature_tmp.Clear();
  129. if (!float_feature(feature_tmp, v_ymax)) {
  130. printf("xml : [%s]'s v_ymax wrong\n", xmlfile.c_str());
  131. return false;
  132. }
  133. (*feature1)["image/object/bbox/xmax"] = feature_tmp;
  134. //序列化example并写入writerrecord
  135. std::string str;
  136. example1.SerializeToString(&str);
  137. writer_->WriteRecord(str);
  138. return true;
  139. }

最终编译Makefile如下:

all:
	rm -rf example.pb*
	${PROTOBUF_HOME}/bin/protoc -I=. --cpp_out=./ example.proto		
	${PROTOBUF_HOME}/bin/protoc -I=. --cpp_out=./ label.proto		
	g++ -std=c++11 -o dataset_to_tfrecord dataset_to_tfrecord.cc example.pb.cc common.cpp -I/usr/local/opencv2/include -L/usr/local/opencv2/lib -L. -lopencv_core -lopencv_highgui -lopencv_imgproc -Itensorflow的路径 -Itensorflow的路径/bazel-genfiles -I${PROTOBUF_HOME}/include -I/usr/local/include/eigen3 -L${PROTOBUF_HOME}/lib -Ltensorflow的路径/bazel-bin/tensorflow/ -lprotobuf -ltensorflow_framework -I${JSONCPP_HOME}/include -L${JSONCPP_HOME}/lib -ljson_linux-gcc-5.4.0_libmt 


声明:本文内容由网友自发贡献,转载请注明出处:【wpsshop】
推荐阅读
相关标签
  

闽ICP备14008679号