当前位置:   article > 正文

tfrecord原理详解 手把手教生成tfrecord文件与解析tfrecord文件

tfrecord

1.什么是tfrecord

TFRecord 是Google官方推荐的一种数据格式,是Google专门为TensorFlow设计的一种数据格式。

TFRecord本质上是二进制文件,目的是更好的利用内存。用户可以将训练集/测试集打包成生成TFRecord文件,后续就可以配合TF中相关的API实现数据的加载,处理,训练等一系列工作,可以方便高效的训练与评估模型。

2.tfrecord原理

TFRecord 并非是TensorFlow唯一支持的数据格式,你也可以使用CSV或文本等格式,但是对于TensorFlow来说,TFRecord 是最友好也是最方便的。
tf.Example是TFRecord的基本结果,其实他就是一个Protobuffer定义的message,表示一组string到bytes value的映射。TFRecord文件里面存储的就是序列化的tf.Example。在github上tensorflow的源码就能看到其定义
message Example

message Example {
  Features features = 1;
};
  • 1
  • 2
  • 3

里面只有一个变量features。如果我们继续查看Features

message Features {
  // Map from feature name to feature.
  map<string, Feature> feature = 1;
};
  • 1
  • 2
  • 3
  • 4

features里面就是一组string到Feature的映射。其中这个string表示feature name,后面的Feature又是一个message

继续查看Feature的定义

message Feature {
  // Each feature can be exactly one kind.
  oneof kind {
    BytesList bytes_list = 1;
    FloatList float_list = 2;
    Int64List int64_list = 3;
  }
};
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8

到这里,我们就可以看到tfrecord里存储的真正数据类型有三种
bytes_list: 可以存储string 和byte两种数据类型。
float_list: 可以存储float(float32)与double(float64) 两种数据类型 。
int64_list: 可以存储:bool, enum, int32, uint32, int64, uint64 。

3.实操生成tfrecords文件

下面来手把手教大家如何生成tfrecords文件,并解析tfrecords文件。
我们以titanic数据为例

PassengerId,Survived,Pclass,Name,Sex,Age,SibSp,Parch,Ticket,Fare,Cabin,Embarked
1,0,3,"Braund, Mr. Owen Harris",male,22,1,0,A/5 21171,7.25,,S
2,1,1,"Cumings, Mrs. John Bradley (Florence Briggs Thayer)",female,38,1,0,PC 17599,71.2833,C85,C
3,1,3,"Heikkinen, Miss. Laina",female,26,0,0,STON/O2. 3101282,7.925,,S
4,1,1,"Futrelle, Mrs. Jacques Heath (Lily May Peel)",female,35,1,0,113803,53.1,C123,S
5,0,3,"Allen, Mr. William Henry",male,35,0,0,373450,8.05,,S
6,0,3,"Moran, Mr. James",male,,0,0,330877,8.4583,,Q
7,0,1,"McCarthy, Mr. Timothy J",male,54,0,0,17463,51.8625,E46,S
8,0,3,"Palsson, Master. Gosta Leonard",male,2,3,1,349909,21.075,,S
9,1,3,"Johnson, Mrs. Oscar W (Elisabeth Vilhelmina Berg)",female,27,0,2,347742,11.1333,,S
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10

上面是titanic部分数据,第一行为各列字段名,后面几行为具体数据。如果想看完整的titanic数据,大家可以自行网上搜索并下载。

首先定义几个辅助方法

import tensorflow as tf
import csv

# Generate Integer Features.
def build_int64_feature(data):
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[data]))

# Generate Float Features.
def build_float_feature(data):
    return tf.train.Feature(float_list=tf.train.FloatList(value=[data]))

# Generate String Features.
def build_string_feature(data):
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[str(data).encode()]))
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14

然后再定义生成Example的方法

# Generate a TF `Example`, parsing all features of the dataset.
def convert_to_tfexample(survived, pclass, name, sex, age, sibsp, parch, ticket, fare):
    return tf.train.Example(
        features=tf.train.Features(
            feature={
                'survived': build_int64_feature(survived),
                'pclass': build_int64_feature(pclass),
                'name': build_string_feature(name),
                'sex': build_string_feature(sex),
                'age': build_string_feature(age),
                'sibsp': build_int64_feature(sibsp),
                'parch': build_int64_feature(parch),
                'ticket': build_string_feature(ticket),
                'fare': build_float_feature(fare),
            })
    )
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16

