当前位置:   article > 正文

TensorFlow(4)-tfrecord数据存储、读取_tfrecord writer is not able to write files

tfrecord writer is not able to write files

# tfrecord 文件读写
tf.train.Example 	tf.train.Features 	tf.train.Feature	tf.train.BytesList 	tf.train.FloatList 	tf.train.INt64List
tf.train.Example.SetializeTostrain() tf.train.Example.FromString

tf.io.TFRecordWriter() tf.python_io.TFRecordWriter() tf.data.experimental.TFRecordWriter()

tf.data.TFRecordDataset() 	tf.data.Dataset.from_tensor_slices()   tf.data.Dataset.from_generator() 

dataset.take(1)  
dataset.map()                 # apply a function to each element of a Dataset.
tf.io.parse_single_example()  # parse one data
tf.parse_example()            # parse the whole batch at once

# 构建数据集
td.data.Dataset.from_tensor_slices()
dataset = tf.data.TFRecordDataset(filenames)    # filenames be a string, a list of strings, or a tf.Tensor of strings. 
dataset = dataset.map()
dataset = dataset.shuffle(buffer_size=10000)
dataset = dataset.repeat(10)
dataset = dataset.batch(32)

sess.run(iterator.initializer)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22

1. tfrecord 数据存盘

tfrecord 用于存储二进制序列数据的一种范式,按顺序存,按顺序取,十分搞笑。里面存的每一条数据都是一个 byte-string, 最常用的转byte-string的方式是tf.train.Example 。tf.train.Example (or protobuf) 以字典{“string”: value}形式存储消息,这种消息存储机制可读性高。

将用户数据需要转化为 tfrecord 约定的格式,才能使用 tfrecord 格式存储数据。

step1: python数据转换 成 tfrecord数据,tfrecord支持写入三种格式的数据:string、int64、float32,对应的tdrecord 数据类型分别 tf.train.BytesList、tf.train.Int64List、tf.train.FloatList

step2: tf.train.BytesList、tf.train.Int64List、tf.train.FloatList 写入 tf.train.Feature

step3: 以 tf.train.Feature 构成 特征字典tf.train.Feautures (Features message)

step4: tf.train.Feautures 转 tf.train.Example

step5: 每一条 tf.train.Example序列化后,io write tf recoed.

Note:

  1. tfrecord文件中并非只能存tf.train.Example 序列化的结果,tf.train.Example 只是将字典序列化的一种方法。任何 byte-string都能够存入TFRecord file,tfrecord 中每一条record按照下面的范式存储。
  2. 如果数据处理速度不是模型训练的瓶颈,格式没必要转化成proto范式存储, binary-string 范式能够高效存读。

单一类型 单条 存盘

# part1: python数据类型转 tf数据类型
# tf.train.BytesList:string、byte
# tf.train.FloatList:float (float32)、double (float64)
# tf.train.Int64List :bool、enum、int32、uint32、int64、uint64
value = 1
value_ed = tf.train.Int64List(value=[value])

# part2: tf.train.BytesList  tf.train.FLoatList tf.train.Int64List 装进 tf.train.feature.
# 以下为scalar 转 tf.train.Feature 的快捷函数, not scalar 数据需要用np.array().tobytes()/tf.io.serialize_tensor转换成binary-strings,再调用接口函数封装成 tf.train.Feature
def _bytes_feature(value):
	"""
	Args: value string/byte, 可以通过np.array(data, dtype=np.float32).tobytes()转换得到byte类型的数据
	Returns: tf.train.Feature
	"""           
  if isinstance(value, type(tf.constant(0))):
    value = value.numpy() # BytesList won't unpack a string from an EagerTensor.
  return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

def _float_feature(value):
	"""
	Args:
	Returns: tf.train.Feature from a float_list of a float / double.
	"""
  return tf.train.Feature(float_list=tf.train.FloatList(value=[value]))

def _int64_feature(value):
	"""
	Returns: tf.train.Feature from an int64_list of a bool / enum / int / uint.
	"""
  return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))


