当前位置:   article > 正文

TensorFlow2-高阶操作(二):张量分割【split(分割后:rank不变)】【unstack(分割后:rank-1)】_tensorflow如何实现张量分解

tensorflow如何实现张量分解

一、split

split(
    value,
    num_or_size_splits,
    axis=0,
    num=None,
    name='split'
)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7

将张量分割成子张量.

  • 如果 num_or_size_splits 是整数类型,num_split,则 value 沿维度 axis 分割成为 num_split 更小的张量.要求 num_split 均匀分配 value.shape[axis]。
  • 如果 num_or_size_splits 不是整数类型,则它被认为是一个张量 size_splits,然后将 value 分割成 len(size_splits) 块.第 i 部分的形状与 value 的大小相同,除了沿维度 axis 之外的大小 size_splits[i]。
import pandas as pd
import tensorflow as tf

x = tf.Variable(tf.random.uniform([5, 30], -1, 1))
print("x = \n", pd.DataFrame(x.numpy()))
print("-" * 200)

# Split `x` into 3 tensors along dimension 1
s0, s1, s2 = tf.split(x, num_or_size_splits=3, axis=1)
print("s0 = \n", pd.DataFrame(s0.numpy()))
print("-" * 50)
print("s1 = \n", pd.DataFrame(s1.numpy()))
print("-" * 50)
print("s2 = \n", pd.DataFrame(s2.numpy()))
print("-" * 200)

# Split `x` into 3 tensors with sizes [4, 15, 11] along dimension 1
t0, t1, t2 = tf.split(x, num_or_size_splits=[4, 15, 11], axis=1)
print("t0 = \n", pd.DataFrame(t0.numpy()))
print("-" * 50)
print("t1 = \n", pd.DataFrame(t1.numpy()))
print("-" * 50)
print("t2 = \n", pd.DataFrame(t2.numpy()))
print("-" * 200)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24

打印结果:

x = 
         0         1         2     ...           27        28        29
0 -0.888679  0.882839  0.739282    ...    -0.688343 -0.930151 -0.875597
1 -0.153850 -0.319729 -0.098402    ...     0.489693 -0.170844 -0.091632
2  0.003379  0.187339  0.795501    ...     0.379071 -0.256689  0.564788
3 -0.372030  0.340384 -0.875375    ...    -0.214336  0.717279  0.092451
4 -0.495783  0.257741 -0.358638    ...    -0.921029 -0.830439  0.507138

[5 rows x 30 columns]
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
s0 = 
           0         1         2    ...            7         8         9
0 -0.888679  0.882839  0.739282    ...    -0.403924  0.196670 -0.098327
1 -0.153850 -0.319729 -0.098402    ...     0.418904  0.081062  0.173876
2  0.003379  0.187339  0.795501    ...     0.615282 -0.385442 -0.311836
3 -0.372030  0.340384 -0.875375    ...    -0.252203 -0.587342  0.321012
4 -0.495783  0.257741 -0.358638    ...     0.552696  0.620588  0.132702

[5 rows x 10 columns]
--------------------------------------------------
s1 = 
           0         1         2    ...            7         8         9
0  0.509016  0.740289 -0.964265    ...     0.459772 -0.697755 -0.540041
1  0.904286  0.986134 -0.409174    ...     0.187198 -0.445747  0.813097
2 -0.137152  0.934053 -0.751823    ...     0.309953  0.716927  0.848913
3  0.096014  0.069597  0.777320    ...    -0.907295 -0.384888  0.764411
4 -0.706331 -0.901017 -0.529774    ...    -0.301620  0.066731  0.770751

[5 rows x 10 columns]
--------------------------------------------------
s2 = 
           0         1         2    ...            7         8         9
0 -0.356173 -0.040504  0.150185    ...    -0.688343 -0.930151 -0.875597
1 -0.436071 -0.224807  0.383009    ...     0.489693 -0.170844 -0.091632
2  0.169518  0.384529 -0.600068    ...     0.379071 -0.256689  0.564788
3  0.038849  0.754196 -0.049200    ...    -0.214336  0.717279  0.092451
4  0.245371 -0.548065  0.338353    ...    -0.921029 -0.830439  0.507138

