当前位置:   article > 正文

【深度学习】TensorFlow基础介绍

【深度学习】TensorFlow基础介绍

TensorFlow

模型

张量、变量共同点:具有形状、类型、值等3个属性。

不同点:变量可被TensorFlow的自动求导机制求导,常被用于机器学习模型的参数。

tfrecord

tensorflow定义的数据格式,一种二进制文件格式,用于保存和读取图像和文本数据。tfrecord文件包含了tf.train.Example protobuf数据。It is designed for use with TensorFlow and is used throughout the higher-level APIS such as TFX.

基本结构与数据类型

tf.train.Example的数据结构是一个字典称为Features,其内部结构可从proto文件看出:

message Example {
 Features features = 1;
};
 
message Features{
 map<string, Feature> featrue = 1;
};
 
message Feature{
    oneof kind{
        BytesList bytes_list = 1;
        FloatList float_list = 2;
        Int64List int64_list = 3;
    }
};
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15

数据类型Feature有3个,Int64、Bytes、Float;Int64存储bool、Enum、uint32、int32、int64、uint64,Bytes存储字符串、二进制,Float存储float(float32)和double(float64)。

文件格式即把数据参考字典结构做二进制数据的protobuf序列化,称为string。

def serialize_example(f1, f2, f3, f4):
    fts = {
        "feature0": _int64_feature(f1),
        "feature1": _int64_feature(f2),
        "feature2": _bytes_feature(f3),
        "feature3": _float_feature(f4),
    }
    m = tf.train.Example(features=tf.train.Features(feature=fts))
    return m.SerializeToString()
ps = serialize_example(3, True, b"goal", 0.999)
ex_proto = tf.train.Example.FromString(ps)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11

tf.train.Feature是被tf.train.Example兼容的。

import tensorflow as tf
def _bytes_feature(x):
    if isinstance(x, type(tf.constant(0))):
        x = x.numpy()
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[x]))
  • 1
  • 2
  • 3
  • 4
  • 5
读写tfrecord文件
  1. 写文件
# Write the `tf.train.Example` observations to the file.
with tf.io.TFRecordWriter(filename) as writer:
    for i in range(n_observations):
        example = serialize_example(feature0[i], feature1[i], feature2[i], feature3[i])
        writer.write(example)
  • 1
  • 2
  • 3
  • 4
  • 5
  1. 读文件
fn = "./Waymo.tfrecord"
rd = tf.data.TFRecordDataset(fn)
# 数据格式 
feature_description = {
    'feature0': tf.io.FixedLenFeature([], tf.int64, default_value=0),
    'feature1': tf.io.FixedLenFeature([], tf.int64, default_value=0),
    'feature2': tf.io.FixedLenFeature([], tf.string, default_value=''),
    'feature3': tf.io.FixedLenFeature([], tf.float32, default_value=0.0),
}

def _parse_function(example_proto):
  # Parse the input `tf.train.Example` proto using the dictionary above.
  return tf.io.parse_single_example(example_proto, feature_description)

parsed_dataset = raw_dataset.map(_parse_function)
for parsed_record in parsed_dataset.take(10):
    print(repr(parsed_record))
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
Waymo Open Dataset

采用tfrecord的数据协议,Dataset结构需参考
https://github.com/waymo-research/waymo-open-dataset/blob/master/waymo_open_dataset/dataset.proto
使用Python库waymo-open-dataset

#与tensorflow版本对应,如tf为2.3.0
pip3 install waymo-open-dataset-tf-2-3-0 --user
  • 1
  • 2
fn = [
    "/data/Waymo_training_segment-10023947602400723454_1120_000_1140_000_with_camera_labels.tfrecord"
]
dataset = tf.data.TFRecordDataset(fn)
for data in dataset.take(1000):
    frame = open_dataset.Frame()
    frame.ParseFromString(bytearray(data.numpy()))
    # plt.figure(figsize=(25, 20))
    # for index, image in enumerate(frame.images):
    #   show_camera_image(image, frame.camera_labels, [3, 3, index+1])
    # plt.show()
    ts = frame.timestamp_micros
    st_img = frame.images[0]
    for labels in frame.camera_labels:
        if labels.name == st_img.name:
            for label in labels.labels:
                x = int(label.box.center_x - 0.5 * label.box.length)
                y = int(label.box.center_y - 0.5 * label.box.width)
                width = int(label.box.length)
                height = int(label.box.width)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20

重复造轮子:用tf.io实现读取数据集。

问题

https://stackoverflow.com/questions/61166864/tensorflow-python-framework-ops-eagertensor-object-has-no-attribute-in-graph

Waymo Open Dataset文件解析格式,如何确定字典结构

raw_image_dataset = tf.data.TFRecordDataset('images.tfrecords')

# Create a dictionary describing the features.
image_feature_description = {
    'height': tf.io.FixedLenFeature([], tf.int64),
    'width': tf.io.FixedLenFeature([], tf.int64),
    'depth': tf.io.FixedLenFeature([], tf.int64),
    'label': tf.io.FixedLenFeature([], tf.int64),
    'image_raw': tf.io.FixedLenFeature([], tf.string),
}

def _parse_image_function(example_proto):
  # Parse the input tf.train.Example proto using the dictionary above.
  return tf.io.parse_single_example(example_proto, image_feature_description)

parsed_image_dataset = raw_image_dataset.map(_parse_image_function)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/小蓝xlanll/article/detail/198642
推荐阅读
相关标签
  

闽ICP备14008679号