赞
踩
TensorFlow现在有三种方法读取数据:
现在我们来讲述最后一种方法:TFRecord
1.什么是TFRecord?
tfrecord数据文件是一种将图像数据和标签统一存储的二进制文件,能更好的利用内存,在tensorflow中快速的复制,移动,读取,存储等。对于我们普通开发者而言,我们并不需要关心这些,Tensorflow 提供了丰富的 API 可以帮助我们轻松读写 TFRecord文件。我们只关心如何使用Tensorflow生成TFRecord,并且读取它。
2.如何使用TFRecord?
因为深度学习很多都是与图片集打交道,那么,我们可以尝试下把一张张的图片转换成 TFRecord 文件。
不说很多原理,直接看代码,代码全部亲测可用。TensorFlow小白,有任何问题请及时指出。
本数据集采用kaggle的猫狗大战数据集中的的训练集(即train)。
数据集名称 | 说明 |
---|---|
train | 训练集 |
test | 测试集 |
我们将一个文件下的所有猫狗图片的位置和对应的标签分别存放到两个list中。
def get_files(file_dir,is_random=True): image_list=[] label_list=[] dog_count=0 cat_count=0 for file in os.listdir(file_dir): name=file.split(sep='.') if(name[0]=='cat'): image_list.append(file_dir+file) label_list.append(0) cat_count+=1 else: image_list.append(file_dir+file) label_list.append(1) dog_count+=1 print('%d cats and %d dogs'%(cat_count,dog_count)) image_list=np.asarray(image_list) label_list=np.asarray(label_list) if is_random: rnd_index=np.arange(len(image_list)) np.random.shuffle(rnd_index) image_list=image_list[rnd_index] label_list=label_list[rnd_index] return image_list,label_list
How to use?
get_files(file_dir,is_random=True)
<!--file_dir:图片文件中的所在位置-->
在保存图片信息的时候,需要先将这些图片的信息转换为byte数据才能写入到tfrecord文件中。属性的取值可以为字符串(BytesList)、实数列表(FloatList)或者整数列表(Int64List)可以看见TFRecord是以字典的形式存储的,这里我们存储了image、label、width、height的信息。
def int64_feature(values): if not isinstance(values,(tuple,list)): values = [values] return tf.train.Feature(int64_list=tf.train.Int64List(value=values)) def bytes_feature(values): return tf.train.Feature(bytes_list=tf.train.BytesList(value=[values])) def float_feature(value): return tf.train.Feature(float_list=tf.train.FloatList(value=value)) def image_to_tfexample(image_data, label,size): return tf.train.Example(features=tf.train.Features(feature={ 'image': bytes_feature(image_data), 'label': int64_feature(label), 'image_width':int64_feature(size[0]), 'image_height':int64_feature(size[1]) }))
将之前的两个list中的信息转化为我们需要的TFRecord数据类型文件
def _convert_dataset(image_list, label_list, tfrecord_dir): """ Convert data to TFRecord format. """ with tf.Graph().as_default(): with tf.Session() as sess: if not os.path.exists(tfrecord_dir): os.makedirs(tfrecord_dir) output_filename = os.path.join(tfrecord_dir, "train.tfrecord") tfrecord_writer = tf.python_io.TFRecordWriter(output_filename) length = len(image_list) for i in range(length): # 图像数据 image_data = Image.open(image_list[i],'r') size = image_data.size image_data = image_data.tobytes() label = label_list[i] example = image_to_tfexample(image_data, label,size) tfrecord_writer.write(example.SerializeToString()) sys.stdout.write('\r>> Converting image %d/%d' % (i + 1, length)) sys.stdout.flush() sys.stdout.write('\n') sys.stdout.flush()
How to use?
_convert_dataset(image_list, label_list, tfrecord_dir)
<!--image_list,label_list:上述产生的两个list-->
<!--tfrecord_dir:你要保存TFRecord文件的位置-->
你们不禁要问了:怎么解析这么复杂的数据呢?我们使用tf.parse_single_example() 将存储为字典形式的TFRecord数据解析出来。这样我们就将image、label、width、height的信息就原封不动“拿”出来了。
def read_and_decode(tfrecord_path): data_files = tf.gfile.Glob(tfrecord_path) #data_path为TFRecord格式数据的路径 filename_queue = tf.train.string_input_producer(data_files,shuffle=True) reader = tf.TFRecordReader() _,serialized_example = reader.read(filename_queue) features = tf.parse_single_example(serialized_example, features={ 'label':tf.FixedLenFeature([],tf.int64), 'image':tf.FixedLenFeature([],tf.string), 'image_width': tf.FixedLenFeature([],tf.int64), 'image_height': tf.FixedLenFeature([],tf.int64), }) image = tf.decode_raw(features['image'],tf.uint8) image_width = tf.cast(features['image_width'],tf.int32) image_height = tf.cast(features['image_height'],tf.int32) image = tf.reshape(image,[image_height,image_width,3]) label = tf.cast(features['label'], tf.int32) return image,label
How to use?
read_and_decode(tfrecord_path)
<!--tfrecord_path:就是你刚刚存放TFRecord的文件位置,我们将它取出来就好了。-->
数据拿出来了,那我们就要用它来组成一个个batch方便我们训练模型。
def batch(image,label):
# Load training set.
#一定要reshape一下image,不然会报错。
image = tf.image.resize_images(image, [128, 128])
with tf.name_scope('input_train'):
image_batch, label_batch = tf.train.shuffle_batch(
[image, label],
batch_size=30,
capacity=2000,
min_after_dequeue=1500)
return image_batch, label_batch
How to use?
batch(image,label)
<!--image,label:我们刚刚解析出来的图片和标签-->
总结
代码链接(https://github.com/MagaretJi/TFRecord)
链接: https://pan.baidu.com/s/1AgHPMMkLZzR4HrEdWfuNhw 密码: vtg6
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。