赞
踩
之前一直以为对ndarray
的各种索引切片操作还算得上熟悉,但今天师弟问了我Informer
实现中ProbSparse Self-Attention
的一些Tensor索引操作,才发现有些操作还不太懂,而网上也缺乏相关的参考资料。因此在一系列探索下,写下了这篇博客。
构造示例数组x
,为一个三维tesnor:
import torch
x = torch.arange(16).reshape(2,2,4)
x
tensor([[[ 0, 1, 2, 3],
[ 4, 5, 6, 7]],
[[ 8, 9, 10, 11],
[12, 13, 14, 15]]])
最简单的是直接传入标量,分别取对应维度的对应数据即可。需要主要的是:
表示取该维度全数据,x:y
表示取该维度上[x,y)
之间的数据。
示例1:
x[0, 0, 0], x[0, 0, :], x[0, 0, 0:2]
(tensor(0), tensor([0, 1, 2, 3]), tensor([0, 1]))
其次是在单维度上传入一维tensor数组,就是在对应的维度上依次获取到对应元素即可。
示例2:
x[1, 1, [0,2,1,0]]
tensor([12, 14, 13, 12])
在多个维度上传入一维tensor数组,类似于numpy中的花式索引,对应的tensor数组提供索引关系。
示例3:
x[0, [0,1], [2,3]]
获取到的元素为x[0,0,2]
和x[0,1,3]
,即[0,1]
提供dim 1的索引值,[2,3]
提供dim 2的索引值。
tensor([2, 7])
示例4:
x[0, [0,1,0], [2]]
这种做法会利用到广播机制,实际上的操作会变成x[0, [0,1,0], [2,2,2]]
tensor([2, 6, 2])
但如果写成这样就会报错:
x[0, [0,1,0], [2,1]]
---------------------------------------------------------------------------
IndexError Traceback (most recent call last)
<ipython-input-21-94117f6cd73a> in <module>
----> 1 x[0, [0,1,0], [2,1]]
IndexError: shape mismatch: indexing tensors could not be broadcast together with shapes [3], [2]
在2的基础上,传入的不止是一维的tensor数组,而是2维甚至多维的。
示例5:
x[0, [[0,1]], [[2,3]]]
在dim1上的tensor数组d1为[[0,1]]
,在dim2上的tensor数组d2为[[2,3]]
,同样进行花式索引构成[x[0,0,2], x[0,1,3]]
。注意输出多了一层维度
去掉dim0来看,获取的元素位置就是[(0,2), (1,3)]
tensor([[2, 7]])
示例6:
x[0, [[0,1],[0,1]], [[0,1],[2,3]]]
在dim1上的tensor数组d1为[[0,1],[0,1]]
,在dim2上的tensor数组d2为[[0,1],[2,3]]
,分别进行花式索引构成[[x[0,0,0], x[0,1,1]],[x[0,0,2], x[0,1,3]]]
去掉dim0来看,获取的元素位置就是[(0,0), (1,1)], [(0,2), (1,3)]
tensor([[0, 5],
[2, 7]])
同样可以利用到类似于示例3的广播机制,如下所示:
x[0, [[0,1],[0,1]], [0,1]], x[0, [[0,1],[0,1]], [0]]
tensor([[0, 5], tensor([[0, 4],
[0, 5]]) [0, 4]])
以及更复杂的示例:
x[0, [[0],[1]], [[0,1],[2,3]]]
等效于x[0, [[0, 0], [1, 1]], [[0,1],[2,3]]]
tensor([[0, 1],
[6, 7]])
在informer的_prob_QK
函数中,有一段代码是为了从
Q
Q
Q中根据索引
M
t
o
p
M_{top}
Mtop得到
Q
r
e
d
u
c
e
Q_{reduce}
Qreduce,其中
Q
Q
Q的维度尺寸为 (B,H,L,E),
M
t
o
p
M_{top}
Mtop为 (B,H,X),得到的
Q
r
e
d
u
c
e
Q_{reduce}
Qreduce为 (B,H,X,E)。
乍一看,其实不是很好去实现这种功能,不能通过普通的Tensor索引去获取到对应元素,最朴素的想法是遍历B*H遍,然后分别获取对应的Q[b, h, x]
,然后再拼接起来。但这样的时间复杂度会达到O(b * h)的级别,怎么使用矩阵机制呢,官方代码给出的实现如下:
Q_reduce = Q[torch.arange(B)[:, None, None],torch.arange(H)[None, :, None],M_top, :]
解释:其中torch.arange(B)[:, None, None]
先生成一个[0,B)之间的一维数组,再扩充成3维,最后的维度为 (B, 1, 1)【后面称Mb】,同理torch.arange(H)[None, :, None]
最后维度为 (1, H, 1)【后面称Mh】,而
M
t
o
p
M_{top}
Mtop为 (B,H,X)。
首先会利用上面描述的广播机制,将Mb和Mh的维度扩充成(B,H,X),值得注意的是Mb是从(B,1,1)扩充的,这意味着只要确定的dim 0
,Mb中的所有值都是一样的,比如Mb[0]
里面就是一个(H,X)维的全0矩阵;同理Mh只要确定了dim 1
,那么剩余的都是一样的值。
利用广播机制后,维度全部扩充为 (B, H, X),再进行花式索引,分别获取得到对应的值。
下面写个示例验证一下:
import torch
Q = torch.arange(16).reshape(2, 2, 4)
M_top = torch.randint(4, (2,2,2))
(tensor([[[ 0, 1, 2, 3],
[ 4, 5, 6, 7]],
[[ 8, 9, 10, 11],
[12, 13, 14, 15]]]),
tensor([[[3, 0],
[2, 0]],
[[2, 0],
[1, 3]]]))
我们要想从Q【2,2,4】中根据M_top【2,2,2】来获取对应的值,得到一个Q_reduce【2,2,2】,实现如下:
Q[torch.arange(2)[:, None, None], torch.arange(2)[None,:,None], M_top]
tensor([[[ 3, 0],
[ 6, 4]],
[[10, 8],
[13, 15]]])
为了方便理解,这里显示Mb,Mh,M_top经过广播后的结果:
以上均为个人理解和实验推理,如果不对和待补充的地方,还请指正。
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。