赞
踩
从根本上讲,tf.Example 是 {“string”: tf.train.Feature} 映射
类型 | 可转换自 |
---|---|
tf.train.BytesList | String Byte |
tf.train.FloatList | float double |
tf.train.Int64List | bool enum int32 uint32 int64 unint64 |
import tensorflow as tf import numpy as np ## !!单条数据转换!! def _bytes_feature(value): """Returns a bytes_list from a string / byte.""" if isinstance(value, type(tf.constant(0))): # tensorflow.python.framework.ops.EagerTensor value = value.numpy() # BytesList won't unpack a string from an EagerTensor. return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value])) def _float_feature(value): """Returns a float_list from a float / double.""" return tf.train.Feature(float_list=tf.train.FloatList(value=[value])) def _int64_feature(value): """Returns an int64_list from a bool / enum / int / uint.""" return tf.train.Feature(int64_list=tf.train.Int64List(value=[value])) # 测试 print(_bytes_feature(b'test_string')) print(_bytes_feature(u'test_bytes'.encode('utf-8'))) print(_float_feature(np.exp(1))) print(_int64_feature(True)) print(_int64_feature(1)) feature = _float_feature(np.exp(1)) feature.SerializeToString()
# 创建一个包含若干数据类型的数据集 n_observations = int(1e4) feature0 = np.random.choice([False, True], n_observations) # bool feature1 = np.random.randint(0, 5, n_observations) # int strings = np.array([b'cat', b'dog', b'chicken', b'horse', b'goat']) feature2 = strings[feature1] # byte feature3 = np.random.randn(n_observations) # float def serialize_example(feature0, feature1, feature2, feature3): """ Creates a tf.Example message ready to be written to a file. """ # Create a dictionary mapping the feature name to the tf.Example-compatible # data type. feature = { 'feature0': _int64_feature(feature0), 'feature1': _int64_feature(feature1), 'feature2': _bytes_feature(feature2), 'feature3': _float_feature(feature3), } # Create a Features message using tf.train.Example. example_proto = tf.train.Example(features=tf.train.Features(feature=feature)) return example_proto.SerializeToString()
example_observation = []
serialized_example = serialize_example(False, 4, b'goat', 0.9876)
example_proto = tf.train.Example.FromString(serialized_example) # 和.SerializeToString()对应
example_proto
tf.data 模块还提供用于在 TensorFlow 中读取和写入数据的工具
# 普通数据转成dataset后使用map方式解析
features_dataset = tf.data.Dataset.from_tensor_slices((feature0, feature1, feature2, feature3))
for f0,f1,f2,f3 in features_dataset.take(1):
print(f0, f1, f2, f3)
def tf_serialize_example(f0,f1,f2,f3):
tf_string = tf.py_function(
serialize_example,
(f0,f1,f2,f3), # pass these args to the above function.
tf.string) # the return type is `tf.string`.
return tf.reshape(tf_string, ()) # The result is a scalar
serialized_features_dataset = features_dataset.map(tf_serialize_example)
serialized_features_dataset
# 做一个generator,将数据写入tfrecord中
def generator():
for features in features_dataset:
yield serialize_example(*features)
serialized_features_dataset = tf.data.Dataset.from_generator(
generator, output_types=tf.string, output_shapes=())
# 将处理好的生成器-dataset写入TFRecord
filename = 'test.tfrecord'
writer = tf.data.experimental.TFRecordWriter(filename)
writer.write(serialized_features_dataset)
注:在 tf.data.Dataset 上进行迭代仅在启用了 Eager Execution 时有效。
# 创建TFRecordDataset,每个文件是一个TFRecord
filenames = [filename]
raw_dataset = tf.data.TFRecordDataset(filenames)
for raw_record in raw_dataset.take(10):
print(repr(raw_record))
# Create a description of the features. feature_description = { 'feature0': tf.io.FixedLenFeature([], tf.int64, default_value=0), 'feature1': tf.io.FixedLenFeature([], tf.int64, default_value=0), 'feature2': tf.io.FixedLenFeature([], tf.string, default_value=''), 'feature3': tf.io.FixedLenFeature([], tf.float32, default_value=0.0), } # 单条解析 def _parse_function(example_proto): # Parse the input `tf.Example` proto using the dictionary above. return tf.io.parse_single_example(example_proto, feature_description) parsed_dataset = raw_dataset.map(_parse_function) for parsed_record in parsed_dataset.take(10): print(repr(parsed_record))
# for循环写入数据
with tf.io.TFRecordWriter(filename) as writer:
for i in range(n_observations):
example = serialize_example(feature0[i], feature1[i], feature2[i], feature3[i])
writer.write(example)
filenames = [filename]
raw_dataset = tf.data.TFRecordDataset(filenames)
for raw_record in raw_dataset.take(1):
example = tf.train.Example()
example.ParseFromString(raw_record.numpy())
print(example)
上述读写操作,粗看类似“八股文”,但在实际操作过程中会遇到不少问题,个人总结了几点需要注意的地方
# 原始观测值:0维 def parse_exmp(example_proto, column_list, data_type, default_value): fea_description = {} for name in column_list: fea_description[name] = tf.io.FixedLenFeature( shape=(), dtype=data_type, default_value=default_value) out = tf.io.parse_single_example(example_proto, fea_description) fea_list = [] for name in col_list: fea_list.append(tf.reshape(out[name], (1,))) out = tf.concat(fea_list, axis=0) return out # 原始观测值:1*5维的数组 def parse_exmp_array(example_proto, column_list, data_type, default_value): fea_description = {} for name in column_list: fea_description[name] = tf.io.FixedLenSequenceFeature( shape=[5], dtype=data_type, allow_missing=True, default_value=-1) out = tf.io.parse_single_example(example_proto, fea_description) fea_list = [] for name in col_list: fea_list.append(tf.reshape(out[name], (1,5))) out = tf.concat(fea_list, axis=0) return out file_names = [file_name] dataset = tf.data.TFRecordDataset(file_names) num_fea = dataset.map(lambda x: parse_exmp(x, numeric_columns, tf.float32, 0.0)) cate_fea = dataset.map(lambda x: parse_exmp(x, cate_columns, tf.int64, 0)) arr_fea = dataset.map(lambda x: parse_exmp_array(x, arr_columns, tf.int64, -1))
Dataset对象是一个 Python 可迭代对象,创建数据集(dataset)有两种不同的方法:
如果输入数据可以放在内存中,创建Dataset的最简单方法是使用Dataset.from_tensor_slices()
# 原始观测值:0维 def parse_exmp(example_proto, column_list, data_type, default_value): fea_description = {} for name in column_list: fea_description[name] = tf.io.FixedLenFeature( shape=(), dtype=data_type, default_value=default_value) out = tf.io.parse_single_example(example_proto, fea_description) fea_list = [] for name in col_list: fea_list.append(tf.reshape(out[name], (1,))) out = tf.concat(fea_list, axis=0) return out # 原始观测值:1*5维的数组 def parse_exmp_array(example_proto, column_list, data_type, default_value): fea_description = {} for name in column_list: fea_description[name] = tf.io.FixedLenSequenceFeature( shape=[5], dtype=data_type, allow_missing=True, default_value=-1) out = tf.io.parse_single_example(example_proto, fea_description) fea_list = [] for name in col_list: fea_list.append(tf.reshape(out[name], (1,5))) out = tf.concat(fea_list, axis=0) return out file_names = [file_name] dataset = tf.data.TFRecordDataset(file_names) num_fea = dataset.map(lambda x: parse_exmp(x, numeric_columns, tf.float32, 0.0)) cate_fea = dataset.map(lambda x: parse_exmp(x, cate_columns, tf.int64, 0)) arr_fea = dataset.map(lambda x: parse_exmp_array(x, arr_columns, tf.int64, -1)) ## 把输入整理成模型需要的格式 data_set = tf.data.Dataset.zip(((num_fea, cate_fea), arr_fea)) data_set = data_set.shuffle(buffer_size = buffer_size).batch(batch_size) print("_____开始构建模型_____") ## 双输入 inputs_num_col = Input(shape=(len(numeric_columns),), name = "num_col", dtype='float32') inputs_cate_col = Input(shape=(len(cate_columns),), name = "cate_col", dtype='int32') ## 自定义的Embedding层,忽略 emb = MyEmb(emb_map, hash_emb_map)(inputs_cate_col) feature_input = tf.concat([emb, inputs_num_col], axis = -1, name='concat_cate_num_feature') ## 自定义的MMOE层,忽略 mmoe_layers = MMoE(units=layer_units, num_experts=num_experts, num_tasks=num_tasks)(feature_input) ## 自定义Tower层,忽略 CTR_logits = Tower(layer_num=tower_num_layer, layer_units=tower_num_layer_units, activation='relu', name='CTR_logits')(mmoe_layers[0]) CTCVR_logits = Tower(layer_num=tower_num_layer, layer_units=tower_num_layer_units, activation='relu', name='CVR_logits')(mmoe_layers[1]) output = tf.concat([CTR_logits, CTCVR_logits], axis = 1, name='output') model = tf.keras.Model(inputs=[inputs_num_col, inputs_cate_col], outputs=output) model.compile(loss='binary_crossentropy', optimizer=tf.keras.optimizers.Adam(learning_rate), metrics=['accuracy']) ## 模型训练 model.fit(x=data_set, epochs=epoch)
原文链接:https://www.yuque.com/docs/share/518e1fc9-b530-4e9e-a38e-6611f5c71346?# 《一、TF2.x数据加载处理pipeline》
内容多搬运自官方文档,建议具体细节、最新更新直接点击链接查看官方说明
作者邮箱:hu.yf@outlook.com 欢迎交流~
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。