当前位置:   article > 正文

keras: tf.data.Dataset.from_tensor_slices()_tf.data.dataset元组

tf.data.dataset元组


机制

  • 作用:创建一个数据集tf.data.Dataset,将数据inputs和标签targets联立在一起。
  • 要求:所有输入张量的第一个维度必须相同。不然inputs和targets不对应。
  • 机制:沿着它们的第一个维度切片。
    意思是说,第一个维度表示有n个样本,将inputs[i]targets[i]组合到一起,从而Dateset有n个组合。

数组

import numpy as np
import tensorflow as tf

# 一维
dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3])
print(list(dataset.as_numpy_iterator()))
# [1, 2, 3]

# 二维
dataset = tf.data.Dataset.from_tensor_slices([[1, 2], [3, 4]])
print(list(dataset.as_numpy_iterator()))
# [array([1, 2]), array([3, 4])]

# 三维
dataset = tf.data.Dataset.from_tensor_slices(
    [[[1, 2], [3, 4]], 
     [[5, 6], [7, 8]]])
print(list(dataset.as_numpy_iterator()))
# [array([[1, 2],
#        [3, 4]]), 
#  array([[5, 6],
#        [7, 8]])]
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • [1, 2, 3]是一维:切完还是一样。
  • [[1, 2], [3, 4]]是二维:a[0]是[1,2],a[1]是[3,4],结果两行组合[array([1, 2]), array([3, 4])]
  • [[[1, 2], [3, 4]], [[5, 6], [7, 8]]]是三维:a[0]是[[1, 2], [3, 4]],a[1]是[[5, 6], [7, 8]],结果两行组合[array([[1, 2],[3, 4]]), array([[5, 6],[7, 8]])]

元组

# 元组
dataset = tf.data.Dataset.from_tensor_slices(([1, 2], [3, 4], [5, 6]))
print(list(dataset.as_numpy_iterator()))
# [(1, 3, 5), (2, 4, 6)]
  • 1
  • 2
  • 3
  • 4

元组的拆维,不是元组[0]为第一维,而是元组内的数组[0]为第一维,即([1, 2], [3, 4], [5, 6])→[0]:1,3,5;[1]:2,4,6

字典

# 字典
dataset = tf.data.Dataset.from_tensor_slices({"a": [1, 2], "b": [3, 4]})
print(list(dataset.as_numpy_iterator()))
# [{'a': 1, 'b': 3}, {'a': 2, 'b': 4}]
  • 1
  • 2
  • 3
  • 4

同元组。

实战

组合inputs和target

  • inputs(3,2), target(3,1)
    PS:target(3)也行。
features = [[1, 3], [2, 1], [3, 3]]    # shape(3,2),3个样本,2个特征
labels = [['A'], ['B'], ['A']]         # shape(3,1),3个样本,1个标签
dataset = tf.data.Dataset.from_tensor_slices((features, labels))
print(list(dataset.as_numpy_iterator()))
# [(array([1, 3], dtype=int32), array([b'A'], dtype=object)), 
#  (array([2, 1], dtype=int32), array([b'B'], dtype=object)), 
#  (array([3, 3], dtype=int32), array([b'A'], dtype=object))]
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 复杂一点:图片inputs(3,2,2), target(3,2)
features = [[[1, 3], [2, 3]],
            [[2, 1], [1, 2]],
            [[3, 3], [3, 2]]]  # shape=(3, 2, 2),3个样本,特征是2*2(图片)
labels = [['A', 'A'],
          ['B', 'B'],
          ['A', 'B']])          # shape=(3, 2),3个样本,2个标签
dataset = tf.data.Dataset.from_tensor_slices((features, labels ))
print(list(dataset.as_numpy_iterator()))
# [(array([[1, 3],
#        [2, 3]], dtype=int32), array([b'A', b'A'], dtype=object)), (array([[2, 1],
#        [1, 2]], dtype=int32), array([b'B', b'B'], dtype=object)), (array([[3, 3],
#        [3, 2]], dtype=int32), array([b'A', b'B'], dtype=object))]
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 别的写法:还能用tf.data.Dataset.zip()
features_dataset = tf.data.Dataset.from_tensor_slices(features)
labels_dataset = tf.data.Dataset.from_tensor_slices(labels)
dataset = tf.data.Dataset.zip((features_dataset, labels_dataset))
print(list(dataset.as_numpy_iterator()))
# [(array([1, 3]), b'A'), (array([2, 1]), b'B'), (array([3, 3]), b'A')]
  • 1
  • 2
  • 3
  • 4
  • 5
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/凡人多烦事01/article/detail/293112
推荐阅读
相关标签
  

闽ICP备14008679号