当前位置:   article > 正文

将图像转为TFRecord文件并读取TFRecord文件_将文件存入record.txt

将文件存入record.txt

1 TFRecord格式介绍

 对于大量的图像数据,TensorFlow提供了一种统一的格式来存储数据——TFRecord。TFRecord文件是以二进制进行存储数据的,适合以串行的方式读取大批量数据,虽然它的内部格式复杂,但是它可以很好地利用内存,方便地复制和移动,更符合TensorFlow执行引擎的方式。
 TFReocrd文件中的数据都是通过tf.train.Example Protocol Buffer的格式存储的。

message Example{
    Features features = 1;
};
message Features{
    map<string,Feature> feature = 1;
};
message Feature{
    oneof kind{
        BytesList bytes_list = 1;
        FloatList float_list = 2;
        Int64List int64_list = 3;
    }
};
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13

 tf.train.Example中包含了一个从属性名称到取值的字典。其中属性名称为一个字符串,属性的取值可以为字符串(BytesList)、实数列表(FloatList)或者整数列表(Int64List)。

2 图像数据集

本文采用的图像数据集来自stanford car dataset
将数据集的图片全部放入data文件夹下,label文件(我已改名为label.txt)放在与data文件夹同根目录下。

3 将图像转为TFRecord

# -*- coding = utf-8 -*-

from __future__ import absolute_import,division,print_function

import numpy as np
import tensorflow as tf
import time
from scipy.misc import imread,imresize
from os import  walk
from os.path import join

#图片存放位置
DATA_DIR = 'data/'

#图片信息
IMG_HEIGHT = 227
IMG_WIDTH = 227
IMG_CHANNELS = 3
NUM_TRAIN = 7000
NUM_VALIDARION = 1144

#读取图片
def read_images(path):
    filenames = next(walk(path))[2]
    num_files = len(filenames)
    images = np.zeros((num_files,IMG_HEIGHT,IMG_WIDTH,IMG_CHANNELS),dtype=np.uint8)
    labels = np.zeros((num_files, ), dtype=np.uint8)
    f = open('label.txt')
    lines = f.readlines()
    #遍历所有的图片和label,将图片resize到[227,227,3]
    for i,filename in enumerate(filenames):
        img = imread(join(path,filename))
        img = imresize(img,(IMG_HEIGHT,IMG_WIDTH))
        images[i] = img
        labels[i] = int(lines[i])
    f.close()
    return images,labels

#生成整数型的属性
def _int64_feature(value):
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))

#生成字符串型的属性
def _bytes_feature(value):
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

def convert(images,labels,name):
    #获取要转换为TFRecord文件的图片数目
    num = images.shape[0]
    #输出TFRecord文件的文件名
    filename = name+'.tfrecords'
    print('Writting',filename)
    #创建一个writer来写TFRecord文件
    writer = tf.python_io.TFRecordWriter(filename)
    for i in range(num):
        #将图像矩阵转化为一个字符串
        img_raw = images[i].tostring()
        #将一个样例转化为Example Protocol Buffer,并将所有需要的信息写入数据结构
        example = tf.train.Example(features=tf.train.Features(feature={
            'label': _int64_feature(int(labels[i])),
            'image_raw': _bytes_feature(img_raw)}))
        #将example写入TFRecord文件
        writer.write(example.SerializeToString())
    writer.close()
    print('Writting End')

def main(argv):
    print('reading images begin')
    start_time = time.time()
    train_images,train_labels = read_images(DATA_DIR)
    duration = time.time() - start_time
    print("reading images end , cost %d sec" %duration)

    #get validation
    validation_images = train_images[:NUM_VALIDARION,:,:,:]
    validation_labels = train_labels[:NUM_VALIDARION]
    train_images = train_images[NUM_VALIDARION:,:,:,:]
    train_labels = train_labels[NUM_VALIDARION:]

    #convert to tfrecords
    print('convert to tfrecords begin')
    start_time = time.time()
    convert(train_images,train_labels,'train')
    convert(validation_images,validation_labels,'validation')
    duration = time.time() - start_time
    print('convert to tfrecords end , cost %d sec' %duration)

if __name__ == '__main__':
    tf.app.run()
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89

 本文将数据集中的7000张用于训练,1144张用于验证。

4 读取TFRecord文件

# -*- coding = utf-8 -*-

from __future__ import absolute_import,division,print_function

import numpy as np
from os.path import join
import tensorflow as tf
import convert_to_tfrecords

#TFRcord文件
TRAIN_FILE = 'train.tfrecords'
VALIDATION_FILE = 'validation.tfrecords'

#图片信息
NUM_CLASSES = 196
IMG_HEIGHT = convert_to_tfrecords.IMG_HEIGHT
IMG_WIDTH = convert_to_tfrecords.IMG_WIDTH
IMG_CHANNELS = convert_to_tfrecords.IMG_CHANNELS
IMG_PIXELS = IMG_HEIGHT * IMG_WIDTH * IMG_CHANNELS

NUM_TRAIN = convert_to_tfrecords.NUM_TRAIN
NUM_VALIDARION = convert_to_tfrecords.NUM_VALIDARION

def read_and_decode(filename_queue):
    #创建一个reader来读取TFRecord文件中的样例
    reader = tf.TFRecordReader()
    #从文件中读出一个样例
    _,serialized_example = reader.read(filename_queue)
    #解析读入的一个样例
    features = tf.parse_single_example(serialized_example,features={
        'label':tf.FixedLenFeature([],tf.int64),
        'image_raw':tf.FixedLenFeature([],tf.string)
        })
    #将字符串解析成图像对应的像素数组
    image = tf.decode_raw(features['image_raw'],tf.uint8)
    label = tf.cast(features['label'],tf.int32)

    image.set_shape([IMG_PIXELS])
    image = tf.reshape(image,[IMG_HEIGHT,IMG_WIDTH,IMG_CHANNELS])
    image = tf.cast(image, tf.float32) * (1. / 255) - 0.5

    return image,label

#用于获取一个batch_size的图像和label
def inputs(data_set,batch_size,num_epochs):
    if not num_epochs:
        num_epochs = None
    if data_set == 'train':
        file = TRAIN_FILE
    else:
        file = VALIDATION_FILE

    with tf.name_scope('input') as scope:
        filename_queue = tf.train.string_input_producer([file], num_epochs=num_epochs)
    image,label = read_and_decode(filename_queue)
    #随机获得batch_size大小的图像和label
    images,labels = tf.train.shuffle_batch([image, label], 
        batch_size=batch_size,
        num_threads=64,
        capacity=1000 + 3 * batch_size,
        min_after_dequeue=1000
    )

    return images,labels
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64

读取一个batch的图像和label只需要调用inputs()函数就行了。

5 结果

结果生成了一个1GB的train.tfrecords和168MB的validation.tfrecords

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/不正经/article/detail/123766
推荐阅读
相关标签
  

闽ICP备14008679号