赞
踩
def read_and_decode(filename): filename_queue = tf.train.string_input_producer([filename]) reader = tf.TFRecordReader() _,serialized_example = reader.read(filename_queue) features = tf.parse_single_example(serialized_example, features={ 'image':tf.FixedLenFeature([],tf.string), 'label':tf.FixedLenFeature([],tf.int64) }) images = tf.decode_raw(features['image'],tf.uint8) images = tf.reshape(images,[144,144,3]) images = tf.cast(images,tf.float32) label = tf.cast(features['label'],tf.int32) images ,labels = tf.train.shuffle_batch([images,label],batch_size=10,capacity=4,min_after_dequeue=3) return images,labels if __name__ == '__main__': filename = 'D:/项目实现/compare_insightface/data/lfw_one.tfrecord' images,labels = read_and_decode(filename) with tf.Session() as sess: start_time = time.time() init_op = tf.global_variables_initializer() sess.run(init_op) coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(coord=coord) for i in range(300): image_raw,labels_raw = sess.run([images,labels]) print(image_raw.shape,labels_raw.shape) end_time = time.time() print('the time is',end_time - start_time) coord.request_stop() coord.join(threads)
2.使用dataset
def parse_function(example_proto): features = { 'image':tf.FixedLenFeature([],tf.string), 'label':tf.FixedLenFeature([],tf.int64) } features = tf.parse_single_example(example_proto,features) img = tf.decode_raw(features['image'],tf.uint8) img = tf.reshape(img,shape=(144,144,3)) r,g,b = tf.split(img,num_or_size_splits=3,axis=-1) img = tf.concat([b,g,r],axis=-1) label = tf.cast(features['label'],tf.int64) return img,label filename = 'D:/项目实现/compare_insightface/data/lfw_one.tfrecord' dataset = tf.data.TFRecordDataset(filename) dataset = dataset.map(parse_function) dataset = dataset.shuffle(buffer_size=10) dataset = dataset.batch(10).repeat(2) #数据集重复使用2次 iterator = dataset.make_initializable_iterator() next_element = iterator.get_next() with tf.Session() as sess: sess.run(iterator.initializer) start_time = time.time() for i in range(300): images_batch,labels_batch = sess.run(next_element) print(images_batch[0,0,0,0],labels_batch[0]) end_time = time.time() print('the time is ',end_time - start_time)
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。