赞
踩
最近遇到了需要将高于2维度的张量相乘的需求,通过互联网资源查到了先用tf.reshape()降到2维再运算的骚操作。下面验证这种操作的可靠性。
#测试多维矩阵乘法。问题来自于mul-attention模型的矩阵运算 #2019-7-14编辑 import tensorflow as tf import numpy as np #定义两个三维矩阵 #k.shape = [batch_size, seq_length, embedded_size] #w.shape = [embedded_size, d_k, h_num] k = tf.Variable(tf.random_uniform([3, 4, 5])) w = tf.Variable(tf.random_uniform([5, 6, 7])) #实现k与w的相乘,目标维度为[batch_size, seq_length, d_k, h_num] #通过reshape的方式,将矩阵降到2维,实现矩阵乘法,再通过reshape还原 k_2d = tf.reshape(k, [-1, 5]) w_2d = tf.reshape(w, [5, -1]) r_2d = tf.matmul(k_2d, w_2d) r_4d = tf.reshape(r_2d, [-1, 4, 6, 7]) #运算结果 with tf.Session() as sess: sess.run(tf.global_variables_initializer()) r_untested = sess.run(r_4d) k_3d, w_3d = sess.run([k, w]) print(np.dot(k_3d[0,:,:],w_3d[:,:,0])) #array([[0.68616796, 1.147416 , 1.2250627 , 1.0267124 , 0.5699807 , 0.65192497], [1.2962847 , 0.63438064, 1.7439795 , 1.2534602 , 0.8585079 , 0.9535629 ], [1.0780972 , 1.466816 , 1.623834 , 1.4493611 , 0.9913111 , 1.1141219 ], [0.6155605 , 1.0016347 , 0.95043844, 0.8071648 , 0.6317205 , 0.8374078 ]], dtype=float32) print(r_untested[0,:,:,0]) #array([[0.68616796, 1.147416 , 1.2250627 , 1.0267124 , 0.5699807 , 0.65192497], [1.2962846 , 0.6343807 , 1.7439795 , 1.2534602 , 0.8585079 , 0.9535629 ], [1.0780972 , 1.4668161 , 1.623834 , 1.449361 , 0.9913111 , 1.1141219 ], [0.6155604 , 1.0016347 , 0.95043844, 0.8071648 , 0.6317204 , 0.8374078 ]], dtype=float32)
最终发现结果相同,大胆的用吧
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。