当前位置:   article > 正文

tensorflow高维度张量相乘_tensorflow 不同维度的张量点乘

tensorflow 不同维度的张量点乘

通过tf.reshape()对高维度张量降维,验证高维度张量相乘结果

最近遇到了需要将高于2维度的张量相乘的需求,通过互联网资源查到了先用tf.reshape()降到2维再运算的骚操作。下面验证这种操作的可靠性。

#测试多维矩阵乘法。问题来自于mul-attention模型的矩阵运算
#2019-7-14编辑
import tensorflow as tf
import numpy as np

#定义两个三维矩阵
#k.shape = [batch_size, seq_length, embedded_size]
#w.shape = [embedded_size, d_k, h_num]
k = tf.Variable(tf.random_uniform([3, 4, 5]))
w = tf.Variable(tf.random_uniform([5, 6, 7]))

#实现k与w的相乘,目标维度为[batch_size, seq_length, d_k, h_num]
#通过reshape的方式,将矩阵降到2维,实现矩阵乘法,再通过reshape还原
k_2d = tf.reshape(k, [-1, 5])
w_2d = tf.reshape(w, [5, -1])
r_2d = tf.matmul(k_2d, w_2d)
r_4d = tf.reshape(r_2d, [-1, 4, 6, 7])

#运算结果
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    r_untested = sess.run(r_4d)
    k_3d, w_3d = sess.run([k, w])

print(np.dot(k_3d[0,:,:],w_3d[:,:,0]))
#array([[0.68616796, 1.147416  , 1.2250627 , 1.0267124 , 0.5699807 ,
        0.65192497],
       [1.2962847 , 0.63438064, 1.7439795 , 1.2534602 , 0.8585079 ,
        0.9535629 ],
       [1.0780972 , 1.466816  , 1.623834  , 1.4493611 , 0.9913111 ,
        1.1141219 ],
       [0.6155605 , 1.0016347 , 0.95043844, 0.8071648 , 0.6317205 ,
        0.8374078 ]], dtype=float32)
print(r_untested[0,:,:,0])
#array([[0.68616796, 1.147416  , 1.2250627 , 1.0267124 , 0.5699807 ,
        0.65192497],
       [1.2962846 , 0.6343807 , 1.7439795 , 1.2534602 , 0.8585079 ,
        0.9535629 ],
       [1.0780972 , 1.4668161 , 1.623834  , 1.449361  , 0.9913111 ,
        1.1141219 ],
       [0.6155604 , 1.0016347 , 0.95043844, 0.8071648 , 0.6317204 ,
        0.8374078 ]], dtype=float32)

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43

最终发现结果相同,大胆的用吧

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/我家小花儿/article/detail/176461
推荐阅读
相关标签
  

闽ICP备14008679号