print(_bytes_feature(b'test_string'))
print(_bytes_feature(u'test_bytes'.encode('utf-8')))
print(_float_feature(np.exp(1)))
print(_int64_feature(True))
print(_int64_feature(1))
feature = _float_feature(np.exp(1))
feature.SerializeToString()                # tf.train.Feature 也可以序列h化
  

# part3: 单一类型数据dump tfrecord demo, value can be a num / list / array              
pybyte_value = np.array(value).tobytes()                     # 0.转Python字节数据
tfbyte_value = tf.train.BytesList(value=[pybyte_value])      # 1.转tf.train 字节数据
feature_dict[key] = tf.train.Feature(bytes_list=tfbyte_value)# 2.转tf.train.Feature()注意是tf.train.Feature()没有s
..........
feature_example = tf.train.Example(features=tf.train.Features(feature=tffeature_dict)# 3.转tf.train.Example()  注意tf.train.Features()s
exmp_serial = feature_example.SerializeToString()           # 序列化feature_example 

tf_writer = tf.python_io.TFRecordWriter(tfrecord_path)      # 构建tf写句柄
tf_writer.write(exmp_serial)                                # 写入tf文件
tf_writer.close()                                           # 关闭句柄

# part4:最后存入tfrecord中每一条record的格式
uint64 length
uint32 masked_crc32_of_length
byte   data[length]
uint32 masked_crc32_of_data
# 【 Note that the tf.train.Example message is just a wrapper around the Features message:】
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59

组合类型 单条 存盘、读取

有此往下demo 数据集: 包含10000个观测数据,每条数据包含4个特征:[bool, label_index, lable_string, random_score]。

# part5 用serialize_example()封装组合类型数据存盘
n_observations = int(1e4)                                                # The number of observations in the dataset.
feature0 = np.random.choice([False, True], n_observations)               # Boolean feature, encoded as False or True.
feature1 = np.random.randint(0, 5, n_observations)                       # Integer feature, random from 0 to 4.
strings = np.array([b'cat', b'dog', b'chicken', b'horse', b'goat'])      # String feature.
feature2 = strings[feature1]
feature3 = np.random.randn(n_observations)                               # Float feature, from a standard normal distribution.

def serialize_example(feature0, feature1, feature2, feature3):
  """
  Creates a tf.train.Example message ready to be written to a file.
  Create a dictionary mapping the feature name to the tf.train.Example-compatible data type.
  Create a Features message using tf.train.Example.
  """
  feature = {
      'feature0': _int64_feature(feature0),
      'feature1': _int64_feature(feature1),
      'feature2': _bytes_feature(feature2),
      'feature3': _float_feature(feature3),
  }
  example_proto = tf.train.Example(features=tf.train.Features(feature=feature))
  return example_proto.SerializeToString()    # tf.train.Example 序列化
  
# example_proto 序列化成 binary-string 
serialized_example = serialize_example(False, 4, b'goat', 0.9876)      
print(serialized_example)

# binary-string 反序列化成 example_proto
example_proto = tf.train.Example.FromString(serialized_example)        
print(example_proto)

# 输出如下

b'\nR\n\x14\n\x08feature2\x12\x08\n\x06\.....

features {
 feature {
   key: "feature0"
   value {
     int64_list {
       value: 0
     }
   }
 }
 feature {
   key: "feature1"
   value {
     int64_list {
       value: 4
     }
   }
 }
 feature {
   key: "feature2"
   value {
     bytes_list {
       value: "goat"
     }
   }
 }
 feature {
   key: "feature3"
   value {
     float_list {
       value: 0.9876000285148621
     }
   }
 }
}


  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71

binary-string 逐条存盘

## 存盘
filename = 'test.tfrecord'
with tf.io.TFRecordWriter(filename) as writer:
  for i in range(n_observations):
    example = serialize_example(feature0[i], feature1[i], feature2[i], feature3[i])
    writer.write(example)
    
## 读盘
filenames = [filename]
raw_dataset = tf.data.TFRecordDataset(filenames)
for raw_record in raw_dataset.take(1):
  example = tf.train.Example()
  example.ParseFromString(raw_record.numpy())
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13

np.array().tobytes()构造包含数组中原始数据的Python字节数据

组合类型 多条(tf数据集) 存盘

神经网络数据集 支持:迭代,随机shuffle,batch获取等操作。tf.data.TFRecordDataset()提供多种接口函数用于完成tf 数据集构建工作,并且能够将整个tf数据集序列化后存盘。

# part6: 从numpy array 中构建tf dataset
features_dataset = tf.data.Dataset.from_tensor_slices((feature0, feature1, feature2, feature3))
for f0, f1, f2, f3 in features_dataset.take(1):
  print(f0, f1, f2, f3)
  break
  
# 对于自定义的序列化操作函数serialize_example,为了使其成为TensorFlow graph 的节点,需要使用一下两种方式进行元素映射。
# part6.1:元素操作1--tf.data.Dataset.map() + tf_funciotn()封装序列化
def tf_serialize_example(f0, f1, f2, f3):
  tf_string = tf.py_function(
    serialize_example,
    (f0, f1, f2, f3),              # Pass these args to the above function.
    tf.string)                     # The return type is `tf.string`.
  return tf.reshape(tf_string, ()) # The result is a scalar.
serialized_features_dataset = features_dataset.map(tf_serialize_example)

# part6.2:元素操作2--generator封装,序列化
def generator():
  for features in features_dataset:
    yield serialize_example(*features)
serialized_features_dataset = tf.data.Dataset.from_generator(
    generator, output_types=tf.string, output_shapes=())

## 6.3 serialized_features_dataset整个数据集存盘
filename = 'test.tfrecord'
writer = tf.data.experimental.TFRecordWriter(filename)
writer.write(serialized_features_dataset)

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28

三种存盘TFRecordWriter

# 1. tf.io.TFRecordWriter(), pure-Python functions for reading and writing TFRecord files.
with tf.io.TFRecordWriter(filename) as writer:
    example = serialize_example(feature0[i], feature1[i], feature2[i], feature3[i])
    writer.write(example)

# 2. tf.data.experimental.TFRecordWriter(
writer = tf.data.experimental.TFRecordWriter(filename)
writer.write(serialized_features_dataset)

# 官方文档说明
# To write TFRecords to disk, use `tf.io.TFRecordWriter`. 
# To save and load the contents of a dataset, use `tf.data.experimental.save` and `tf.data.experimental.load`


# 3. tf.python_io.TFRecordWriter(tfrecord_path)
tf_writer = tf.python_io.TFRecordWriter(tfrecord_path)      # 构建tf写句柄
tf_writer.write(exmp_serial)                                # 写入tf文件
tf_writer.close()                                           # 关闭句柄


  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20

2. tfrecord 数据读取

feature_description = {                           
    # description provide info to build data shape and type signature, tf.data.Datasets use graph-execution
    'feature0': tf.io.FixedLenFeature([], tf.int64, default_value=0),
    'feature1': tf.io.FixedLenFeature([], tf.int64, default_value=0),
    'feature2': tf.io.FixedLenFeature([], tf.string, default_value=''),
    'feature3': tf.io.FixedLenFeature([], tf.float32, default_value=0.0),
} 

def _parse_function(example_proto):
     # Parse the input `tf.train.Example` proto using the dictionary above.
    return tf.io.parse_single_example(example_proto, feature_description) 

# case1 tf.data.TFRecordDataset() + feature_description + map()
file_name = 'test.tfrecord'
filenames = [file_name]
raw_dataset = tf.data.TFRecordDataset(filenames)     # 从tfrecord中载入的是binary-string
parsed_dataset = raw_dataset.map(_parse_function)    # apply func to each element of a dataset

# Use eager execution to display the observations in the dataset.   # 看不到具体数据值
for parsed_record in parsed_dataset.take(10):    
    print(repr(parsed_record))
    break      


# case2 tf.data.TFRecordDataset() + tf.train.Example.ParseFromString(), returns a tf.train.Example, 没有实验成功
raw_dataset2 = tf.data.TFRecordDataset(filenames)    
for raw_record in raw_dataset2.take(1):
    example = tf.train.Example()
    # example.ParseFromString(raw_record.numpy())  # AttributeError: "'Tensor' object has no attribute 'numpy'
    # example.ParseFromString(raw_record)            # expected a readable buffer object

    # example.features.feature is the dictionary 转 dictionary of NumPy arrays
    result = {} 
    for key, feature in example.features.feature.items():
        # The values are the Feature objects which contain a `kind` which contains:
        # one of three fields: bytes_list, float_list, int64_list

        kind = feature.WhichOneof('kind')
        result[key] = np.array(getattr(feature, kind).value)
          
# case3: tf.python_io.tf_record_iterator()
record_iterator = tf.python_io.tf_record_iterator(path=file_name)
    for string_record in record_iterator:
      example = tf.train.Example()
      example.ParseFromString(string_record)
      print("cyy", example)
      # Exit after 1 iteration as this is purely demonstrative.
      break  
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48

参考资料:TFRecord and tf.train.Example

3.tensorflow 构建数据集

tf1.x
下列函数输入是一个数据集,输出也是一个数据集

3.1 Dataset structure

Dataset.output_types 指明components中 每个元素的数据类型
Dataset.output_shapes指明components中 每个元素的数据shape
td.data.Dataset.from_tensor_slices() 把tensor的第0个轴作为数据个数

"""
tf.data 支持建立大规模、复杂的数据流pipelines
    tf.data.Dataset()  构建数据集
    tf.data.Iterator() 迭代数据集 Iterator.get_next() yields the next element of a Dataset when executed, 两种迭代模式
Dataset 数据集中的每一个数据称为一个elements,
        一个elements可能由single tensor, a tuple of tensor, a nested tuple of tensor构成, every tensor calls components
"""                                                                                                                                                                                                                                                                                                                                                                                                                                        
dataset1 = tf.data.Dataset.from_tensor_slices(tf.random_uniform([4, 10]))
print(dataset1.output_types)  # ==> "tf.float32"
print(dataset1.output_shapes)  # ==> "(10,)"

dataset2 = tf.data.Dataset.from_tensor_slices(
    (tf.random_uniform([4]),
     tf.random_uniform([4, 100], maxval=100, dtype=tf.int32)))
print(dataset2.output_types)  # ==> "(tf.float32, tf.int32)"
print(dataset2.output_shapes)  # ==> "((), (100,))"

dataset3 = tf.data.Dataset.zip((dataset1, dataset2))
print(dataset3.output_types)  # ==> (tf.float32, (tf.float32, tf.int32))
print(dataset3.output_shapes)  # ==> "(10, ((), (100,)))"
    # 给 conponents每个元素赋予名字,方便读取
dataset = tf.data.Dataset.from_tensor_slices(
    {"a": tf.random_uniform([4]),
     "b": tf.random_uniform([4, 100], maxval=100, dtype=tf.int32)})
print(dataset.output_types)  # ==> "{'a': tf.float32, 'b': tf.int32}"
print(dataset.output_shapes)  # ==> "{'a': (), 'b': (100,)}"
    # Dataset.amp(), Dataset.flat_map(), Dataset.filter() 将指定的变换操作 应用于 每个元素
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27

3.2 Creating an iterator

  1. one-shot iterator 一次迭代数据集 epoch = 1, no need explicit initialization
  2. initialible iterator sess.run(iterator.initializer)
  3. reinitializable
  4. feedable
# one-shot iterator
dataset = tf.data.Dataset.range(5)
iterator = dataset.make_one_shot_iterator() # 可以一次迭代吧,又不用训练
next_element = iterator.get_next()          # tf.Session.run() get data, push iter to next data

sess = tf.Session()
for i in range(100):
    value = sess.run(next_element) 
    print("dataset1", i, value)
    assert i == value
    break
    
# initialible  可以支持数据集的多次迭代, iterator.initializer before each iteration,
   # 支持placeholder指定数据集元素的尺寸, feed different data
max_value = tf.placeholder(tf.int64, shape=[])
dataset2 = tf.data.Dataset.range(max_value)
iterator2 = dataset2.make_initializable_iterator()
next_element2 = iterator2.get_next()

# reinitializable
   # Initialize an iterator over a dataset with 10 elements.
sess.run(iterator2.initializer, feed_dict={max_value: 10})
for i in range(10):
    value = sess.run(next_element2)
    print("dataset2", i, value)
    assert i == value
    break

    # Initialize the same iterator over a dataset with 100 elements.
sess.run(iterator2.initializer, feed_dict={max_value: 100})
for i in range(100):
    value = sess.run(next_element2)
    print("dataset3", i, value)
    assert i == value
    break

# A feedable iterator can be used together .......
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37

3.3 Try-except block - wrap the “training loop”

If iterator reaches the end of the dataset, executing sess.run(next_element) will raise a tf.errors.OutOfRangeError.
the iterator will be in an unusable state, must initialize it again if you want to use it further.

dataset4 = tf.data.Dataset.range(5)
iterator4 = dataset4.make_initializable_iterator()
next_element4 = iterator4.get_next()
   # Typically `result` will be the output of a model, or an optimizer's training operation.
   # travel to the dataset ends
result = tf.add(next_element4, next_element4)
sess.run(iterator4.initializer)
print(sess.run(result))  # ==> "0"
print(sess.run(result))  # ==> "2"
print(sess.run(result))  # ==> "4"
print(sess.run(result))  # ==> "6"
print(sess.run(result))  # ==> "8"
try:
   sess.run(result)
except tf.errors.OutOfRangeError:
   print("End of dataset")  # ==> "End of dataset"

   # wrap trianing loop
sess.run(iterator4.initializer)
epoch = 0
while True:
   try:
       data = sess.run(result)
       print("dataset4", epoch, data)
   except tf.errors.OutOfRangeError:
       # one-shot dataset don't sopport reinitialization, so it can't wrap the training loop
       sess.run(iterator4.initializer)  
       epoch += 1
       if epoch == 3:
           break

   # nested structure for Iterator.get_next()
   feature1 = tf.data.Dataset.from_tensor_slices(tf.random_uniform([4, 10]))
   feature2 = tf.data.Dataset.from_tensor_slices((tf.random_uniform([4]), tf.random_uniform([4, 100])))
   dataset5 = tf.data.Dataset.zip((feature1, feature2))
   
   iterator5 = dataset5.make_initializable_iterator()

   sess.run(iterator5.initializer)
   next1, (next2, next3) = iterator5.get_next() # sess.run() any of these tensors will advance the iterator for all components
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40

3.4 Saving iterator state

3.5 Dataset.map() - preprocessing data

 # Parsing tf.Example protocol buffer messages 
      # Transforms a scalar string `example_proto` into a pair of a scalar string and a scalar integer, representing an image and its label, respectively.
def _parse_function(example_proto):
  features = {"image": tf.FixedLenFeature((), tf.string, default_value=""),
              "label": tf.FixedLenFeature((), tf.int64, default_value=0)}
  parsed_features = tf.parse_single_example(example_proto, features)
  return parsed_features["image"], parsed_features["label"]

      # Creates a dataset that reads all of the examples from two files, and extracts the image and label features.
filenames = ["/var/data/file1.tfrecord", "/var/data/file2.tfrecord"]
dataset6 = tf.data.TFRecordDataset(filenames)
dataset6 = dataset6.map(_parse_function)

  # Decoding image data and resizing it.....
  # Applying arbitrary Python logic with tf.py_func()

# Batching dataset elements, Dataset.batch()-- stacks n consecutive elements of a dataset into a single element.
  # Simple batching
inc_dataset = tf.data.Dataset.range(100)
dec_dataset = tf.data.Dataset.range(0, -100, -1)
dataset7 = tf.data.Dataset.zip((inc_dataset, dec_dataset))    # zip output is typle
batched_dataset = dataset7.batch(4)

iterator7 = batched_dataset.make_one_shot_iterator()
next_element7 = iterator7.get_next()

print(sess.run(next_element7))  # ==> ([0, 1, 2,   3],   [ 0, -1,  -2,  -3])  
print(sess.run(next_element7))  # ==> ([4, 5, 6,   7],   [-4, -5,  -6,  -7])
print(sess.run(next_element7))  # ==> ([8, 9, 10, 11],   [-8, -9, -10, -11])
data7 = sess.run(next_element7) # data7[0].shape = (4, ), data7[1].shape = (4, )
  # Batching tensors with padding ........
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31

3.6 Dataset.batch() - Batching dataset elements

dataset.batch(B) 使 data 增加一个batch 维度 [P, V, F] -> [B, P, V, F]

    # Batching dataset elements, Dataset.batch()-- stacks n consecutive elements of a dataset into a single element.
        # Simple batching
    inc_dataset = tf.data.Dataset.range(100)
    dec_dataset = tf.data.Dataset.range(0, -100, -1)
    dataset7 = tf.data.Dataset.zip((inc_dataset, dec_dataset))    # zip output is typle
    batched_dataset = dataset7.batch(4)

    iterator7 = batched_dataset.make_one_shot_iterator()
    next_element7 = iterator7.get_next()

    print(sess.run(next_element7))  # ==> ([0, 1, 2,   3],   [ 0, -1,  -2,  -3])  
    print(sess.run(next_element7))  # ==> ([4, 5, 6,   7],   [-4, -5,  -6,  -7])
    print(sess.run(next_element7))  # ==> ([8, 9, 10, 11],   [-8, -9, -10, -11])
    data7 = sess.run(next_element7) # data7[0].shape = (4, ), data7[1].shape = (4, )
        # Batching tensors with padding ........
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15

3.7 Training workflows - Dataset.repeat() 、 Dataset.shuffle()

# Processing multiple epochs, 
    # way1 Dataset.repeat(n), no arguments will repeat the input indefinitely
filenames = ["/var/data/file1.tfrecord", "/var/data/file2.tfrecord"]
dataset = tf.data.TFRecordDataset(filenames)
dataset = dataset.map(...)
dataset = dataset.repeat(10)
dataset = dataset.batch(32)
    # way2, try-except, Compute for 100 epochs.
filenames = ["/var/data/file1.tfrecord", "/var/data/file2.tfrecord"]
dataset = tf.data.TFRecordDataset(filenames)
dataset = dataset.map(...)
dataset = dataset.batch(32)
iterator = dataset.make_initializable_iterator()
next_element = iterator.get_next()
for _ in range(100):
    sess.run(iterator.initializer)

    while True:
        try:
            sess.run(next_element)
        except tf.errors.OutOfRangeError:
            break

# Dataset.shuffle() -  Randomly shuffling input data
     ##  it maintains a fixed-size buffer and chooses the next element uniformly at random from that buffer.
filenames = ["/var/data/file1.tfrecord", "/var/data/file2.tfrecord"]
dataset = tf.data.TFRecordDataset(filenames)
dataset = dataset.map(...)
dataset = dataset.shuffle(buffer_size=10000)
dataset = dataset.batch(32)
dataset = dataset.repeat()
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31

3.8 Other - Reading input data、 tf2.x

# # Reading input data
#     # Consuming NumPy arrays
#     # Consuming TFRecord data: tf.data.TFRecordDataset enable you to stream over the contents of one or more TFRecord files as part of an input pipeline.
# filenames = tf.placeholder(tf.string, shape=[None])
# dataset6 = tf.data.TFRecordDataset(filenames)
# dataset6 = dataset6.map(...)  # Parse the record into tensors.
# dataset6 = dataset6.repeat()  # Repeat the input indefinitely.
# dataset6 = dataset6.batch(32)
# iterator6 = dataset6.make_initializable_iterator()

#         # You can feed the initializer with the appropriate filenames for the current phase of execution, e.g. training vs. validation.
#         # Initialize `iterator` with training data.
# training_filenames = ["/var/data/file1.tfrecord", "/var/data/file2.tfrecord"]
# sess.run(iterator6.initializer, feed_dict={filenames: training_filenames})
#         # Initialize `iterator` with validation data.
# validation_filenames = ["/var/data/validation1.tfrecord", ...]
# sess.run(iterator6.initializer, feed_dict={filenames: validation_filenames})

# tf2.x 无需session.run() 即可获取数据
# data_iterator  = dataset.as_numpy_iterator()
# data_iterator.next()     # 
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/你好赵伟/article/detail/123796
推荐阅读
相关标签
  

闽ICP备14008679号