赞
踩
1、文件读取流程
2、文件读取API
3、文件读取案例
步骤1:构造一个文件名队列(路径+文件名)
步骤2:读取文件名队列
步骤3:对读取的文件进行解码
步骤4:放入到样本队列中,进行批处理
注:
tensorflow默认只读取一个样本,根据样本格式不同,情况不同,如下:
1、csv文件 读取一行数据
2、二进制文件 指定一个样本的bytes读取
3、图片文件 一个文件一个文件的读取
批处理也是构造一个队列
主线程要做的事就是取样本数据训练
tf.train.string_input_producer(string_tensor, num_epochs=None, shuffle=True)
将输出字符串(例如文件名)输入到管道队列
shuffleh指定是否顺序和乱序
TensorFlow默认每次只读取一个样本,具体到文本文件读取一行、二进制文件读取指定字节数(最好一个样本)、图片文件默认读取一张图片、TFRecords默认读取一个example
1、他们有共同的读取方法:read(file_queue):从队列中指定数量内容返回一个Tensors元组(key文件名字,value默认的内容(一个样本))
2、由于默认只会读取一个样本,所以通常想要进行批处理。使用tf.train.batch或tf.train.shuffle_batch进行多样本获取,便于训练时候指定每批次多个样本的训练
对于读取不同的文件类型,内容需要解码操作,解码成统一的Tensor格式
解码阶段,默认所有的内容都解码成tf.uint8格式,如果需要后续的类型处理继续处理
在解码之后,我们可以直接获取默认的一个样本内容了,但是如果想要获取多个样本,这个时候需要结合管道的末尾进行批处理
注意:
batch_size和capacity没有大小之分,谁大谁小都可以,一般相等就行
批处理大小跟队列大小和数据的数量没有关系。批处理大小只决定这批次取多少数据(若batch_size大于数据总量,则最后取出的数据有重复,而重复训练是没影响的)
capacity的大小不影响结果
tf.train.start_queue_runners(sess=None, coord=None)
收集所有图中的队列线程,并启动线程
csv文件数据展示:
完整代码如下:
- #! /usr/bin/env python
- # -*- coding:utf-8 -*-
- import tensorflow as tf
- import os
- os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' # 设置告警级别
-
- def csvread(filelist):
- """
- 读取csv文件
- :param filelist: 文件路径+名字的列表
- :return: 读取的内容
- """
- # 1、构造文件队列
- file_queue = tf.train.string_input_producer(filelist)
-
- # 2、构造csv阅读器读取队列数据(按一行)
- reader = tf.TextLineReader()
- key, value = reader.read(file_queue)
-
- # 3、对每行内容解码
- # record_defaults:指定每一个样本的每一列的类型,指定默认值
- records = [["None"], ["None"]]
- # 有几列就用几个参数接收
- example, lable = tf.decode_csv(value, record_defaults=records)
-
- # 4、想要读取多个数据,就要进行批处理 9条数据,1个线程,指定队列9个数据
- # batch_size和capacity没有大小之分,谁大谁小都可以,一般相等就行
- # 批处理大小跟队列大小和数据的数量没有关系。批处理大小只决定这批次取多少数据(若batch_size大于数据总量,则最后的数据有重复,而重复训练是没影响的)
- # capacity的大小不影响结果
- example_batch, label_batch = tf.train.batch([example, lable], batch_size=9, num_threads=1, capacity=9)
-
- return example_batch, label_batch
-
-
- if __name__ == '__main__':
-
- # 找到文件,放入列表
- file_name = os.listdir("./csvdata/")
- filelist = [os.path.join("./csvdata/", file) for file in file_name]
-
- example_batch, label_batch = csvread(filelist)
-
- # 开启会话运行结果
- with tf.Session() as sess:
- # 定义一个线程协调器
- coord = tf.train.Coordinator()
-
- # 开启读文件的线程
- threads = tf.train.start_queue_runners(sess, coord=coord)
-
- # 打印读取的内容
- print(sess.run([example_batch, label_batch]))
-
- # 回收子线程
- coord.request_stop()
- coord.join(threads)
代码运行后如下:
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。