当前位置:   article > 正文

tensorflow的数据读取 tf.data.DataSet、tf.data.Iterator_tf.data.dataset.from_generator

tf.data.dataset.from_generator

tensorflow的工程有使用python的多进程读取数据,然后给feed给神经网络进行训练。

也有tensorflow中的 tf.data.DataSet的使用。并且由于是tensorflow框架的内容,会让工程看起来更加连贯流畅。

这里我们需要先了解 tf.data 下的两个类:

  • tf.data.DataSet:将我们的numpy数据 转换成 tensorflow的DataSet数据
  • tf.data.Iterator:生成DataSet的迭代器,来源源不断的获取数据传送送神经网络。

接下来我们用例子说明:

  1. import tensorflow as tf
  2. import numpy as np
  3. data_num = 100
  4. batch = 5
  5. def dealdata(image, label):
  6. """该函数下使用的是python的语法进行处理数据
  7. 注意:返回数据基本需要加上 .astype(np.float32) 或 .astype(np.int32),否则可能会类型问题上报错,"""
  8. image = image * 2
  9. label = label * 2
  10. return image.astype(np.int32), label.astype(np.int32)
  11. def map_func(image, label):
  12. """
  13. 该函数下的内容是要使用tensorflow的语法实现的数据处理。
  14. 如果数据处理比较复杂,可以使用tf.py_func调用python的处理函数
  15. 这里 tf.py_func的入参最后一项,数据的类型和个数,是根据调用的函数具体返回数据填写的,不必与get_batch_gen保持一致!!!
  16. """
  17. image = image + 1
  18. image, label = tf.py_func(dealdata, [image, label], [tf.int32, tf.int32])
  19. return image, label
  20. def get_batch_gen(split):
  21. """
  22. 该函数返回的内容,是 tf.data.Dataset.from_generator()函数的三个入参
  23. Returns:
  24. batch_gen: 数据信息的迭代器的函数,
  25. gen_types:tf.data.Dataset.from_generator()迭代出的数据的类型
  26. gen_shape:f.data.Dataset.from_generator()迭代出的数据的shape
  27. """
  28. if (split == "train"):
  29. image = list(range(0, data_num)) # 这里用数字代表输入的数据信息,实际使用可以是数据的路径等
  30. label = list(range(0, data_num)) # 使用数字代表数据的标签
  31. if (split == "val"):
  32. # image = ...
  33. # label = ...
  34. print()
  35. def batch_gen():
  36. for i in range(data_num): # 这里循环的次数,是网络训练的次数(具体前向传播的次数)
  37. yield ([image[i]], [label[i]])
  38. gen_types = (tf.int32, tf.int32)
  39. gen_shape = ([None], [None])
  40. return batch_gen, gen_types, gen_shape
  41. class Dataset:
  42. def init_input_pipeline(self):
  43. gen_function, gen_types, gen_shapes = get_batch_gen("train")
  44. """将python定义的 数据信息的迭代器的内容,转换成tensorflow的 DataSet,该 DataSet有两个属性:.output_shapes/.output_types"""
  45. data = tf.data.Dataset.from_generator(gen_function, gen_types, gen_shapes)
  46. # print(data.output_shapes)
  47. # print(data.output_types)
  48. """设置了输出数据的batch。虽然返回的data 是tensorflow中另外一种类 BatchDataSet,其实就是设置batch"""
  49. # data = data.batch(batch)
  50. data = data.shuffle(data_num).batch(batch) # 读取出了 data_num 个数据后,就打乱。
  51. data = data.map(map_func=map_func, num_parallel_calls=4) # 进行多进程数量,来并行处理数据。map_func为处理函数
  52. data = data.prefetch(buffer_size=batch * 10) # 提前准备数据的数量,提前处理了数据,就可以让gpu尽可能少的处于等待数据的状态
  53. """构造一个DataSet的迭代器"""
  54. iter = tf.data.Iterator.from_structure(data.output_types, data.output_shapes)
  55. self.init_op = iter.make_initializer(data) # 迭代器的初始化
  56. self.flat_inputs = iter.get_next() # 使用get_next源源不断的获取数据
  57. """需要说明的,这里的self.flat_inputs,就可以想tf.placehoder一样,直接作为神经网络的输入节点,在其后面一层一层的定义网络结构。"""
  58. class network():
  59. def __init__(self, flat_inputs):
  60. self.input = {}
  61. self.input["image"] = flat_inputs[0] # 类比于输入的placehoder的地位
  62. self.input["label"] = flat_inputs[1] # 类比于label的placehoder的地位
  63. # ... 具体神神经网络结构
  64. self.out = ...
  65. if __name__ == '__main__':
  66. trian_Dataset = Dataset()
  67. trian_Dataset.init_input_pipeline()
  68. # model = network(trian_Dataset.flat_inputs) # 初始化神经网络
  69. with tf.Session() as sess:
  70. for i in range(2):
  71. sess.run(trian_Dataset.init_op)
  72. try:
  73. while True:
  74. a, b = sess.run(trian_Dataset.flat_inputs) # 直接运行 数据读取的get_next() 节点,即可获取到数据。
  75. print(a.reshape(-1), b.reshape(-1))
  76. # out = sess.run(model.out) # 运行网络的输出节点,即可得到网络输出数据
  77. except tf.errors.OutOfRangeError:
  78. # 迭代器数据取完时,直接跳转在这里,防止运行中断
  79. print("outOfRange")

