赞
踩
三种获取数据到TensorFlow程序的方法
多线程+队列的方式
第一阶段 构造文件名队列
第二阶段 读取与解码
第三阶段 批处理
注:这些操作需要启动运行这些队列操作的线程,以便我们在进行文件读取的过程中能够顺利进行入队出队操作
将需要读取的文件的文件名放入文件名队列
从队列当中读取文件内容,并进行解码操作
他们有共同的读取方法,read(file_queue),并且都会返回一个Tensor元祖(key文件名字,value默认的内容(一个样本))
由于默认只会读取一个样本,所以如果想要进行批处理,需要使用tf.train.batch或tf.train,shuffle_batch进行批处理操作,便于之后指定每批次多个样本的训练
解码阶段。默认所有的内容多解码成tf.unit8类型,如果之后需要转换成指定类型则可使用tf.cast()进行相应转换
解码之后,可以直接获取默认的一个样本内容了,但如果想要获取多个样本,需要加入到新的队列进行批处理
以上用到的队列都是tf.train.QueueRunner对象
每个QueueRunner都负责一个阶段,tf.train.start_queue_runners函数会要求图中的每个QueueRunner启动它的运行队列操作的线程。(这些操作需要在会话中开启)
特征抽取:
文本—数值(二维数组shape(n_samples,m_features))
字典—数值(二维数组shape(n_samples,m_features))
两种图片:黑白图片,彩色图片
组成图片的最基本单位是像素
图片长度,图片宽度,图片通道数
灰度图:每一个像素点[0,255]的数,灰度图[长,宽,1]
彩色图:每一个像素点用3个[0,255]的数表示,彩色图[长,宽,3]
一张图片可以被表示成一个3D张量,即其形状为[height,width,channel],height就表示高,width表示宽,channel表示通道数
准备100张狗图片
import os import tensorflow.compat.v1 as tf tf.disable_v2_behavior() def picture_read(file_list): """ 狗图片读取案例 :return: """ # 1.构造文件名队列 file_queue=tf.train.string_input_producer(file_list) #2.读取与解码 reader=tf.WholeFileReader() #key文件名 value一张图片原始编码形式 key,value=reader.read(file_queue) print("key:\n",key) print("value:\n",value) #解码阶段 image=tf.image.decode_jpeg(value) print("image:\n",image) #图像的形状、类型修改 image_resize=tf.image.resize_images(image,[200,200]) print("image_resize:\n",image_resize) #静态形状修改 image_resize.set_shape(shape=[200,200,3]) print("image_resized:\n",image_resize) #3.批处理 image_batch=tf.train.batch([image_resize],batch_size=100,num_threads=1,capacity=100) print("image_batch:\n",image_batch) #开启会话 with tf.Session() as sess: #开启线程 #线程协调员 coord=tf.train.Coordinator() threads=tf.train.start_queue_runners(sess=sess,coord=coord) key_new,value_new,image_new,image_resize_new,image_batch_new=sess.run([key,value,image,image_resize,image_batch]) print("key_new:\n",key_new) print("value_new:\n",value_new) print("image_new\n",image_new) print("image_resize_new:\n",image_resize_new) print("image_batch_new:\n",image_batch_new) #回收线程 coord.request_stop() coord.join(threads) if __name__ == '__main__': #构造路径+文件名列表 filename=os.listdir("./dog") # print(filename) #拼接文件+路径名 file_list=[os.path.join("./dog/",file)for file in filename] print(file_list) picture_read(file_list)
<1×标签><3072×像素>
…
<1×标签><3072×像素>
每3073个字节是一个样本,1个目标值+3072个像素,第一个字节是第一个图像的标签,它是一个0-9范围的数字,接下来的3072个字节是图像像素的值。前1024个字节是红色通道值,下1024个绿色,最后1024个蓝色
这里的图片形状设置从1维的排列到3维数据的时候,涉及到NHWC与NCHW的概念
在读取设置图片形状的时候有两种格式:
设置为“NHWC”时,排列顺序为[batch,height,width,channels]
设置为“NCHW”时,排列顺序为[batch,channels,height,width]
N表示这批图像有几张,H表示图像在竖直方向有多少像素,W表示水平方向像素数,C表示通道数
TensorFlow默认的[height,width,channel]
假设RGB三通道两种格式的区别如下图所示:
二进制数据文件:
import tensorflow.compat.v1 as tf tf.disable_v2_behavior() import os class Cifar(object): def __init__(self): #初始化操作 self.height=32 self.width=32 self.channels=3 #字节数 self.image_bytes=self.height*self.width*self.channels self.label_bytes=1 self.all_bytes=self.label_bytes+self.image_bytes def read_and_decode(self,file_list): #1.构造文件名队列 file_queue=tf.train.string_input_producer(file_list) #2.读取与解码 reader=tf.FixedLengthRecordReader(self.all_bytes) #key文件名 value一个样本 key,value=reader.read(file_queue) print("key:\n",key) print("value:\n",value) #解码阶段 decoded=tf.decode_raw(value,tf.uint8) print("decoded:\n",decoded) #将目标值和特征值切片切开 label=tf.slice(decoded,[0],[self.label_bytes]) image=tf.slice(decoded,[self.label_bytes],[self.image_bytes]) print("label:\n",label) print("image\n",image) #调整图片形状 image_reshaped=tf.reshape(image,shape=[self.channels,self.height,self.width]) print("image_reshaped:\n",image_reshaped) #转置,将图片的顺序转为height,width,channels image_transposed=tf.transpose(image_reshaped,[1,2,0]) print("image_transposed:\n",image_transposed) #调整图像类型 image_cast=tf.cast(image_transposed,tf.float32) #3.批处理 label_batch,image_batch=tf.train.batch([label,image_cast],batch_size=100,num_threads=1,capacity=100) print("label_batch:\n",label_batch) print("image_batch:\n",image_batch) #开启会话 with tf.Session() as sess: #开启线程 coord=tf.train.Coordinator() threads=tf.train.start_queue_runners(sess=sess,coord=coord) key_new,value_new,decoded_new,label_new,image_new,image_reshaped_new,image_transposed_new=sess.run([key,value,decoded,label,image,image_reshaped,image_transposed]) label_value,image_value=sess.run([label_batch,image_batch]) print("key_new:\n",key_new) print("value_new:\n",value_new) print("decoded_new:\n",decoded_new) print("label_new:\n",label_new) print("image_new:\n",image_new) print("image_reshaped_new:\n",image_reshaped_new) print("image_transposed_new:\n",image_transposed_new) print("label_value:\n",label_value) print("image_value:\n",image_value) #回收线程 coord.request_stop() coord.join(threads) return None; if __name__ == "__main__": file_name=os.listdir("./cifar-10-batches-bin") print("file_name:\n",file_name) #构造文件名路径列表 file_list=[os.path.join("./cifar-10-batches-bin/",file) for file in file_name if file[-3:]=="bin"] print("file_list:\n",file_list) #实例化Cifar cifar=Cifar() cifar.read_and_decode(file_list)
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。