再将其写入文件

def write_tf_records():
    writer = tf.io.TFRecordWriter('output.tfrecords')
    with open('titanic.csv') as f:
        reader = csv.reader(f, skipinitialspace=True)
        for i, record in enumerate(reader):
            if i == 0:
                continue
            survived, pclass, name, sex, age, sibsp, parch, ticket, fare = record[1:10]
            print("age, fare is: ", age, fare)
            example = convert_to_tfexample(int(survived), int(pclass), name, sex, age, int(sibsp), int(parch), ticket, float(fare))

            writer.write(example.SerializeToString())
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12

这样,就生成了名为output.tfrecords的文件。

4.解析tfrecords文件

接下来我们解析上面生成的文件

首先定义features字典:

features = {
        'survived': tf.io.FixedLenFeature([], tf.int64),
        'pclass': tf.io.FixedLenFeature([], tf.int64),
        'name': tf.io.FixedLenFeature([], tf.string),
        'sex': tf.io.FixedLenFeature([], tf.string),
        'age': tf.io.FixedLenFeature([], tf.string),
        'sibsp': tf.io.FixedLenFeature([], tf.int64),
        'parch': tf.io.FixedLenFeature([], tf.int64),
        'ticket': tf.io.FixedLenFeature([], tf.string),
        'fare': tf.io.FixedLenFeature([], tf.float32)
}
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11

然后使用parse_single_example方法,解析单条数据

# Parse features, using the above template.
def parse_record(record):
    return tf.io.parse_single_example(record, features=features)
  • 1
  • 2
  • 3

主方法:

def read_tf_records():
    filenames = ["output.tfrecords"]
    data = tf.data.TFRecordDataset(filenames)
    data = data.map(parse_record)
    data = data.repeat()
    # Shuffle data.
    data = data.shuffle(buffer_size=1000)
    # Batch data (aggregate records together).
    data = data.batch(batch_size=4)
    # Prefetch batch (pre-load batch for faster consumption).
    data = data.prefetch(buffer_size=1)


    # Dequeue data and display.
    for record in data.take(1):
        print("record is: ", record)
        print("record[survived is: ", record['survived'])
        print(type(record['survived']))
        print()
        print(record['survived'].numpy())
        print(record['name'].numpy())
        print(record['fare'].numpy())
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22

主方法的输出为:

record is:  {'age': <tf.Tensor: shape=(4,), dtype=string, numpy=array([b'', b'9', b'20', b'32'], dtype=object)>, 'fare': <tf.Tensor: shape=(4,), dtype=float32, numpy=array([16.1   , 27.9   , 15.7417,  7.925 ], dtype=float32)>, 'name': <tf.Tensor: shape=(4,), dtype=string, numpy=
array([b'Davison, Mrs. Thomas Henry (Mary E Finck)',
       b'Skoog, Miss. Mabel', b'Nakid, Mr. Sahid', b'Jussila, Mr. Eiriik'],
      dtype=object)>, 'parch': <tf.Tensor: shape=(4,), dtype=int64, numpy=array([0, 2, 1, 0])>, 'pclass': <tf.Tensor: shape=(4,), dtype=int64, numpy=array([3, 3, 3, 3])>, 'sex': <tf.Tensor: shape=(4,), dtype=string, numpy=array([b'female', b'female', b'male', b'male'], dtype=object)>, 'sibsp': <tf.Tensor: shape=(4,), dtype=int64, numpy=array([1, 3, 1, 0])>, 'survived': <tf.Tensor: shape=(4,), dtype=int64, numpy=array([1, 0, 1, 1])>, 'ticket': <tf.Tensor: shape=(4,), dtype=string, numpy=array([b'386525', b'347088', b'2653', b'STON/O 2. 3101286'], dtype=object)>}
record[survived is:  tf.Tensor([1 0 1 1], shape=(4,), dtype=int64)
<class 'tensorflow.python.framework.ops.EagerTensor'>

[1 0 1 1]
[b'Davison, Mrs. Thomas Henry (Mary E Finck)' b'Skoog, Miss. Mabel'
 b'Nakid, Mr. Sahid' b'Jussila, Mr. Eiriik']
[16.1    27.9    15.7417  7.925 ]
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11

使用上面的方式,就解析出来原有的数据!

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/小丑西瓜9/article/detail/123773
推荐阅读
相关标签
  

闽ICP备14008679号