当前位置:   article > 正文

深度学习(1):加载图片数据集--TFRecord(Tensorflow)_tensorflow加载图片

tensorflow加载图片

目录

1、TFRecord介绍

2、TFRecord格式数据文件处理过程

3、TFRecord格式

4、生成TFRecord格式数据

5、TFRecord数据文件解码

6、解码并生成Dataset数据集

7、查看第一批元素


1、TFRecord介绍

TFRecord 是 TensorFlow 中的数据集中存储格式,TFRecord是一种二进制文件

将数据集整理成 TFRecord 格式后,TensorFlow 就可以高效地读取和处理这些数据集,从而更高效地进行大规模的模型训练

TFRecord 内部使用了二进制数据编码方案,它只占用一个内存块,只需要一次性加载一个二进制文件的方式即可。简单,快速,尤其对大型训练数据很友好。而且当我们的训练数据量比较大的时候,可以将数据分成多个 TFRecord 文件,来提高处理效率。

2、TFRecord格式数据文件处理过程

将形式各样的数据集整理为 TFRecord 格式,可以对数据集中的每个元素进行以下步骤:

(1)读取该数据元素到内存;

(2)将该元素转换为 tf.train.Example 对象(每一个 tf.train.Example 由若干个 tf.train.Feature 的字典组成,因此需要先建立 Feature 的字典);

(3)将该 tf.train.Example 对象序列化为字符串,并通过一个预先定义的 tf.io.TFRecordWriter 写入 TFRecord 文件。

读取 TFRecord 数据可按照以下步骤:

(1)通过 tf.data.TFRecordDataset 读入原始的 TFRecord 文件(此时文件中的 tf.train.Example 对象尚未被反序列化),获得一个 tf.data.Dataset数据集对象;

(2)通过 Dataset.map 方法,对该数据集对象中的每一个序列化的 tf.train.Example 字符串执行 tf.io.parse_single_example 函数,从而实现反序列化。

3、TFRecord格式

TFRecord内部包含了多个tf.train.Example, 而Exampleprotocol buffer(protobuf) 数据标准的实现,在一个Example消息体中包含了一系列的tf.train.feature属性,而每一个feature 是一个key-value的键值对,其中,key 是string类型,而value 的取值有三种:

  • bytes_list:可以存储stringbyte两种数据类型。
  • float_list:可以存储float(float32)double(float64) 两种数据类型 。
  • int64_list:可以存储bool, enum, int32, uint32, int64, uint64 。

tf.train.Feature 支持三种数据格式:

  • tf.train.BytesList :字符串或原始 Byte 文件(如图片),通过 bytes_list 参数传入一个由字符串数组初始化的 tf.train.BytesList 对象;
  • tf.train.FloatList :浮点数,通过 float_list 参数传入一个由浮点数数组初始化的 tf.train.FloatList 对象;
  • tf.train.Int64List :整数,通过 int64_list 参数传入一个由整数数组初始化的 tf.train.Int64List 对象。

如果只希望保存一个元素而非数组,传入一个只有一个元素的数组即可

