赞
踩
关于TensorFlow读取数据,官网给出了三种方法:
对于数据量较小而言,可能一般选择直接将数据加载进内存,然后再分batch输入网络进行训练(tip:使用这种方法时,结合yeild 使用更为简洁)。但是如果数据量较大,这样的方法就不适用了。因为太耗内存,所以这时最好使用TensorFlow提供的队列queue,也就是第二种方法:从文件读取数据。对于一些特定的读取,比如csv文件格式,官网有相关的描述,在这里我们学习一种比较通用的,高效的读取方法,即使用TensorFlow内定标准格式——TFRecords。
TensorFlow提供了一种统一的格式来存储数据,这个格式就是TFRecords。
一种保存记录的方法可以允许你讲任意的数据转换为TensorFlow所支持的格式,这种方法可以使TensorFlow的数据集更容易与网络应用架构相匹配。这种建议的方法就是使用TFRecords文件。
TFRecord是谷歌推荐的一种二进制文件格式,理论上它可以保存任何格式的信息。下面是Tensorflow的官网给出的文档结构,整个文件由文件长度信息,长度校验码,数据,数据校验码组成。
1 2 3 4 |
|
但是对于我们普通开发者而言,我们并不需要关心这些,TensorFlow提供了丰富的API可以帮助我们轻松地读写TFRecord文件。
TFRecord支持写入三种格式的数据:string,int64,float32,以列表的形式分别通过tf.train.BytesList,tf.train.Int64List,tf.train.FloatList 写入 tf.train.Feature,如下所示:
1 2 3 4 5 6 7 |
|
通过上述操作,我们以dict的形式把要写入的数据汇总,并构建 tf.train.Features,然后构建 tf.train.Example。如下:
1 2 3 4 5 6 7 8 9 10 11 |
|
把创建的tf.train.Example序列化下,便可以通过 tf.python_io.TFRecordWriter 写入 tfrecord文件中,如下:
1 2 3 4 5 6 7 8 9 10 |
|
TFRecord 的核心内容在于内部有一系列的Example,Example 是protocolbuf 协议(protocolbuf 是通用的协议格式,对主流的编程语言都适用。所以这些 List对应到Python语言当中是列表。而对于Java 或者 C/C++来说他们就是数组)下的消息体。
一个Example消息体包含了一系列的feature属性。每一个feature是一个map,也就是 key-value 的键值对。key 取值是String类型。而value是Feature类型的消息体。下面代码给出了 tf.train.Example的定义:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 |
|
从上面的代码可以看出 tf.train.example 的数据结构是比较简洁的。tf.train>example中包含了一个从属性名称到取值的字典。其中属性名称为一个字符串,属性的取值为字符串(ByteList),实数列表(FloatList)或者整数列表(Int64List),举个例子,比如将一张解码前的图像存为一个字符串,图像所对应的类别编码存为整数列表,所以可以说TFRecord 可以存储几乎任何格式的信息。
TFRerecord也不是非用不可,但确实是谷歌官网推荐的文件格式。
TFRecords 其实是一种二进制文件,虽然它不如其他格式好理解,但是它能更好的利用内存,更方便赋值和移动,并且不需要单独的标签文件,理论上,它能保存所有的信息。总而言之,这样的文件格式好处多多,所以让我们利用起来。
使用TensorFlow进行网格训练时,为了提高读取数据的效率,一般建议将训练数据转化为TFrecords格式。
使用tensorflow官网例子练习,我们会发现基本都是MNIST,CIFAR_10这种做好的数据集说事。所以对于我们这些初学者,完全不知道图片该如何输入。这时候学习自己制作数据集就非常有必要了。
我们可以使用TFWriter轻松的完成这个任务。但是制作之前,我们要明确自己的目的。我们必须要想清楚,需要把什么信息存储到TFRecord 文件当中,这其实是最重要的。
下面我们将一张图片转化为TFRecord,然后读取一张TFRecord文件,并展示为图片。
4.1 将一张图片转化成TFRecord 文件
下面举例说明尝试把图片转化成TFRecord 文件。
首先定义Example 消息体。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 |
|
上面的Example表示,要将一张 cat 图片信息写进了 TFRecord 当中。而图片信息包含了图片的名字,图片的维度信息还有图片的数据,分别对应了 name,shape,content 3个feature。
下面我们尝试使用代码实现:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 |
|
上述代码注释比较详细,所以我们就重点说一下下面三点:
4.2 TFRecord 文件读取为图片
我们将图片的信息写入到一个tfrecord文件当中。现在我们需要检验它是否正确。这就需要用到如何读取TFRecord 文件的知识点了。
代码如下:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 |
|
下面解释一下代码:
1,首先使用dataset去读取tfrecord文件
2,在解析example 的时候,用现成的API:tf.parse_single_example
3,用 np.fromstring() 方法就可以获取解析后的string数据,记得把数据还原成 np.uint8
4,用 tf.image.encode_jepg() 方法可以将图片数据编码成 jpeg 格式
5,用 tf.gfile.GFile 对象可以把图片数据保存到本地
6,因为将图片 shape 写入了example 中,所以解析的时候必须指定维度,在这里 [3],不然程序会报错。
运行程序后,可以看到图片显示如下:
下面我们将一个文件夹的图片转化为TFRecord,然后再将TFRecord读取为图片。
5.1 将一个文件夹下多张图片转化为一个TFRecord文件
下面举例说明尝试把图片转化成TFRecord 文件。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 |
|
5.2 将一个TFRecord文件转化为图片显示
因为图片太多,所以我们这里只展示每个文件夹中第一张图片即可。
代码如下:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 |
|
部分代码解析:
5.3,加入队列
1 2 3 4 5 6 |
|
注意,启动队列那条code不能忘记,不然会卡死,这样加入后,就可以做到和tensorflow官网一样的二进制数据集了。
当图片数据很多时候,会导致单个record文件超级巨大的情况,解决方法就是,将数据分成多个record文件保存,读取时,只需要将多个record文件的路径列表交给“tf.train.string_input_producer”,
完整代码如下:(此处来自 此博客)
之前,我们都是将数据转存为tfrecord文件,训练时候再去读取,如果不想转为record文件,想直接读取图像文件进行训练,可以使用下面的方法:
filename.txt
1 2 3 4 5 6 7 8 9 10 |
|
代码如下:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 |
|
TensorFlow引入了tf.data.Dataset模块,使其数据读入的操作变得更为方便,而支持多线程(进程)的操作,也在效率上获得了一定程度的提高。使用tf.data.Dataset模块的pipline机制,可实现CPU多线程处理输入的数据,如读取图片和图片的一些的预处理,这样GPU可以专注于训练过程,而CPU去准备数据。
参考资料:
1 2 3 |
|
从tfrecord文件创建TFRecordDataset方式如下:
1 2 |
|
解析tfrecord 文件的每条记录,即序列化后的 tf.train.Example;使用 tf.parse_single_example 来解析:
1 |
|
其中,data_dict 是一个dict,包含的key 是写入tfrecord文件时用的key ,相应的value是对应不同的数据类型,我们直接使用代码看,如下:
1 2 3 4 5 6 7 8 |
|
解析tfrecord文件中的所有记录,我们需要使用dataset 的map 方法,如下:
1 |
|
Dataset支持一类特殊的操作:Transformation。一个Dataset通过Transformation变成一个新的Dataset。通常我们可以通过Transformation完成数据变换,打乱,组成batch,生成epoch等一系列操作。常用的Transformation有:map、batch、shuffle和repeat。
map方法可以接受任意函数对dataset中的数据进行处理;另外可以使用repeat,shuffle,batch方法对dataset进行重复,混洗,分批;用repeat赋值dataset以进行多个epoch;如下:
1 |
|
解析完数据后,便可以取出数据进行使用,通过创建iterator来进行,如下:
1 2 3 |
|
下面分别介绍
8.1,map
使用 tf.data.Dataset.map
,我们可以很方便地对数据集中的各个元素进行预处理。因为输入元素之间时独立的,所以可以在多个 CPU 核心上并行地进行预处理。map
变换提供了一个 num_parallel_calls
参数去指定并行的级别。
1 |
|
8.2,prefetch
tf.data.Dataset.prefetch 提供了 software pipelining 机制。该函数解耦了 数据产生的时间 和 数据消耗的时间。具体来说,该函数有一个后台线程和一个内部缓存区,在数据被请求前,就从 dataset 中预加载一些数据(进一步提高性能)。prefech(n) 一般作为最后一个 transformation,其中 n 为 batch_size。 prefetch 的使用方法如下:
1 2 3 |
|
8.3,repeat
repeat的功能就是将整个序列重复多次,主要用来处理机器学习中的epoch,假设原先的数据是一个epoch,使用repeat(5)就可以将之变成5个epoch:
如果直接调用repeat()的话,生成的序列就会无限重复下去,没有结束,因此也不会抛出tf.errors.OutOfRangeError异常
8.4,完整代码如下:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 |
|
当我们使用tf 中的 dataset时,可能会出现如下错误:
原因是tf 版本不同导致的错误。
在编写代码的时候,使用的tf版本不同,可能导致其Dataset API 放置的位置不同。当使用TensorFlow1.3的时候,Dataset API是放在 contrib 包里面,而当使用TensorFlow1.4以后的版本,Dataset API已经从contrib 包中移除了,而变成了核心API的一员。故会产生报错。
解决方法:
将下面代码:
1 2 |
|
改为此代码:
1 2 |
|
问题解决。
函数如下:
1 |
|
函数功能:实现对图片的读取
函数参数:path:图片所在路径
decodestyle:图片的解码方式(‘r’:UTF-8编码; ‘rb’:非UTF-8编码)
例子如下:
1 |
|
zip() 函数用于将可迭代的对象作为参数,将对象中对应的元素打包成一个个元组,然后返回由这些元组组成的列表。如果各个迭代器的元素个数不一致,则返回列表长度与最短的对象相同,利用*号操作符,可以将元组解压为列表。
在 Python 3.x 中为了减少内存,zip() 返回的是一个对象。如需展示列表,需手动 list() 转换。
1 2 3 4 5 |
|
实例:
1 2 3 4 5 6 7 8 9 10 11 12 |
|
1,为什么前面使用Dataset,而用大多数博文中的 QueueRunner 呢?
A:这是因为 Dataset 比 QueueRunner 新,而且是官方推荐的,Dataset 比较简单。
2,学习了 TFRecord 相关知识,下一步学习什么?
A:可以尝试将常见的数据集如 MNIST 和 CIFAR-10 转换成 TFRecord 格式。
A:可以尝试将常见的数据集如 MNIST 和 CIFAR-10 转换成 TFRecord 格式。
参考文献:https://blog.csdn.net/u012759136/article/details/52232266
https://blog.csdn.net/tengxing007/article/details/56847828/
https://blog.csdn.net/briblue/article/details/80789608 (五星推荐)
https://blog.csdn.net/happyhorizion/article/details/77894055 (五星推荐)
不经一番彻骨寒 怎得梅花扑鼻香
TensorFlow直接读取图片和读写TFRecords速度对比
https://www.cnblogs.com/wj-1314/p/11211333.html
https://blog.csdn.net/kwame211/article/details/78579035
https://www.jianshu.com/p/15e3f74180fc
https://www.2cto.com/kf/201702/604326.html
https://www.2cto.com/kf/201702/604326.html
https://blog.csdn.net/lingtianyulong/article/details/80555908
https://www.sohu.com/a/219765050_717210
https://github.com/YJango/TFRecord-Dataset-Estimator-API/blob/master/tfrecorder.py
https://www.jiqizhixin.com/articles/2018-07-06-4
https://blog.csdn.net/weixin_42111770/article/details/87920048
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。