TensorFlow提供了标准的TFRecord 格式,而关于 tensorflow 读取数据, 官网也提供了3中方法 :
1 Feeding: 在tensorflow程序运行的每一步, 用python代码在线提供数据
2 Reader : 在一个计算图(tf.graph)的开始前,将文件读入到流(queue)中
3 在声明tf.variable变量或numpy数组时保存数据。受限于内存大小,适用于数据较小的情况
注意:tfrecord会根据你选择输入文件的类,自动给每一类打上同样的标签 如在本例中,只有0,1 两类
- # -----------------------------------------------------------------------------
- # encoding=utf-8
- import os
- import tensorflow as tf
- from PIL import Image
- cwd = 'ferro/train//'
- classes = {'Cutting', 'Fatigue','Normal'}
- # 制作TFRecords数据
- def create_record():
- writer = tf.python_io.TFRecordWriter("ferro_train.tfrecords")
- for index, name in enumerate(classes):
- class_path = cwd + "/" + name + "/"
- for img_name in os.listdir(class_path):
- img_path = class_path + img_name
- img = Image.open(img_path)
- img = img.resize((64, 64))
- img_raw = img.tobytes() # 将图片转化为原生bytes
- print(index, img_raw)
- example = tf.train.Example(
- features=tf.train.Features(feature={
- "label": tf.train.Feature(int64_list=tf.train.Int64List(value=[index])),
- 'img_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw]))
- }))
- writer.write(example.SerializeToString())
- writer.close()
TFRecords文件包含了tf.train.Example 协议内存块(protocol buffer)(协议内存块包含了字段 Features)。我们可以写一段代码获取你的数据, 将数据填入到Example协议内存块(protocol buffer),将协议内存块序列化为一个字符串, 并且通过tf.python_io.TFRecordWriter 写入到TFRecords文件。
- # 读取二进制数据
- def read_and_decode(filename):
- # 创建文件队列,不限读取的数量
- filename_queue = tf.train.string_input_producer([filename])
- # create a reader from file queue
- reader = tf.TFRecordReader()
- # reader从文件队列中读入一个序列化的样本
- _, serialized_example = reader.read(filename_queue)
- # get feature from serialized example
- # 解析符号化的样本
- features = tf.parse_single_example(
- serialized_example,
- features={
- 'label': tf.FixedLenFeature([], tf.int64),
- 'img_raw': tf.FixedLenFeature([], tf.string)
- })
- label = features['label']
- img = features['img_raw']
- img = tf.decode_raw(img, tf.uint8)
- img = tf.reshape(img, [64, 64, 3])
- # img = tf.cast(img, tf.float32) * (1. / 255) - 0.5
- label = tf.cast(label, tf.int32)
- return img, label
一个Example中包含Features,Features里包含Feature(这里没s)的字典。最后,Feature里包含有一个 FloatList, 或者ByteList,或者Int64List。另外,需要我们注意的是:feature的属性“label”和“img_raw”名称要和制作时统一 ,返回的img数据和label数据一一对应。
- if __name__ == '__main__':
- create_record()
- batch = read_and_decode('ferro_train.tfrecords')
- init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())
- with tf.Session() as sess: # 开始一个会话
- sess.run(init_op)
- coord = tf.train.Coordinator()
- threads = tf.train.start_queue_runners(coord=coord)
- for i in range(200):
- example, lab = sess.run(batch) # 在会话中取出image和label
- img = Image.fromarray(example, 'RGB') # 这里Image是之前提到的
- img.save(cwd + '/' + str(i) + '_Label_' + str(lab) + '.jpg') # 存下图片;注意cwd后边加上‘/’
- print(example, lab)
- coord.request_stop()
- coord.join(threads)
- sess.close()
- # -----------------------------------------------------------------------------
- # encoding=utf-8
- import os
- import tensorflow as tf
- from PIL import Image
- cwd = 'ferro/train//'
- classes = {'Cutting', 'Fatigue','Normal'}
- # 制作TFRecords数据
- def create_record():
- writer = tf.python_io.TFRecordWriter("ferro_train.tfrecords")
- for index, name in enumerate(classes):
- class_path = cwd + "/" + name + "/"
- for img_name in os.listdir(class_path):
- img_path = class_path + img_name
- img = Image.open(img_path)
- img = img.resize((64, 64))
- img_raw = img.tobytes() # 将图片转化为原生bytes
- print(index, img_raw)
- example = tf.train.Example(
- features=tf.train.Features(feature={
- "label": tf.train.Feature(int64_list=tf.train.Int64List(value=[index])),
- 'img_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw]))
- }))
- writer.write(example.SerializeToString())
- writer.close()
- # -------------------------------------------------------------------------
- # 读取二进制数据
- def read_and_decode(filename):
- # 创建文件队列,不限读取的数量
- filename_queue = tf.train.string_input_producer([filename])
- # create a reader from file queue
- reader = tf.TFRecordReader()
- # reader从文件队列中读入一个序列化的样本
- _, serialized_example = reader.read(filename_queue)
- # get feature from serialized example
- # 解析符号化的样本
- features = tf.parse_single_example(
- serialized_example,
- features={
- 'label': tf.FixedLenFeature([], tf.int64),
- 'img_raw': tf.FixedLenFeature([], tf.string)
- })
- label = features['label']
- img = features['img_raw']
- img = tf.decode_raw(img, tf.uint8)
- img = tf.reshape(img, [64, 64, 3])
- # img = tf.cast(img, tf.float32) * (1. / 255) - 0.5
- label = tf.cast(label, tf.int32)
- return img, label
- # --------------------------------------------------------------------------
- # ---------主程序----------------------------------------------------------
- if __name__ == '__main__':
- create_record()
- #batch = read_and_decode('ferro_train.tfrecords')
- init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())
- with tf.Session() as sess: # 开始一个会话
- sess.run(init_op)
- coord = tf.train.Coordinator()
- threads = tf.train.start_queue_runners(coord=coord)
- for i in range(200):
- example, lab = sess.run(batch) # 在会话中取出image和label
- img = Image.fromarray(example, 'RGB') # 这里Image是之前提到的
- img.save(cwd + '/' + str(i) + '_Label_' + str(lab) + '.jpg') # 存下图片;注意cwd后边加上‘/’
- print(example, lab)
- coord.request_stop()
- coord.join(threads)
- sess.close()
- # -----------------------------------------------------------------------------
运行上述的完整代码,便可以 将从TFRecord中取出的文件保存下来了。如下图:
每一幅图片的命名中,第二个数字则是 label,Cut都为1,Normal都为0;通过对照图片,可以发现图片分类正确。
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。