赞
踩
想记录一下自己制作训练集并训练的过、希望踩过的坑能帮助后面入坑的人。
本次使用的训练集的是kaggle中经典的猫狗大战数据集(提取码:ufz5)。因为本人笔记本配置很差还不是N卡所以把train的数据分成了训练集和测试集并没有使用原数据集中的test。在tensorflow中使用TFRecord格式喂给神经网络但是现在官方推荐使用tf.data
但这个API还没看所以还是使用了TFRecord。
代码注释还挺清楚就直接上代码了。
import os import tensorflow as tf from PIL import Image # 源数据地址 cwd = 'C:/Users/Qigq/Desktop/P_Data/kaggle/train' # 生成record路径及文件名 train_record_path = "C:/Users/Qigq/Desktop/P_Data/kaggle/ouputdata/train.tfrecords" test_record_path = "C:/Users/Qigq/Desktop/P_Data/kaggle/ouputdata/test.tfrecords" # 分类 classes = {'cat','dog'} def _byteslist(value): """二进制属性""" return tf.train.Feature(bytes_list = tf.train.BytesList(value = [value])) def _int64list(value): """整数属性""" return tf.train.Feature(int64_list = tf.train.Int64List(value = [value])) def create_train_record(): """创建训练集tfrecord""" writer = tf.python_io.TFRecordWriter(train_record_path) # 创建一个writer NUM = 1 # 显示创建过程(计数) for index, name in enumerate(classes): class_path = cwd + "/" + name + '/' l = int(len(os.listdir(class_path)) * 0.7) # 取前70%创建训练集 for img_name in os.listdir(class_path)[:l]: img_path = class_path + img_name img = Image.open(img_path) img = img.resize((128, 128)) # resize图片大小 img_raw = img.tobytes() # 将图片转化为原生bytes example = tf.train.Example( # 封装到Example中 features=tf.train.Features(feature={ "label":_int64list(index), # label必须为整数类型属性 'img_raw':_byteslist(img_raw) # 图片必须为二进制属性 })) writer.write(example.SerializeToString()) print('Creating train record in ',NUM) NUM += 1 writer.close() # 关闭writer print("Create train_record successful!") def create_test_record(): """创建测试tfrecord""" writer = tf.python_io.TFRecordWriter(test_record_path) NUM = 1 for index, name in enumerate(classes): class_path = cwd + '/' + name + '/' l = int(len(os.listdir(class_path)) * 0.7) for img_name in os.listdir(class_path)[l:]: # 剩余30%作为测试集 img_path = class_path + img_name img = Image.open(img_path) img = img.resize((128, 128)) img_raw = img.tobytes() # 将图片转化为原生bytes # print(index,img_raw) example = tf.train.Example( features=tf.train.Features(feature={ "label":_int64list(index), 'img_raw':_byteslist(img_raw) })) writer.write(example.SerializeToString()) print('Creating test record in ',NUM) NUM += 1 writer.close() print("Create test_record successful!") def read_record(filename): """读取tfrecord""" filename_queue = tf.train.string_input_producer([filename]) # 创建文件队列 reader = tf.TFRecordReader() # 创建reader _, serialized_example = reader.read(filename_queue) features = tf.parse_single_example( serialized_example, features={ 'label': tf.FixedLenFeature([], tf.int64), 'img_raw': tf.FixedLenFeature([], tf.string) } ) label = features['label'] img = features['img_raw'] img = tf.decode_raw(img, tf.uint8) img = tf.reshape(img, [128, 128, 3]) img = tf.cast(img, tf.float32) * (1. / 255) - 0.5 # 归一化 label = tf.cast(label, tf.int32) return img, label def get_batch_record(filename,batch_size): """获取batch""" image,label = read_record(filename) image_batch,label_batch = tf.train.shuffle_batch([image,label], # 随机抽取batch size个image、label batch_size=batch_size, capacity=2000, min_after_dequeue=1000) return image_batch,label_batch def main(): create_train_record() create_test_record() if __name__ == '__main__': main() ### 调用示例 ### # create_train_record(cwd,classes) # create_test_record(cwd,classes) # image_batch,label_batch = get_batch_record(filename,32) # init = tf.initialize_all_variables() # # with tf.Session() as sess: # sess.run(init) # # coord = tf.train.Coordinator() # threads = tf.train.start_queue_runners(sess=sess,coord=coord) # # for i in range(1): # image,label = sess.run([image_batch,label_batch]) # print(image.shape,1) # # # coord.request_stop() # coord.join(threads)
下一篇将介绍定义神经网络
如有错误望多多指教~~
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。