赞
踩
先整体看relative_position的调用说明
The relative position is defined as
memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to
position. If bidirectional=False, then positive relative positions are invalid. (如果双向=False,正相对位置是无效的。)
We use smaller buckets for small absolute relative_position and larger buckets for larger absolute relative_positions.
我们对于小的相对位置使用小的buckets,对于大的相对位置使用大的buckets。
All relative positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket.
所有相对位置大于max_distance的映射到相同的bucket,所有相对位置小于等于max_distance的映射到相同的bucket。
This should allow for more graceful generalization to longer sequences than the model has been trained on
对于相对位置编码的综合的解释
这个设计的思路其实也很直观,就是比较邻近的位置(0~7),我们需要比较得精细一些,所以给它们都分配一个独立的位置编码,至于稍远的位置(比如8~11),我们不用区分得太清楚,所以它们可以共用一个位置编码,距离越远,共用的范围就可以越大,直到达到指定范围再clip。使用类似于nezha的相对位置编码,考虑i-j的相对位置内容信息。
首先查看一下初始化参数的值
self.has_relative_attention_bias = False
self.relative_attention_num_buckets = 32
self.d_model = 512
self.key_value_proj_dim = 64
self.inner_dim = 512
接着进入forward调用程序部分
batch_size,seq_length = hidden_states.shape[:2]
real_seq_length = seq_length
得到参数
batch_size = 1,seq_length = 15,real_seq_length = 15
然后经历三个dense网络层
query_states = shape(self.q(hidden_states)) # (batch_size, n_heads, seq_length, dim_per_head)
# get key/value states
key_states = project(
hidden_states, self.k, key_value_states, past_key_value[0] if past_key_value is not None else None
)
value_states = project(
hidden_states, self.v, key_value_states, past_key_value[1] if past_key_value is not None else None
)
这里对应的形状参数为
query_states =
torch.Size([1, 8, 15, 64])
......
接下来本质上是跟nezha模型一样,有个position_bias参数,然后softmax(Q*K+bias)…
这里T5中的selfattention公式与之前的selfattention公式有所区别,T5之中的selfattention公式内容为
s
o
f
t
m
a
x
(
Q
∗
K
T
+
p
o
s
i
t
i
o
n
b
i
a
s
)
∗
V
softmax(Q*K^{T}+position_{bias})*V
softmax(Q∗KT+positionbias)∗V
这里不需要除以
d
k
d_{k}
dk
接下来计算scores的内容
scores = torch.matmul(
query_states,key_states.transpose(3,2)
)
得到的scores的内容是
scores = (1,8,15,15)
接下来计算位置便移的position_bias内容
position_bias = self.compute_bias(real_seq_length,key_length)
进入的到compute_bias函数之中,查看它的对应计算过程
(感觉这里t5模型的compute_bias与nezha中的compute_bias有点类似???)
注意!!!t5的相对位置编码和nezha的相对位置编码还是不一样的!!!
def compute_bias(self, query_length, key_length): """Compute binned relative position bias""" #query_length = 15,key_length = 15 context_position = torch.arange( query_length, dtype=torch.long, device=self.relative_attention_bias.weight.device )[:, None] memory_position = torch.arange( key_length, dtype=torch.long, device=self.relative_attention_bias.weight.device )[None, :] relative_position = memory_position - context_position # shape (query_length, key_length) relative_position_bucket = self._relative_position_bucket( relative_position, # shape (query_length, key_length) bidirectional=(not self.is_decoder), num_buckets=self.relative_attention_num_buckets, ) values = self.relative_attention_bias(relative_position_bucket) # shape (query_length, key_length, num_heads) values = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, query_length, key_length) return values
这里首先通过语句调用context_position和memory_position
调用语句
context_position = torch.arange(
query_length, dtype=torch.long, device=self.relative_attention_bias.weight.device
)[:, None]
memory_position = torch.arange(
key_length, dtype=torch.long, device=self.relative_attention_bias.weight.device
)[None, :]
得到相应的参数
context_position = tensor([[ 0], [ 1], [ 2], [ 3], [ 4], [ 5], [ 6], [ 7], [ 8], [ 9], [10], [11], [12], [13], [14]])
memory_position =
tensor([[ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14]])
接下来计算relative_position的位置内容
relative_position = memory_position-context_position
得到对应的相对位置内容
relative_position = relative_position = [[ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14], [ -1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13], [ -2, -1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], [ -3, -2, -1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], [ -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10], [ -5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9], [ -6, -5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 6, 7, 8], [ -7, -6, -5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 6, 7], [ -8, -7, -6, -5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 6], [ -9, -8, -7, -6, -5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5], [-10, -9, -8, -7, -6, -5, -4, -3, -2, -1, 0, 1, 2, 3, 4], [-11, -10, -9, -8, -7, -6, -5, -4, -3, -2, -1, 0, 1, 2, 3], [-12, -11, -10, -9, -8, -7, -6, -5, -4, -3, -2, -1, 0, 1, 2], [-13, -12, -11, -10, -9, -8, -7, -6, -5, -4, -3, -2, -1, 0, 1], [-14, -13, -12, -11, -10, -9, -8, -7, -6, -5, -4, -3, -2, -1, 0]])
然后计算relative_position_bucket的向量内容
relative_position_bucket = self._relative_position_bucket(
relative_position, # shape (query_length, key_length)
bidirectional=(not self.is_decoder),
num_buckets=self.relative_attention_num_buckets,
)
进入到_relative_position_bucket函数之中,首先进行调用
if bidirectional:
num_buckets //= 2
relative_buckets += (relative_position > 0).to(torch.long) * num_buckets
relative_position = torch.abs(relative_position)
else:
relative_position = -torch.min(relative_position, torch.zeros_like(relative_position))
这里的bidirectional代表双向的含义是左右位置都处于相同的状态(举个例子,左边0~7的位置和右边0~7的位置的编码相同)
这里调用bidirectional的部分,首先调用num_bucket的参数
num_bucket = 32
然后将num_bucket除以2
num_buckets //= 2
求得num_buckets的对应值
num_buckets = 16
接下来调用上面的relative_position并对非零部分乘上num_buckets
relative_buckets += (relative_position > 0).to(torch.long)*num_buckets
得到relative_buckets的对应值
relative_buckets = tensor([[ 0, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16], [ 0, 0, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16], [ 0, 0, 0, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16], [ 0, 0, 0, 0, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16], [ 0, 0, 0, 0, 0, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16], [ 0, 0, 0, 0, 0, 0, 16, 16, 16, 16, 16, 16, 16, 16, 16], [ 0, 0, 0, 0, 0, 0, 0, 16, 16, 16, 16, 16, 16, 16, 16], [ 0, 0, 0, 0, 0, 0, 0, 0, 16, 16, 16, 16, 16, 16, 16], [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 16, 16, 16, 16, 16, 16], [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 16, 16, 16, 16, 16], [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 16, 16, 16, 16], [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 16, 16, 16], [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 16, 16], [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 16], [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]])
接下来的relative_position还是接着上面的relative_position进行操作,与上面的relative_buckets的内容暂时没有关系
relative_position = torch.abs(relative_position)
获得relative_position的对应值
relative_position = tensor([[ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14], [ 1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13], [ 2, 1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], [ 3, 2, 1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], [ 4, 3, 2, 1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10], [ 5, 4, 3, 2, 1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9], [ 6, 5, 4, 3, 2, 1, 0, 1, 2, 3, 4, 5, 6, 7, 8], [ 7, 6, 5, 4, 3, 2, 1, 0, 1, 2, 3, 4, 5, 6, 7], [ 8, 7, 6, 5, 4, 3, 2, 1, 0, 1, 2, 3, 4, 5, 6], [ 9, 8, 7, 6, 5, 4, 3, 2, 1, 0, 1, 2, 3, 4, 5], [10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0, 1, 2, 3, 4], [11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0, 1, 2, 3], [12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0, 1, 2], [13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0, 1], [14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0]])
然后定义is_small的内容
is_small = relative_position < max_exact
得到对应的is_small的内容
可以看出,这里的is_small是一种流线型的bool判定矩阵内容
接下来是一部分较为复杂的计算过程(这里牵扯到t5相对位置编码的计算)
relative_postion_if_large = max_exact + (
torch.log(relative_position.float() / max_exact)
/ math.log(max_distance / max_exact)
* (num_buckets - max_exact)
).to(torch.long)
一点一点解析这里公式的内容
首先这里的max_exact = 8,将全部的max_exact用8替换
relative_position_if_large = 8 + log(relative_position/8)/log(max_distance/8)
这里对应的相对位置矩阵
relative_position = tensor([[ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14], [ 1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13], [ 2, 1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], [ 3, 2, 1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], [ 4, 3, 2, 1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10], [ 5, 4, 3, 2, 1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9], [ 6, 5, 4, 3, 2, 1, 0, 1, 2, 3, 4, 5, 6, 7, 8], [ 7, 6, 5, 4, 3, 2, 1, 0, 1, 2, 3, 4, 5, 6, 7], [ 8, 7, 6, 5, 4, 3, 2, 1, 0, 1, 2, 3, 4, 5, 6], [ 9, 8, 7, 6, 5, 4, 3, 2, 1, 0, 1, 2, 3, 4, 5], [10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0, 1, 2, 3, 4], [11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0, 1, 2, 3], [12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0, 1, 2], [13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0, 1], [14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0]])
对应的最大距离
max_distance =
128
所以最终的计算公式为
8
+
(
l
o
g
(
[
[
0
,
1
,
2
,
.
.
.
14
]
,
[
1
,
0
,
1
,
.
.
.
13
]
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
[
14
,
13
,
.
.
1
,
0
]
]
)
/
8
)
/
(
128
/
8
)
8 + (log([[0,1,2,...14], \\ \qquad \quad \ \ \ [1,0,1,...13] \\ \qquad \quad \ \ \ ...................\\ \qquad \quad \ \ \ [14,13,..1,0]])/8)/(128/8)
8+(log([[0,1,2,...14], [1,0,1,...13] ................... [14,13,..1,0]])/8)/(128/8)
化简之后的结果为
8
+
(
l
o
g
(
[
[
0
,
1
,
2
,
.
.
.
14
]
,
[
1
,
0
,
1
,
.
.
.
13
]
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
[
14
,
13
,
.
.
1
,
0
]
]
)
/
128
)
8 + (log([[0,1,2,...14], \\ \qquad \quad \ \ \ [1,0,1,...13] \\ \qquad \quad \ \ \ ...................\\ \qquad \quad \ \ \ [14,13,..1,0]])/128)
8+(log([[0,1,2,...14], [1,0,1,...13] ................... [14,13,..1,0]])/128)
这里我们首先查看log的计算结果
torch.log = tensor([[ -inf, 0.0000, 0.6931, 1.0986, 1.3863, 1.6094, 1.7918, 1.9459, 2.0794, 2.1972, 2.3026, 2.3979, 2.4849, 2.5649, 2.6391], [0.0000, -inf, 0.0000, 0.6931, 1.0986, 1.3863, 1.6094, 1.7918, 1.9459, 2.0794, 2.1972, 2.3026, 2.3979, 2.4849, 2.5649], [0.6931, 0.0000, -inf, 0.0000, 0.6931, 1.0986, 1.3863, 1.6094, 1.7918, 1.9459, 2.0794, 2.1972, 2.3026, 2.3979, 2.4849], [1.0986, 0.6931, 0.0000, -inf, 0.0000, 0.6931, 1.0986, 1.3863, 1.6094, 1.7918, 1.9459, 2.0794, 2.1972, 2.3026, 2.3979], [1.3863, 1.0986, 0.6931, 0.0000, -inf, 0.0000, 0.6931, 1.0986, 1.3863, 1.6094, 1.7918, 1.9459, 2.0794, 2.1972, 2.3026], [1.6094, 1.3863, 1.0986, 0.6931, 0.0000, -inf, 0.0000, 0.6931, 1.0986, 1.3863, 1.6094, 1.7918, 1.9459, 2.0794, 2.1972], [1.7918, 1.6094, 1.3863, 1.0986, 0.6931, 0.0000, -inf, 0.0000, 0.6931, 1.0986, 1.3863, 1.6094, 1.7918, 1.9459, 2.0794], [1.9459, 1.7918, 1.6094, 1.3863, 1.0986, 0.6931, 0.0000, -inf, 0.0000, 0.6931, 1.0986, 1.3863, 1.6094, 1.7918, 1.9459], [2.0794, 1.9459, 1.7918, 1.6094, 1.3863, 1.0986, 0.6931, 0.0000, -inf, 0.0000, 0.6931, 1.0986, 1.3863, 1.6094, 1.7918], [2.1972, 2.0794, 1.9459, 1.7918, 1.6094, 1.3863, 1.0986, 0.6931, 0.0000, -inf, 0.0000, 0.6931, 1.0986, 1.3863, 1.6094], [2.3026, 2.1972, 2.0794, 1.9459, 1.7918, 1.6094, 1.3863, 1.0986, 0.6931, 0.0000, -inf, 0.0000, 0.6931, 1.0986, 1.3863], [2.3979, 2.3026, 2.1972, 2.0794, 1.9459, 1.7918, 1.6094, 1.3863, 1.0986, 0.6931, 0.0000, -inf, 0.0000, 0.6931, 1.0986], [2.4849, 2.3979, 2.3026, 2.1972, 2.0794, 1.9459, 1.7918, 1.6094, 1.3863, 1.0986, 0.6931, 0.0000, -inf, 0.0000, 0.6931], [2.5649, 2.4849, 2.3979, 2.3026, 2.1972, 2.0794, 1.9459, 1.7918, 1.6094, 1.3863, 1.0986, 0.6931, 0.0000, -inf, 0.0000], [2.6391, 2.5649, 2.4849, 2.3979, 2.3026, 2.1972, 2.0794, 1.9459, 1.7918, 1.6094, 1.3863, 1.0986, 0.6931, 0.0000, -inf]])
接下来查看
torch.log(relative_position.float()/max_exact)
的对应的内容
torch.log = tensor([[ -inf, 0.0000, 0.0866, 0.1373, 0.1733, 0.2012, 0.2240, 0.2432, 0.2599, 0.2747, 0.2878, 0.2997, 0.3106, 0.3206, 0.3299], [0.0000, -inf, 0.0000, 0.0866, 0.1373, 0.1733, 0.2012, 0.2240, 0.2432, 0.2599, 0.2747, 0.2878, 0.2997, 0.3106, 0.3206], [0.0866, 0.0000, -inf, 0.0000, 0.0866, 0.1373, 0.1733, 0.2012, 0.2240, 0.2432, 0.2599, 0.2747, 0.2878, 0.2997, 0.3106], [0.1373, 0.0866, 0.0000, -inf, 0.0000, 0.0866, 0.1373, 0.1733, 0.2012, 0.2240, 0.2432, 0.2599, 0.2747, 0.2878, 0.2997], [0.1733, 0.1373, 0.0866, 0.0000, -inf, 0.0000, 0.0866, 0.1373, 0.1733, 0.2012, 0.2240, 0.2432, 0.2599, 0.2747, 0.2878], [0.2012, 0.1733, 0.1373, 0.0866, 0.0000, -inf, 0.0000, 0.0866, 0.1373, 0.1733, 0.2012, 0.2240, 0.2432, 0.2599, 0.2747], [0.2240, 0.2012, 0.1733, 0.1373, 0.0866, 0.0000, -inf, 0.0000, 0.0866, 0.1373, 0.1733, 0.2012, 0.2240, 0.2432, 0.2599], [0.2432, 0.2240, 0.2012, 0.1733, 0.1373, 0.0866, 0.0000, -inf, 0.0000, 0.0866, 0.1373, 0.1733, 0.2012, 0.2240, 0.2432], [0.2599, 0.2432, 0.2240, 0.2012, 0.1733, 0.1373, 0.0866, 0.0000, -inf, 0.0000, 0.0866, 0.1373, 0.1733, 0.2012, 0.2240], [0.2747, 0.2599, 0.2432, 0.2240, 0.2012, 0.1733, 0.1373, 0.0866, 0.0000, -inf, 0.0000, 0.0866, 0.1373, 0.1733, 0.2012], [0.2878, 0.2747, 0.2599, 0.2432, 0.2240, 0.2012, 0.1733, 0.1373, 0.0866, 0.0000, -inf, 0.0000, 0.0866, 0.1373, 0.1733], [0.2997, 0.2878, 0.2747, 0.2599, 0.2432, 0.2240, 0.2012, 0.1733, 0.1373, 0.0866, 0.0000, -inf, 0.0000, 0.0866, 0.1373], [0.3106, 0.2997, 0.2878, 0.2747, 0.2599, 0.2432, 0.2240, 0.2012, 0.1733, 0.1373, 0.0866, 0.0000, -inf, 0.0000, 0.0866], [0.3206, 0.3106, 0.2997, 0.2878, 0.2747, 0.2599, 0.2432, 0.2240, 0.2012, 0.1733, 0.1373, 0.0866, 0.0000, -inf, 0.0000], [0.3299, 0.3206, 0.3106, 0.2997, 0.2878, 0.2747, 0.2599, 0.2432, 0.2240, 0.2012, 0.1733, 0.1373, 0.0866, 0.0000, -inf]])
接着查看除数中的max_distance/max_exact
math.log(max_distance/max_exact)
这里的max_distance = 128,max_exact = 8,所以这里计算出来的结果
math.log(128/8) = math.log(16) = 2.0794415416798357
最后部分调用relative_position_if_large部分
relative_position_if_large = torch.min(
relative_postion_if_large, torch.full_like(relative_postion_if_large, num_buckets - 1)
)
疑问??
这里一直有个疑问,就是在这个位置输出relative_position_if_large为什么一直报错
relative_position_bucket = self._relative_position_bucket(
410 relative_position, # shape (query_length, key_length)
411 bidirectional=(not self.is_decoder),
~/.local/lib/python3.9/site-packages/transformers/models/t5/modeling_t5.py in _relative_position_bucket(relative_position, bidirectional, num_buckets, max_distance)
393 )
394 print('relative_position_if_large = ')
--> 395 print(relative_position_if_large)
396 relative_buckets += torch.where(is_small, relative_position, relative_postion_if_large)
这里为什么输出不了先不管他(后来发现还是自己的名称写错了),我们直接换一个变量进行输出内容
current_data = max_exact + (
torch.log(relative_position.float() / max_exact)
/ math.log(max_distance / max_exact)
* (num_buckets - max_exact)
)
最终的经过公式之后输出的内容为
仔细观察可疑看出,这里近距离的内容,比如周围1,2部分的内容,比如数值2,变化的很快,而远一些的位置,比如数值8就会延续很长一段数值,变化的很慢
这里的num_buckets = 16,max_exact = 8,因此16-8 = 8.
接下来调用的过程
relative_position_if_large = torch.min(
relative_position_if_large, torch.full_like(relative_position_if_large, num_buckets - 1)
)
relative_position_if_large = torch.min(
relative_position_if_large,torch.full_like(relative_position_if_large,num_buckets-1)
)
这里先看一下torch.full_like的调用
torch.full_like(relative_position_if_large.num_buckets-1)
得到的结果
tensor([[15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15],
[15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15],
[15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15],
[15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15],
[15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15],
[15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15],
[15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15],
[15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15],
[15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15],
[15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15],
[15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15],
[15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15],
[15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15],
[15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15],
[15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15]])
接下来调用torch.min的函数内容
relative_position_if_large = torch.min(
relative_position_if_large,torch.full_like(relative_position_if_large,num_buckets-1)
)
这波操作在这里没有任何的变动,relative_position_if_large的对应值还是上文的对应值
原因在于relative_position_if_large的所有值都要比15小,所以这里本质上等于没变
接下来调用最后一波操作:
relative_buckets += torch.where(is_small,relative_position,relative_position_if_large)
注释:
这里首先挂出来原版的relative_buckets(上文调用的过程)
relative_buckets += (relative_position > 0).to(torch.long)*num_buckets
注释完毕
原版的relative_buckets的内容
relative_buckets = tensor([[ 0, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16], [ 0, 0, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16], [ 0, 0, 0, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16], [ 0, 0, 0, 0, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16], [ 0, 0, 0, 0, 0, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16], [ 0, 0, 0, 0, 0, 0, 16, 16, 16, 16, 16, 16, 16, 16, 16], [ 0, 0, 0, 0, 0, 0, 0, 16, 16, 16, 16, 16, 16, 16, 16], [ 0, 0, 0, 0, 0, 0, 0, 0, 16, 16, 16, 16, 16, 16, 16], [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 16, 16, 16, 16, 16, 16], [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 16, 16, 16, 16, 16], [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 16, 16, 16, 16], [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 16, 16, 16], [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 16, 16], [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 16], [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]])
现在需要调用的函数内容
torch.where(is_small,relative_position,relative_position_if_large)
调用结束的内容为
torch.where = tensor([[0, 1, 2, 3, 4, 5, 6, 7, 8, 8, 8, 8, 9, 9, 9], [1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 8, 8, 8, 9, 9], [2, 1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 8, 8, 8, 9], [3, 2, 1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 8, 8, 8], [4, 3, 2, 1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 8, 8], [5, 4, 3, 2, 1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 8], [6, 5, 4, 3, 2, 1, 0, 1, 2, 3, 4, 5, 6, 7, 8], [7, 6, 5, 4, 3, 2, 1, 0, 1, 2, 3, 4, 5, 6, 7], [8, 7, 6, 5, 4, 3, 2, 1, 0, 1, 2, 3, 4, 5, 6], [8, 8, 7, 6, 5, 4, 3, 2, 1, 0, 1, 2, 3, 4, 5], [8, 8, 8, 7, 6, 5, 4, 3, 2, 1, 0, 1, 2, 3, 4], [8, 8, 8, 8, 7, 6, 5, 4, 3, 2, 1, 0, 1, 2, 3], [9, 8, 8, 8, 8, 7, 6, 5, 4, 3, 2, 1, 0, 1, 2], [9, 9, 8, 8, 8, 8, 7, 6, 5, 4, 3, 2, 1, 0, 1], [9, 9, 9, 8, 8, 8, 8, 7, 6, 5, 4, 3, 2, 1, 0]])
这里解析一下内容
is_small函数选取的是较小的内容,
is_small = relative_position < max_exact
也就是说这里如果relative_position < max_exact的时候,选取relative_position的对应值内容即可,否则选取relative_if_large的内容,得到的torch.where矩阵内容为
torch.where = tensor([[0, 1, 2, 3, 4, 5, 6, 7, 8, 8, 8, 8, 9, 9, 9], [1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 8, 8, 8, 9, 9], [2, 1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 8, 8, 8, 9], [3, 2, 1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 8, 8, 8], [4, 3, 2, 1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 8, 8], [5, 4, 3, 2, 1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 8], [6, 5, 4, 3, 2, 1, 0, 1, 2, 3, 4, 5, 6, 7, 8], [7, 6, 5, 4, 3, 2, 1, 0, 1, 2, 3, 4, 5, 6, 7], [8, 7, 6, 5, 4, 3, 2, 1, 0, 1, 2, 3, 4, 5, 6], [8, 8, 7, 6, 5, 4, 3, 2, 1, 0, 1, 2, 3, 4, 5], [8, 8, 8, 7, 6, 5, 4, 3, 2, 1, 0, 1, 2, 3, 4], [8, 8, 8, 8, 7, 6, 5, 4, 3, 2, 1, 0, 1, 2, 3], [9, 8, 8, 8, 8, 7, 6, 5, 4, 3, 2, 1, 0, 1, 2], [9, 9, 8, 8, 8, 8, 7, 6, 5, 4, 3, 2, 1, 0, 1], [9, 9, 9, 8, 8, 8, 8, 7, 6, 5, 4, 3, 2, 1, 0]])
最后加上得到的内容
relative_buckets += torch.where(is_small,relative_position,relative_position_if_large)
得到最终的内容
relative_buckets = tensor([[ 0, 17, 18, 19, 20, 21, 22, 23, 24, 24, 24, 24, 25, 25, 25], [ 1, 0, 17, 18, 19, 20, 21, 22, 23, 24, 24, 24, 24, 25, 25], [ 2, 1, 0, 17, 18, 19, 20, 21, 22, 23, 24, 24, 24, 24, 25], [ 3, 2, 1, 0, 17, 18, 19, 20, 21, 22, 23, 24, 24, 24, 24], [ 4, 3, 2, 1, 0, 17, 18, 19, 20, 21, 22, 23, 24, 24, 24], [ 5, 4, 3, 2, 1, 0, 17, 18, 19, 20, 21, 22, 23, 24, 24], [ 6, 5, 4, 3, 2, 1, 0, 17, 18, 19, 20, 21, 22, 23, 24], [ 7, 6, 5, 4, 3, 2, 1, 0, 17, 18, 19, 20, 21, 22, 23], [ 8, 7, 6, 5, 4, 3, 2, 1, 0, 17, 18, 19, 20, 21, 22], [ 8, 8, 7, 6, 5, 4, 3, 2, 1, 0, 17, 18, 19, 20, 21], [ 8, 8, 8, 7, 6, 5, 4, 3, 2, 1, 0, 17, 18, 19, 20], [ 8, 8, 8, 8, 7, 6, 5, 4, 3, 2, 1, 0, 17, 18, 19], [ 9, 8, 8, 8, 8, 7, 6, 5, 4, 3, 2, 1, 0, 17, 18], [ 9, 9, 8, 8, 8, 8, 7, 6, 5, 4, 3, 2, 1, 0, 17], [ 9, 9, 9, 8, 8, 8, 8, 7, 6, 5, 4, 3, 2, 1, 0]])
class RelativePositionEmbeddingT5(RelativePositionEmbedding): """Google T5的相对位置编码 来自论文:https://arxiv.org/abs/1910.10683 """ def __init__( self, input_dim, output_dim, max_distance=128, bidirectional=True, embeddings_initializer='zeros', **kwargs ): super(RelativePositionEmbeddingT5, self).__init__(input_dim, output_dim, **kwargs) self.max_distance = max_distance self.bidirectional = bidirectional def compute_position_ids(self, inputs): """T5的相对位置分桶(直接翻译自官方T5源码) """ q, v = inputs # 计算位置差 q_idxs = K.arange(0, K.shape(q)[1], dtype='int32') q_idxs = K.expand_dims(q_idxs, 1) v_idxs = K.arange(0, K.shape(v)[1], dtype='int32') v_idxs = K.expand_dims(v_idxs, 0) pos_ids = v_idxs - q_idxs # 后处理操作 num_buckets, max_distance = self.input_dim, self.max_distance ret = 0 n = -pos_ids if self.bidirectional: num_buckets //= 2 ret += K.cast(K.less(n, 0), 'int32') * num_buckets n = K.abs(n) else: n = K.maximum(n, 0) # now n is in the range [0, inf) max_exact = num_buckets // 2 is_small = K.less(n, max_exact) val_if_large = max_exact + K.cast( K.log(K.cast(n, K.floatx()) / max_exact) / np.log(max_distance / max_exact) * (num_buckets - max_exact), 'int32', ) val_if_large = K.minimum(val_if_large, num_buckets - 1) ret += K.switch(is_small, n, val_if_large) return ret
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。