赞
踩
tensorflow的工程有使用python的多进程读取数据,然后给feed给神经网络进行训练。
也有tensorflow中的 tf.data.DataSet的使用。并且由于是tensorflow框架的内容,会让工程看起来更加连贯流畅。
这里我们需要先了解 tf.data 下的两个类:
- tf.data.DataSet:将我们的numpy数据 转换成 tensorflow的DataSet数据
- tf.data.Iterator:生成DataSet的迭代器,来源源不断的获取数据传送送神经网络。
接下来我们用例子说明:
import tensorflow as tf import numpy as np data_num = 100 batch = 5 def dealdata(image, label): """该函数下使用的是python的语法进行处理数据 注意:返回数据基本需要加上 .astype(np.float32) 或 .astype(np.int32),否则可能会类型问题上报错,""" image = image * 2 label = label * 2 return image.astype(np.int32), label.astype(np.int32) def map_func(image, label): """ 该函数下的内容是要使用tensorflow的语法实现的数据处理。 如果数据处理比较复杂,可以使用tf.py_func调用python的处理函数 这里 tf.py_func的入参最后一项,数据的类型和个数,是根据调用的函数具体返回数据填写的,不必与get_batch_gen保持一致!!! """ image = image + 1 image, label = tf.py_func(dealdata, [image, label], [tf.int32, tf.int32]) return image, label def get_batch_gen(split): """ 该函数返回的内容,是 tf.data.Dataset.from_generator()函数的三个入参 Returns: batch_gen: 数据信息的迭代器的函数, gen_types:tf.data.Dataset.from_generator()迭代出的数据的类型 gen_shape:f.data.Dataset.from_generator()迭代出的数据的shape """ if (split == "train"): image = list(range(0, data_num)) # 这里用数字代表输入的数据信息,实际使用可以是数据的路径等 label = list(range(0, data_num)) # 使用数字代表数据的标签 if (split == "val"): # image = ... # label = ... print() def batch_gen(): for i in range(data_num): # 这里循环的次数,是网络训练的次数(具体前向传播的次数) yield ([image[i]], [label[i]]) gen_types = (tf.int32, tf.int32) gen_shape = ([None], [None]) return batch_gen, gen_types, gen_shape class Dataset: def init_input_pipeline(self): gen_function, gen_types, gen_shapes = get_batch_gen("train") """将python定义的 数据信息的迭代器的内容,转换成tensorflow的 DataSet,该 DataSet有两个属性:.output_shapes/.output_types""" data = tf.data.Dataset.from_generator(gen_function, gen_types, gen_shapes) # print(data.output_shapes) # print(data.output_types) """设置了输出数据的batch。虽然返回的data 是tensorflow中另外一种类 BatchDataSet,其实就是设置batch""" # data = data.batch(batch) data = data.shuffle(data_num).batch(batch) # 读取出了 data_num 个数据后,就打乱。 data = data.map(map_func=map_func, num_parallel_calls=4) # 进行多进程数量,来并行处理数据。map_func为处理函数 data = data.prefetch(buffer_size=batch * 10) # 提前准备数据的数量,提前处理了数据,就可以让gpu尽可能少的处于等待数据的状态 """构造一个DataSet的迭代器""" iter = tf.data.Iterator.from_structure(data.output_types, data.output_shapes) self.init_op = iter.make_initializer(data) # 迭代器的初始化 self.flat_inputs = iter.get_next() # 使用get_next源源不断的获取数据 """需要说明的,这里的self.flat_inputs,就可以想tf.placehoder一样,直接作为神经网络的输入节点,在其后面一层一层的定义网络结构。""" class network(): def __init__(self, flat_inputs): self.input = {} self.input["image"] = flat_inputs[0] # 类比于输入的placehoder的地位 self.input["label"] = flat_inputs[1] # 类比于label的placehoder的地位 # ... 具体神神经网络结构 self.out = ... if __name__ == '__main__': trian_Dataset = Dataset() trian_Dataset.init_input_pipeline() # model = network(trian_Dataset.flat_inputs) # 初始化神经网络 with tf.Session() as sess: for i in range(2): sess.run(trian_Dataset.init_op) try: while True: a, b = sess.run(trian_Dataset.flat_inputs) # 直接运行 数据读取的get_next() 节点,即可获取到数据。 print(a.reshape(-1), b.reshape(-1)) # out = sess.run(model.out) # 运行网络的输出节点,即可得到网络输出数据 except tf.errors.OutOfRangeError: # 迭代器数据取完时,直接跳转在这里,防止运行中断 print("outOfRange")
其中,最固定化的流程是 dataset.init_input_pipeline()。
- 主要根据【python的迭代器】定义一个【tf.data.DataSet】,然后设置【batch】,再通过【map】来处理多进程处理数据(复杂的数据处理可以通过tf.py_func来调用python语法的处理函数),并使用【prepare】设置提前准备数据的个数,使gpu尽量少的处于等待状态。
- 定义 tf.data.DataSet 的迭代器 【tf.data.Iterator】,设置初始化节点【make_initializer】以及获取数据节点【get_next】
上面的例子的打印结果如下:
关于自己的测试和问题:
本人使用以上的代码来进行数据读取,在map_func函数中,添加【time.slaeep(2)】,模拟处理复杂数据耗时;然后在map的入参中设置了多进程的数量。然后测试发现,数据获取运行时间并未得到减少。
另外在github上下载论文开源工程,修改map的多进程数量入参,测试用时还是相同的情况。
这里我挠挠脑袋,有点疑惑不知道为什么。这里就先记录到这,没准后面那次再测其他工程就OK了
上面的形式,在我们阅读源码网络结构时是不方便的,直接打印神经网络的某一层,得到tensor的shape基本是 (?,?,?...)的形式,这不是我们需要看到的。那么阅读网络代码时,如何显示化的tensor的shape呢?如下代码,
... dataset.init_input_pipeline() # with tf.Session() as sess: sess.run(dataset.train_init_op) a = sess.run(dataset.flat_inputs) for j in range(len(a)): dataset.flat_inputs[j].set_shape(list(a[j].shape)) # model = Network(dataset, cfg) ...如果需要使用tf.placehoder()代替输入的位置,继而能够进行另外的操作:
... dataset.init_input_pipeline() # PlaceHolder = [] with tf.Session() as sess: sess.run(dataset.train_init_op) a = sess.run(dataset.flat_inputs) for j in range(len(a)): PlaceHolder.append(tf.placeholder(dataset.flat_inputs[j].dtype, a[j].shape, "pl_{}".format(j))) print(PlaceHolder[-1]) dataset.flat_inputs = PlaceHolder # model = Network(dataset, cfg) ...
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。