赞
踩
目录
TFRecord 是 TensorFlow 中的数据集中存储格式,TFRecord是一种二进制文件。
将数据集整理成 TFRecord 格式后,TensorFlow 就可以高效地读取和处理这些数据集,从而更高效地进行大规模的模型训练。
TFRecord 内部使用了二进制数据编码方案,它只占用一个内存块,只需要一次性加载一个二进制文件的方式即可。简单,快速,尤其对大型训练数据很友好。而且当我们的训练数据量比较大的时候,可以将数据分成多个 TFRecord 文件,来提高处理效率。
将形式各样的数据集整理为 TFRecord 格式,可以对数据集中的每个元素进行以下步骤:
(1)读取该数据元素到内存;
(2)将该元素转换为 tf.train.Example 对象(每一个 tf.train.Example 由若干个 tf.train.Feature 的字典组成,因此需要先建立 Feature 的字典);
(3)将该 tf.train.Example 对象序列化为字符串,并通过一个预先定义的 tf.io.TFRecordWriter 写入 TFRecord 文件。
读取 TFRecord 数据可按照以下步骤:
(1)通过 tf.data.TFRecordDataset 读入原始的 TFRecord 文件(此时文件中的 tf.train.Example 对象尚未被反序列化),获得一个 tf.data.Dataset数据集对象;
(2)通过 Dataset.map 方法,对该数据集对象中的每一个序列化的 tf.train.Example 字符串执行 tf.io.parse_single_example 函数,从而实现反序列化。
TFRecord内部包含了多个tf.train.Example
, 而Example
是protocol buffer
(protobuf) 数据标准的实现,在一个Example
消息体中包含了一系列的tf.train.feature
属性,而每一个feature
是一个key-value
的键值对,其中,key
是string类型,而value
的取值有三种:
bytes_list:
可以存储string
和byte
两种数据类型。float_list:
可以存储float(float32)
与double(float64)
两种数据类型 。int64_list:
可以存储bool, enum, int32, uint32, int64, uint64
。tf.train.Feature 支持三种数据格式:
如果只希望保存一个元素而非数组,传入一个只有一个元素的数组即可
- import os
- import tensorflow as tf
-
- # 读取数据集中图片文件名和标签
- def read_image_filenames (data_dir) :
- cat_dir = data_dir + "cat/"
- dog_dir = data_dir + "dog/"
-
- cat_filenames = [cat_dir + fn for fn in os.listdir(cat_dir)]
- dog_filenames = [dog_dir + fn for fn in os.listdir(dog_dir)]
- filenames = cat_filenames + dog_filenames
-
- # 将cat类的标签设为0, dog类的标签设为1
- labels = [0]* len(cat_filenames) + [1] *len(dog_filenames)
-
- return filenames,labels
- # 定义生成TFRecord格式数据文件函数
- def write_TFRecord_file(filenames,labels,tfrecord_file):
- with tf.io.TFRecordWriter(tfrecord_file) as writer:
- for filename,label in zip(filenames,labels) :
- # 读取数据集图片到内存,image 为一个 Byte类型的字符串
- image = open(filename,"rb").read()
- # 建立tf.train.Feature字典
- feature = {
- # 图片是一个Bytes对象
- 'image': tf.train.Feature(bytes_list=tf.train.BytesList(value=[image])),
- # 标签是一个Int对象
- 'label': tf.train.Feature(int64_list=tf.train.Int64List(value=[label]))
- }
- # 通过feature字典建立Example
- example = tf.train.Example(features=tf.train.Features(feature=feature))
- # 将Example序列化并写入TFRecord 文件
- writer.write(example.SerializeToString())
- train_data_dir = './data_small/train/' # 数据集路径
- tfrecord_file = train_data_dir + 'train.tfrecords' # 生成的tfrecord路径
-
- if not os.path.isfile(tfrecord_file): # 判断train.tfrecord是否存在
- train_filenames,train_labels = read_image_filenames(train_data_dir)
- write_TFRecord_file(train_filenames,train_labels,tfrecord_file)
- print('write TFRecord file:',tfrecord_file)
- else:
- print(tfrecord_file,'already exists.')
- # 定义Feature结构,告诉解码器每个Feature的类型是什么,要与生成的TFrecord的类型一致
- feature_description = {
- "image":tf.io.FixedLenFeature([],tf.string),
- "label":tf.io.FixedLenFeature([],tf.int64)
- }
-
- # 将TFRecord 文件中的每一个序列化的 tf.train.Example 解码
- def parse_example(example_string):
- feature_dict = tf.io.parse_single_example(example_string,feature_description)
- feature_dict['image'] = tf.io.decode_jpeg(feature_dict['image']) # 解码JPEG图片
- feature_dict['image'] = tf.image.resize(feature_dict['image'],[224,224])/ 255.0 # 改变图片尺寸并进行归一化
- return feature_dict['image'],feature_dict['label']
- def read_TFRecond_file(tfrecord_file):
- # 读取TFRecord 文件
- raw_dataset = tf.data.TFRecordDataset(tfrecord_file)
- # 解码
- dataset = raw_dataset.map(parse_example)
-
- return dataset
tfrecord文件创建一个TFRecordDataset类的实例对象
参数:tf.data.TFRecordDataset(filenames,compression_type=None,
buffer_size=None,num_parallel_reads=None)
一般只传第一个参数filenames即可 ,生成的tfrecord文件
- # Dataset的数据缓冲器大小,和数据集大小及规律有关
- buffer_size = 20000
- # Dataset的数据批次大小,每批次多少个样本数
- batch_size = 8
- dataset_train = read_TFRecond_file(tfrecord_file) # 解码
- dataset_train = dataset_train.shuffle(buffer_size) # 打乱数据
- dataset_train = dataset_train.batch(batch_size) # 分批次进行读取
- import matplotlib.pyplot as plt
-
- sub_dataset = dataset_train.take(1) # 读取第一个批次
-
- for images,labels in sub_dataset:
- fig,axs = plt.subplots(1, batch_size)
- for i in range(batch_size):
- axs[i].set_title(labels.numpy()[i])
- axs[i].imshow(images.numpy()[i])
- axs[i].set_xticks([])
- axs[i].set_yticks([])
- plt.show()
案例实例地址:Tfrecord介绍以及实例· GitHub
链接:猫狗大战数据集
提取码:kqgt
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。