其中,最固定化的流程是 dataset.init_input_pipeline()。

  • 主要根据【python的迭代器】定义一个【tf.data.DataSet】,然后设置【batch】,再通过【map】来处理多进程处理数据(复杂的数据处理可以通过tf.py_func来调用python语法的处理函数),并使用【prepare】设置提前准备数据的个数,使gpu尽量少的处于等待状态。
  • 定义 tf.data.DataSet 的迭代器 【tf.data.Iterator】,设置初始化节点【make_initializer】以及获取数据节点【get_next】

上面的例子的打印结果如下:

关于自己的测试和问题:

本人使用以上的代码来进行数据读取,在map_func函数中,添加【time.slaeep(2)】,模拟处理复杂数据耗时;然后在map的入参中设置了多进程的数量。然后测试发现,数据获取运行时间并未得到减少。

另外在github上下载论文开源工程,修改map的多进程数量入参,测试用时还是相同的情况。

这里我挠挠脑袋,有点疑惑不知道为什么。这里就先记录到这,没准后面那次再测其他工程就OK了

上面的形式,在我们阅读源码网络结构时是不方便的,直接打印神经网络的某一层,得到tensor的shape基本是 (?,?,?...)的形式,这不是我们需要看到的。那么阅读网络代码时,如何显示化的tensor的shape呢?如下代码,

  1. ...
  2. dataset.init_input_pipeline()
  3. #
  4. with tf.Session() as sess:
  5. sess.run(dataset.train_init_op)
  6. a = sess.run(dataset.flat_inputs)
  7. for j in range(len(a)):
  8. dataset.flat_inputs[j].set_shape(list(a[j].shape))
  9. #
  10. model = Network(dataset, cfg)
  11. ...

如果需要使用tf.placehoder()代替输入的位置,继而能够进行另外的操作:

  1. ...
  2. dataset.init_input_pipeline()
  3. #
  4. PlaceHolder = []
  5. with tf.Session() as sess:
  6. sess.run(dataset.train_init_op)
  7. a = sess.run(dataset.flat_inputs)
  8. for j in range(len(a)):
  9. PlaceHolder.append(tf.placeholder(dataset.flat_inputs[j].dtype, a[j].shape, "pl_{}".format(j)))
  10. print(PlaceHolder[-1])
  11. dataset.flat_inputs = PlaceHolder
  12. #
  13. model = Network(dataset, cfg)
  14. ...
 


 

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/从前慢现在也慢/article/detail/123853
推荐阅读
相关标签
  

闽ICP备14008679号