赞
踩
参考文章:
将数据导入TensorFlow
使用tensorflow训练自己的数据集(一)——制作数据集
用Tensorflow处理自己的数据:制作自己的TFRecords数据集
在用tensorflow来进行网络模型的训练时,我们总是需要先输入数据,这样才能对网络模型进行训练。这就涉及将数据导入TensorFlow的问题了
总共有四种方法将数据导入到TensorFlow中:
在这里我们使用第三种方法,将自己的数据集先写入TFRecord文件,然后从TFRecord文件将数据导入TensorFlow。
【1】首先编写一个小程序来获取自己的数据,将它放在一个示例协议缓冲区中,将缓冲区序列化为一个字符串,然后使用tf.python_io.TFRecordWriter将该字符串写入TFRecords文件。下面直接上代码,根据代码注释很容易理解:
def create_train_record(): """创建训练集tfrecord""" writer = tf.python_io.TFRecordWriter(train_record_path) # 创建一个writer for index, name in enumerate(classes): #生成由二元组构成的一个迭代对象,每个二元组由可迭代参数的索引号及其对应的元素组成 class_path = cwd + "\\" + name + "\\" l = int(len(os.listdir(class_path)) * 0.6) # 取前60%创建训练集 for img_name in os.listdir(class_path)[:l]: img_path = class_path + img_name img = Image.open(img_path) img = img.resize((128, 128)) # resize图片大小 img_raw = img.tobytes() # 将图片转化为原生bytes即二进制格式 example = tf.train.Example( # 封装到示例协议缓冲区Example中 features=tf.train.Features(feature={ "label":_int64list(index), # label必须为整数类型属性 'img_raw':_byteslist(img_raw) # 图片必须为二进制属性 })) writer.write(example.SerializeToString()) #序列化为字符串,将字符串写入TFRecords文件 writer.close() # 关闭writer def create_test_record(): """创建测试集tfrecord""" writer = tf.python_io.TFRecordWriter(test_record_path) for index, name in enumerate(classes): class_path = cwd + "\\" + name + "\\" l = int(len(os.listdir(class_path)) * 0.6) h = int(len(os.listdir(class_path)) * 0.9) for img_name in os.listdir(class_path)[l:h]: # 中间30%作为测试集 img_path = class_path + img_name img = Image.open(img_path) img = img.resize((128, 128)) img_raw = img.tobytes() # 将图片转化为原生bytes # print(index,img_raw) example = tf.train.Example( features=tf.train.Features(feature={ "label":_int64list(index), 'img_raw':_byteslist(img_raw) })) writer.write(example.SerializeToString()) writer.close() def create_val_record(): """创建验证集tfrecord""" writer = tf.python_io.TFRecordWriter(val_record_path) for index, name in enumerate(classes): class_path = cwd + "\\" + name + "\\" h = int(len(os.listdir(class_path)) * 0.9) for img_name in os.listdir(class_path)[h:]: # 剩余10%作为验证集 img_path = class_path + img_name img = Image.open(img_path) img = img.resize((128, 128)) img_raw = img.tobytes() # 将图片转化为原生bytes # print(index,img_raw) example = tf.train.Example( features=tf.train.Features(feature={ "label":_int64list(index), 'img_raw':_byteslist(img_raw) })) writer.write(example.SerializeToString()) writer.close()
我们将数据集分为训练集+测试集+验证集三部分,占总数据集的比例分别为:60%、30%、10%。根据自己的需要可以自己任意比例的分配自己的数据集,很多时候我们可能只需要将数据集分为训练集+测试集。
【2】再就是读取TFRecords文件,使用tf.TFRecordReader与tf.parse_single_example解码器。tf.parse_single_example操作将示例协议缓冲区解码为张量。
def read_record(filename): """读取tfrecord""" filename_queue = tf.train.string_input_producer([filename]) # 创建文件队列 reader = tf.TFRecordReader() # 创建reader _, serialized_example = reader.read(filename_queue) features = tf.parse_single_example( serialized_example, features={ 'label': tf.FixedLenFeature([], tf.int64), 'img_raw': tf.FixedLenFeature([], tf.string) } ) label = features['label'] img = features['img_raw'] img = tf.decode_raw(img, tf.uint8) #tf.decode_raw函数的意思是将原来编码为字符串类型的变量重新变回来 img = tf.reshape(img, [128, 128, 3]) # img = tf.cast(img, tf.float32) * (1. / 255) - 0.5 # 归一化 归一化之后图片会变成那种看不清的图片,未归一化的话则是完好的图片 label = tf.cast(label, tf.int32) #这个函数主要用于数据类型的转变,不会改变原始数据的值还有形状的, return img, label
【3】在管道的最后,使用一个队列来作为训练,评估或判断一起批处理示例。为此,在这里使用一个随机化的示例顺序的队列:tf.train.shuffle_batch。
def get_batch_record(filename,batch_size):
"""获取batch"""
image,label = read_record(filename)
image_batch,label_batch = tf.train.shuffle_batch([image,label], # 随机抽取batch size个image、label
batch_size=batch_size,
capacity=2000,
min_after_dequeue=1000)
return image_batch,label_batch#tf.reshape(label_batch,[batch_size])
【4】调用以上函数,制作数据集并用自己制作的数据集进行训练
create_train_record()
create_test_record()
create_val_record()
image_batch,label_batch = get_batch_record(train_record_path,32)
init = tf.global_variables_initializer()
with tf.Session() as sess:
sess.run(init)
coord = tf.train.Coordinator() #1
threads = tf.train.start_queue_runners(sess=sess,coord=coord) #2
for i in range(1):
image,label = sess.run([image_batch,label_batch])
print(image.shape,1)
coord.request_stop() #3
coord.join(threads) #4
【5】整合以上的代码,写入dateset.py模块
#coding="utf-8" import os import tensorflow as tf from PIL import Image os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' # 源数据地址 cwd = os.getcwd() + "\\train" # 生成record路径及文件名 train_record_path = os.getcwd() + "\\train.tfrecords" test_record_path = os.getcwd() + "\\test.tfrecords" val_record_path = os.getcwd() + "\\val.tfrecords" # 分类 根据自己的需求进行分类 classes = {'0-gjfd','1-gjqzm','2-gjxbs','3-cmj','4-gjsjj','5-hlj','6-gjhyj','7-hdjqc','8-dlfj','9-hcfs','10-gjdsy','11-jxsy','12-mjdsy' ,'13-xwye','14-kyye','15-gjqye','16-hjyc','17-jecy','18-xxjy'} def _byteslist(value): """二进制属性""" return tf.train.Feature(bytes_list = tf.train.BytesList(value = [value])) def _int64list(value): """整数属性""" return tf.train.Feature(int64_list = tf.train.Int64List(value = [value])) def create_train_record(): """创建训练集tfrecord""" writer = tf.python_io.TFRecordWriter(train_record_path) # 创建一个writer for index, name in enumerate(classes): #生成由二元组构成的一个迭代对象,每个二元组由可迭代参数的索引号及其对应的元素组成 class_path = cwd + "\\" + name + "\\" l = int(len(os.listdir(class_path)) * 0.6) # 取前60%创建训练集 for img_name in os.listdir(class_path)[:l]: img_path = class_path + img_name img = Image.open(img_path) img = img.resize((128, 128)) # resize图片大小 img_raw = img.tobytes() # 将图片转化为原生bytes即二进制格式 example = tf.train.Example( # 封装到示例协议缓冲区Example中 features=tf.train.Features(feature={ "label":_int64list(index), # label必须为整数类型属性 'img_raw':_byteslist(img_raw) # 图片必须为二进制属性 })) writer.write(example.SerializeToString()) #序列化为字符串,将字符串写入TFRecords文件 writer.close() # 关闭writer def create_test_record(): """创建测试tfrecord""" writer = tf.python_io.TFRecordWriter(test_record_path) for index, name in enumerate(classes): class_path = cwd + "\\" + name + "\\" l = int(len(os.listdir(class_path)) * 0.6) h = int(len(os.listdir(class_path)) * 0.9) for img_name in os.listdir(class_path)[l:h]: #取中间30%作为测试集 img_path = class_path + img_name img = Image.open(img_path) img = img.resize((128, 128)) img_raw = img.tobytes() # 将图片转化为原生bytes # print(index,img_raw) example = tf.train.Example( features=tf.train.Features(feature={ "label":_int64list(index), 'img_raw':_byteslist(img_raw) })) writer.write(example.SerializeToString()) writer.close() def create_val_record(): """创建验证集tfrecord""" writer = tf.python_io.TFRecordWriter(val_record_path) for index, name in enumerate(classes): class_path = cwd + "\\" + name + "\\" h = int(len(os.listdir(class_path)) * 0.9) for img_name in os.listdir(class_path)[h:]: # 剩余10%作为验证集 img_path = class_path + img_name img = Image.open(img_path) img = img.resize((128, 128)) img_raw = img.tobytes() # 将图片转化为原生bytes # print(index,img_raw) example = tf.train.Example( features=tf.train.Features(feature={ "label":_int64list(index), 'img_raw':_byteslist(img_raw) })) writer.write(example.SerializeToString()) writer.close() def read_record(filename): """读取tfrecord""" filename_queue = tf.train.string_input_producer([filename]) # 创建文件队列 reader = tf.TFRecordReader() # 创建reader _, serialized_example = reader.read(filename_queue) features = tf.parse_single_example( serialized_example, features={ 'label': tf.FixedLenFeature([], tf.int64), 'img_raw': tf.FixedLenFeature([], tf.string) } ) label = features['label'] img = features['img_raw'] img = tf.decode_raw(img, tf.uint8) #tf.decode_raw函数的意思是将原来编码为字符串类型的变量重新变回来 img = tf.reshape(img, [128, 128, 3]) # img = tf.cast(img, tf.float32) * (1. / 255) - 0.5 # 归一化 归一化之后图片会变成那种看不清的图片,未归一化的话则是完好的图片 label = tf.cast(label, tf.int32) #这个函数主要用于数据类型的转变,不会改变原始数据的值还有形状的, return img, label def get_batch_record(filename,batch_size): """获取batch""" image,label = read_record(filename) image_batch,label_batch = tf.train.shuffle_batch([image,label], # 随机抽取batch size个image、label batch_size=batch_size, capacity=2000, min_after_dequeue=1000) return image_batch,label_batch#tf.reshape(label_batch,[batch_size]) def main(): create_train_record() create_test_record() create_val_record() if __name__ == '__main__': main() #create_train_record() #create_test_record() #create_val_record() image_batch,label_batch = get_batch_record(train_record_path,32) #此处是调用的train.tfrcord文件,根据需要进行调用tfrecord文件 with tf.Session() as sess: sess.run(tf.global_variables_initializer()) coord = tf.train.Coordinator() #1 threads = tf.train.start_queue_runners(sess=sess,coord=coord) #2 for i in range(1): image,label = sess.run([image_batch,label_batch]) print(image.shape,label.shape) coord.request_stop() #3 coord.join(threads) #4
在下篇文章中,我们将会把我们制作好的数据集用于神经网络的训练
附:
在运行代码时可能会出现如下错误:
【UnicodeEncodeError】: ‘utf-8’ codec can’t encode character ‘\udcd5’ in position 2189: surrogates not allowed
这个错误是因为编码方式不一致造成的,但是解决这个问题的办法,是检查路径,仔细检查一下文件路径,看看是否正确。
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。