赞
踩
对于大量的图像数据,TensorFlow提供了一种统一的格式来存储数据——TFRecord。TFRecord文件是以二进制进行存储数据的,适合以串行的方式读取大批量数据,虽然它的内部格式复杂,但是它可以很好地利用内存,方便地复制和移动,更符合TensorFlow执行引擎的方式。
TFReocrd文件中的数据都是通过tf.train.Example Protocol Buffer的格式存储的。
message Example{
Features features = 1;
};
message Features{
map<string,Feature> feature = 1;
};
message Feature{
oneof kind{
BytesList bytes_list = 1;
FloatList float_list = 2;
Int64List int64_list = 3;
}
};
tf.train.Example中包含了一个从属性名称到取值的字典。其中属性名称为一个字符串,属性的取值可以为字符串(BytesList)、实数列表(FloatList)或者整数列表(Int64List)。
本文采用的图像数据集来自stanford car dataset
将数据集的图片全部放入data文件夹下,label文件(我已改名为label.txt)放在与data文件夹同根目录下。
# -*- coding = utf-8 -*-
from __future__ import absolute_import,division,print_function
import numpy as np
import tensorflow as tf
import time
from scipy.misc import imread,imresize
from os import walk
from os.path import join
#图片存放位置
DATA_DIR = 'data/'
#图片信息
IMG_HEIGHT = 227
IMG_WIDTH = 227
IMG_CHANNELS = 3
NUM_TRAIN = 7000
NUM_VALIDARION = 1144
#读取图片
def read_images(path):
filenames = next(walk(path))[2]
num_files = len(filenames)
images = np.zeros((num_files,IMG_HEIGHT,IMG_WIDTH,IMG_CHANNELS),dtype=np.uint8)
labels = np.zeros((num_files, ), dtype=np.uint8)
f = open('label.txt')
lines = f.readlines()
#遍历所有的图片和label,将图片resize到[227,227,3]
for i,filename in enumerate(filenames):
img = imread(join(path,filename))
img = imresize(img,(IMG_HEIGHT,IMG_WIDTH))
images[i] = img
labels[i] = int(lines[i])
f.close()
return images,labels
#生成整数型的属性
def _int64_feature(value):
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
#生成字符串型的属性
def _bytes_feature(value):
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
def convert(images,labels,name):
#获取要转换为TFRecord文件的图片数目
num = images.shape[0]
#输出TFRecord文件的文件名
filename = name+'.tfrecords'
print('Writting',filename)
#创建一个writer来写TFRecord文件
writer = tf.python_io.TFRecordWriter(filename)
for i in range(num):
#将图像矩阵转化为一个字符串
img_raw = images[i].tostring()
#将一个样例转化为Example Protocol Buffer,并将所有需要的信息写入数据结构
example = tf.train.Example(features=tf.train.Features(feature={
'label': _int64_feature(int(labels[i])),
'image_raw': _bytes_feature(img_raw)}))
#将example写入TFRecord文件
writer.write(example.SerializeToString())
writer.close()
print('Writting End')
def main(argv):
print('reading images begin')
start_time = time.time()
train_images,train_labels = read_images(DATA_DIR)
duration = time.time() - start_time
print("reading images end , cost %d sec" %duration)
#get validation
validation_images = train_images[:NUM_VALIDARION,:,:,:]
validation_labels = train_labels[:NUM_VALIDARION]
train_images = train_images[NUM_VALIDARION:,:,:,:]
train_labels = train_labels[NUM_VALIDARION:]
#convert to tfrecords
print('convert to tfrecords begin')
start_time = time.time()
convert(train_images,train_labels,'train')
convert(validation_images,validation_labels,'validation')
duration = time.time() - start_time
print('convert to tfrecords end , cost %d sec' %duration)
if __name__ == '__main__':
tf.app.run()
本文将数据集中的7000张用于训练,1144张用于验证。
# -*- coding = utf-8 -*-
from __future__ import absolute_import,division,print_function
import numpy as np
from os.path import join
import tensorflow as tf
import convert_to_tfrecords
#TFRcord文件
TRAIN_FILE = 'train.tfrecords'
VALIDATION_FILE = 'validation.tfrecords'
#图片信息
NUM_CLASSES = 196
IMG_HEIGHT = convert_to_tfrecords.IMG_HEIGHT
IMG_WIDTH = convert_to_tfrecords.IMG_WIDTH
IMG_CHANNELS = convert_to_tfrecords.IMG_CHANNELS
IMG_PIXELS = IMG_HEIGHT * IMG_WIDTH * IMG_CHANNELS
NUM_TRAIN = convert_to_tfrecords.NUM_TRAIN
NUM_VALIDARION = convert_to_tfrecords.NUM_VALIDARION
def read_and_decode(filename_queue):
#创建一个reader来读取TFRecord文件中的样例
reader = tf.TFRecordReader()
#从文件中读出一个样例
_,serialized_example = reader.read(filename_queue)
#解析读入的一个样例
features = tf.parse_single_example(serialized_example,features={
'label':tf.FixedLenFeature([],tf.int64),
'image_raw':tf.FixedLenFeature([],tf.string)
})
#将字符串解析成图像对应的像素数组
image = tf.decode_raw(features['image_raw'],tf.uint8)
label = tf.cast(features['label'],tf.int32)
image.set_shape([IMG_PIXELS])
image = tf.reshape(image,[IMG_HEIGHT,IMG_WIDTH,IMG_CHANNELS])
image = tf.cast(image, tf.float32) * (1. / 255) - 0.5
return image,label
#用于获取一个batch_size的图像和label
def inputs(data_set,batch_size,num_epochs):
if not num_epochs:
num_epochs = None
if data_set == 'train':
file = TRAIN_FILE
else:
file = VALIDATION_FILE
with tf.name_scope('input') as scope:
filename_queue = tf.train.string_input_producer([file], num_epochs=num_epochs)
image,label = read_and_decode(filename_queue)
#随机获得batch_size大小的图像和label
images,labels = tf.train.shuffle_batch([image, label],
batch_size=batch_size,
num_threads=64,
capacity=1000 + 3 * batch_size,
min_after_dequeue=1000
)
return images,labels
读取一个batch的图像和label只需要调用inputs()函数就行了。
结果生成了一个1GB的train.tfrecords和168MB的validation.tfrecords
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。