4、生成TFRecord格式数据

  1. import os
  2. import tensorflow as tf
  1. # 读取数据集中图片文件名和标签
  2. def read_image_filenames (data_dir) :
  3. cat_dir = data_dir + "cat/"
  4. dog_dir = data_dir + "dog/"
  5. cat_filenames = [cat_dir + fn for fn in os.listdir(cat_dir)]
  6. dog_filenames = [dog_dir + fn for fn in os.listdir(dog_dir)]
  7. filenames = cat_filenames + dog_filenames
  8. # 将cat类的标签设为0, dog类的标签设为1
  9. labels = [0]* len(cat_filenames) + [1] *len(dog_filenames)
  10. return filenames,labels
  1. # 定义生成TFRecord格式数据文件函数 
  2. def write_TFRecord_file(filenames,labels,tfrecord_file):
  3. with tf.io.TFRecordWriter(tfrecord_file) as writer:
  4. for filename,label in zip(filenames,labels) :
  5. # 读取数据集图片到内存,image 为一个 Byte类型的字符串
  6. image = open(filename,"rb").read()
  7. # 建立tf.train.Feature字典
  8. feature = {
  9. # 图片是一个Bytes对象
  10. 'image': tf.train.Feature(bytes_list=tf.train.BytesList(value=[image])),
  11. # 标签是一个Int对象
  12. 'label': tf.train.Feature(int64_list=tf.train.Int64List(value=[label]))
  13. }
  14. # 通过feature字典建立Example
  15. example = tf.train.Example(features=tf.train.Features(feature=feature))
  16. # 将Example序列化并写入TFRecord 文件
  17. writer.write(example.SerializeToString())
  1. train_data_dir = './data_small/train/' # 数据集路径
  2. tfrecord_file = train_data_dir + 'train.tfrecords' # 生成的tfrecord路径
  3. if not os.path.isfile(tfrecord_file): # 判断train.tfrecord是否存在
  4. train_filenames,train_labels = read_image_filenames(train_data_dir)
  5. write_TFRecord_file(train_filenames,train_labels,tfrecord_file)
  6. print('write TFRecord file:',tfrecord_file)
  7. else:
  8. print(tfrecord_file,'already exists.')

5、TFRecord数据文件解码

 1、定义TFRecord数据文件解码函数

  1. # 定义Feature结构,告诉解码器每个Feature的类型是什么,要与生成的TFrecord的类型一致
  2. feature_description = {
  3. "image":tf.io.FixedLenFeature([],tf.string),
  4. "label":tf.io.FixedLenFeature([],tf.int64)
  5. }
  6. # 将TFRecord 文件中的每一个序列化的 tf.train.Example 解码
  7. def parse_example(example_string):
  8. feature_dict = tf.io.parse_single_example(example_string,feature_description)
  9. feature_dict['image'] = tf.io.decode_jpeg(feature_dict['image']) # 解码JPEG图片
  10. feature_dict['image'] = tf.image.resize(feature_dict['image'],[224,224])/ 255.0 # 改变图片尺寸并进行归一化
  11. return feature_dict['image'],feature_dict['label']

 2、定义读取TFRecord文件,解码并生成Dataset数据集的函数

  1. def read_TFRecond_file(tfrecord_file):
  2. # 读取TFRecord 文件
  3. raw_dataset = tf.data.TFRecordDataset(tfrecord_file)
  4. # 解码
  5. dataset = raw_dataset.map(parse_example)
  6. return dataset

3、tf.data.TFRecordDataset

tfrecord文件创建一个TFRecordDataset类的实例对象 

参数:tf.data.TFRecordDataset(filenames,compression_type=None,

                buffer_size=None,num_parallel_reads=None)
一般只传第一个参数filenames即可 ,生成的tfrecord文件

6、解码并生成Dataset数据集

  1. # Dataset的数据缓冲器大小,和数据集大小及规律有关
  2. buffer_size = 20000
  3. # Dataset的数据批次大小,每批次多少个样本数
  4. batch_size = 8
  1. dataset_train = read_TFRecond_file(tfrecord_file) # 解码
  2. dataset_train = dataset_train.shuffle(buffer_size) # 打乱数据
  3. dataset_train = dataset_train.batch(batch_size) # 分批次进行读取

7、查看第一批元素

  1. import matplotlib.pyplot as plt
  2. sub_dataset = dataset_train.take(1) # 读取第一个批次
  3. for images,labels in sub_dataset:
  4. fig,axs = plt.subplots(1, batch_size)
  5. for i in range(batch_size):
  6. axs[i].set_title(labels.numpy()[i])
  7. axs[i].imshow(images.numpy()[i])
  8. axs[i].set_xticks([])
  9. axs[i].set_yticks([])
  10. plt.show()

案例实例地址:Tfrecord介绍以及实例· GitHub  

链接:猫狗大战数据集
提取码:kqgt

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/很楠不爱3/article/detail/123821
推荐阅读
相关标签
  

闽ICP备14008679号