当前位置:   article > 正文

一篇搞定TFRecord(内附代码+数据集)_tfrecord数据集

tfrecord数据集

TensorFlow现在有三种方法读取数据:

  • 直接加载数据 适合数据量小的
  • 文件读取数据 从文件读取数据,如CSV文件格式
  • TFRecord

现在我们来讲述最后一种方法:TFRecord

1.什么是TFRecord?
tfrecord数据文件是一种将图像数据和标签统一存储的二进制文件,能更好的利用内存,在tensorflow中快速的复制,移动,读取,存储等。对于我们普通开发者而言,我们并不需要关心这些,Tensorflow 提供了丰富的 API 可以帮助我们轻松读写 TFRecord文件。我们只关心如何使用Tensorflow生成TFRecord,并且读取它。

2.如何使用TFRecord?
因为深度学习很多都是与图片集打交道,那么,我们可以尝试下把一张张的图片转换成 TFRecord 文件。

不说很多原理,直接看代码,代码全部亲测可用。TensorFlow小白,有任何问题请及时指出。

数据集

本数据集采用kaggle的猫狗大战数据集中的的训练集(即train)。

数据集名称说明
train训练集
test测试集

生成TFRecord文件

step1 数据集准备工作

我们将一个文件下的所有猫狗图片的位置和对应的标签分别存放到两个list中。

def get_files(file_dir,is_random=True):
    image_list=[]
    label_list=[]
    dog_count=0
    cat_count=0
    for file in os.listdir(file_dir):
        name=file.split(sep='.')
        if(name[0]=='cat'):
            image_list.append(file_dir+file)
            label_list.append(0)
            cat_count+=1
        else:
            image_list.append(file_dir+file)
            label_list.append(1)
            dog_count+=1
    print('%d cats and %d dogs'%(cat_count,dog_count))

    image_list=np.asarray(image_list)
    label_list=np.asarray(label_list)

    if is_random:
        rnd_index=np.arange(len(image_list))
        np.random.shuffle(rnd_index)
        image_list=image_list[rnd_index]
        label_list=label_list[rnd_index]

    return image_list,label_list
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27

How to use?

get_files(file_dir,is_random=True)
<!--file_dir:图片文件中的所在位置-->
  • 1
  • 2
step2 TFRecord数据类型转换

在保存图片信息的时候,需要先将这些图片的信息转换为byte数据才能写入到tfrecord文件中。属性的取值可以为字符串(BytesList)、实数列表(FloatList)或者整数列表(Int64List)可以看见TFRecord是以字典的形式存储的,这里我们存储了image、label、width、height的信息。

 def int64_feature(values):
    if not isinstance(values,(tuple,list)):
        values = [values]
    return tf.train.Feature(int64_list=tf.train.Int64List(value=values))

def bytes_feature(values):
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[values]))

def float_feature(value):
    return tf.train.Feature(float_list=tf.train.FloatList(value=value))

def image_to_tfexample(image_data, label,size):
    return tf.train.Example(features=tf.train.Features(feature={
        'image': bytes_feature(image_data),
        'label': int64_feature(label),
        'image_width':int64_feature(size[0]),
        'image_height':int64_feature(size[1])
    }))
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
step3 数据存储

将之前的两个list中的信息转化为我们需要的TFRecord数据类型文件

def _convert_dataset(image_list, label_list, tfrecord_dir):
    """ Convert data to TFRecord format. """
    with tf.Graph().as_default():
        with tf.Session() as sess:
            if not os.path.exists(tfrecord_dir):
                os.makedirs(tfrecord_dir)
            output_filename = os.path.join(tfrecord_dir, "train.tfrecord")
            tfrecord_writer = tf.python_io.TFRecordWriter(output_filename)
            length = len(image_list)
            for i in range(length):
                # 图像数据
                image_data = Image.open(image_list[i],'r')
                size = image_data.size
                image_data = image_data.tobytes()
                label = label_list[i]
                example = image_to_tfexample(image_data, label,size)
                tfrecord_writer.write(example.SerializeToString())
                sys.stdout.write('\r>> Converting image %d/%d' % (i + 1, length))
                sys.stdout.flush()

    sys.stdout.write('\n')
    sys.stdout.flush()
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22

How to use?

_convert_dataset(image_list, label_list, tfrecord_dir)
<!--image_list,label_list:上述产生的两个list-->
<!--tfrecord_dir:你要保存TFRecord文件的位置-->
  • 1
  • 2
  • 3
解析TFRecord数据

你们不禁要问了:怎么解析这么复杂的数据呢?我们使用tf.parse_single_example() 将存储为字典形式的TFRecord数据解析出来。这样我们就将image、label、width、height的信息就原封不动“拿”出来了。

def read_and_decode(tfrecord_path):
    data_files = tf.gfile.Glob(tfrecord_path)  #data_path为TFRecord格式数据的路径
    filename_queue = tf.train.string_input_producer(data_files,shuffle=True)
    reader = tf.TFRecordReader()
    _,serialized_example = reader.read(filename_queue)
    features = tf.parse_single_example(serialized_example,
                                       features={
                                           'label':tf.FixedLenFeature([],tf.int64),
                                           'image':tf.FixedLenFeature([],tf.string),
                                           'image_width': tf.FixedLenFeature([],tf.int64),
                                           'image_height': tf.FixedLenFeature([],tf.int64),
                                       })

    image = tf.decode_raw(features['image'],tf.uint8)
    image_width = tf.cast(features['image_width'],tf.int32)
    image_height = tf.cast(features['image_height'],tf.int32)
    image = tf.reshape(image,[image_height,image_width,3])
    label = tf.cast(features['label'], tf.int32)
    return image,label
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19

How to use?

read_and_decode(tfrecord_path)
<!--tfrecord_path:就是你刚刚存放TFRecord的文件位置,我们将它取出来就好了。-->
  • 1
  • 2
加载数据集

数据拿出来了,那我们就要用它来组成一个个batch方便我们训练模型。

def batch(image,label):
    # Load training set.
    #一定要reshape一下image,不然会报错。
    image = tf.image.resize_images(image, [128, 128])
    with tf.name_scope('input_train'):
        image_batch, label_batch = tf.train.shuffle_batch(
               [image, label],
               batch_size=30, 
               capacity=2000,                                                        
               min_after_dequeue=1500)
    return image_batch, label_batch
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11

How to use?

batch(image,label)
<!--image,label:我们刚刚解析出来的图片和标签-->
  • 1
  • 2

总结
在这里插入图片描述
代码链接(https://github.com/MagaretJi/TFRecord)
链接: https://pan.baidu.com/s/1AgHPMMkLZzR4HrEdWfuNhw 密码: vtg6

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/Cpp五条/article/detail/123706?site
推荐阅读
相关标签
  

闽ICP备14008679号