[5 rows x 10 columns]
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
t0 = 
           0         1         2         3
0 -0.888679  0.882839  0.739282  0.454827
1 -0.153850 -0.319729 -0.098402 -0.764573
2  0.003379  0.187339  0.795501 -0.467434
3 -0.372030  0.340384 -0.875375  0.350312
4 -0.495783  0.257741 -0.358638  0.301579
--------------------------------------------------
t1 = 
          0         1         2     ...           12        13        14
0 -0.484341 -0.429574  0.999090    ...     0.634394  0.459772 -0.697755
1  0.325134 -0.227807 -0.890493    ...     0.152983  0.187198 -0.445747
2 -0.074674 -0.037023  0.830544    ...    -0.993245  0.309953  0.716927
3  0.044287  0.245083 -0.858829    ...    -0.583070 -0.907295 -0.384888
4 -0.105187  0.293733  0.783647    ...     0.397994 -0.301620  0.066731

[5 rows x 15 columns]
--------------------------------------------------
t2 = 
          0         1         2     ...           8         9         10
0 -0.540041 -0.356173 -0.040504    ...    -0.688343 -0.930151 -0.875597
1  0.813097 -0.436071 -0.224807    ...     0.489693 -0.170844 -0.091632
2  0.848913  0.169518  0.384529    ...     0.379071 -0.256689  0.564788
3  0.764411  0.038849  0.754196    ...    -0.214336  0.717279  0.092451
4  0.770751  0.245371 -0.548065    ...    -0.921029 -0.830439  0.507138

[5 rows x 11 columns]
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

Process finished with exit code 0
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70

二、unstack

将秩为 R 的张量的给定维度出栈为秩为 (R-1) 的张量.

通过沿 axis 维度将 num 张量从 value 中分离出来.如果没有指定 num(默认值),则从 value 的形状推断.如果 value.shape[axis] 不知道,则引发 ValueError.

例如,给定一个具有形状 (A, B, C, D) 的张量.

  • 如果 axis == 0,那么 output 中的第 i 个张量就是切片 value[i, :, :, :],并且 output 中的每个张量都具有形状 (B, C, D).(请注意,出栈的维度已经消失,不像split).
  • 如果 axis == 1,那么 output 中的第 i 个张量就是切片 value[:, i, :, :],并且 output 中的每个张量都具有形状 (A, C, D).
tf.unstack(value, num=None, axis=0, name='unstack')
  • 1
  • value: A rank R > 0 Tensor to be unstacked.
  • num: An int. The length of the dimension axis. Automatically inferred if None (the default).
  • axis: An int. The axis to unstack along. Defaults to the first dimension. Negative - values: wrap around, so the valid range is [-R, R).
  • name: A name for the operation (optional).
import tensorflow as tf

x = tf.reshape(tf.range(12), (3, 4))
print("x = \n", x)
print("-" * 200)

p, q, r = tf.unstack(x)
print("p = ", p)
print("-" * 50)
print("q = ", q)
print("-" * 50)
print("r = ", r)
print("-" * 200)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13

打印结果:

x = 
 tf.Tensor(
[[ 0  1  2  3]
 [ 4  5  6  7]
 [ 8  9 10 11]], shape=(3, 4), dtype=int32)
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
p =  tf.Tensor([0 1 2 3], shape=(4,), dtype=int32)
--------------------------------------------------
q =  tf.Tensor([4 5 6 7], shape=(4,), dtype=int32)
--------------------------------------------------
r =  tf.Tensor([ 8  9 10 11], shape=(4,), dtype=int32)
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

Process finished with exit code 0
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
本文内容由网友自发贡献,转载请注明出处:https://www.wpsshop.cn/w/盐析白兔/article/detail/176608
推荐阅读
相关标签
  

闽ICP备14008679号