赞
踩
self-attention可以视为一个特征提取层,给定输入特征 a 1 , a 2 , ⋅ ⋅ ⋅ a n a^{1},a^{2},\cdot \cdot \cdot a^{n} a1,a2,⋅⋅⋅an,经过self-attention layer,融合每个输入特征,得到新的特征 b 1 , b 2 , ⋅ ⋅ ⋅ b n b^{1},b^{2},\cdot \cdot \cdot b^{n} b1,b2,⋅⋅⋅bn。具体如下:
设输入特征为
I
I
I,分别将其乘以三个矩阵
W
q
W^{q}
Wq、
W
k
W^{k}
Wk和
W
v
W^{v}
Wv得到
Q
Q
Q(query)、
K
K
K(key)和
V
V
V(value)三个矩阵;接下来使用矩阵
Q
Q
Q和
K
K
K的乘积得到注意力矩阵
A
A
A,归一化得到
A
^
\hat{A}
A^;最后,将归一化后的注意力矩阵
A
^
\hat{A}
A^乘上
V
V
V,得到最后的输出特征
O
O
O。
上述的self-attention中,每个输入特征
a
i
a^{i}
ai乘上矩阵
W
q
W^{q}
Wq、
W
k
W^{k}
Wk和
W
v
W^{v}
Wv后,分别得到一个向量
q
i
q^{i}
qi、
k
i
k^{i}
ki和
v
i
v^{i}
vi,称为单头自注意力机制。如果将这些向量
q
i
q^{i}
qi、
k
i
k^{i}
ki和
v
i
v^{i}
vi分裂为
n
n
n个就得到
n
n
n头自注意力机制了。公认多头自注意力机制的效果好于单头的,因为前者可以捕获更多维度的信息。示意图如下:
设超参数num_attention_heads为自注意力机制的头数,如此,计算出每个头的维度attention_head_size。
self.num_attention_heads = num_attention_heads
self.attention_head_size = int(hidden_size / num_attention_heads)
self.all_head_size = hidden_size
定义 W q W^{q} Wq、 W k W^{k} Wk和 W v W^{v} Wv三个矩阵。
self.query = nn.Linear(input_size, self.all_head_size)
self.key = nn.Linear(input_size, self.all_head_size)
self.value = nn.Linear(input_size, self.all_head_size)
下面开始逐步计算,需要主要的是计算过程中张量维度的变化。
将输入特征乘以三个矩阵
W
q
W^{q}
Wq、
W
k
W^{k}
Wk和
W
v
W^{v}
Wv,输出的张量此时还没有区分出多个头。维度变化为:input_tensor
(
b
a
t
c
h
,
n
,
i
n
p
u
t
_
s
i
z
e
)
\left ( batch,n,input\_size\right )
(batch,n,input_size)到mixed_query_layer
(
b
a
t
c
h
,
n
,
a
l
l
_
h
e
a
d
_
s
i
z
e
)
\left ( batch,n,all\_head\_size\right )
(batch,n,all_head_size)
mixed_query_layer = self.query(input_tensor)
mixed_key_layer = self.key(input_tensor)
mixed_value_layer = self.value(input_tensor)
切分为num_attention_heads个头,并变换维度。维度变化为:mixed_query_layer ( b a t c h , n , a l l _ h e a d _ s i z e ) \left ( batch,n,all\_head\_size\right ) (batch,n,all_head_size)到query_layer ( b a t c h , n u m _ a t t e n t i o n _ h e a d s , n , a t t e n t i o n _ h e a d _ s i z e ) \left ( batch,num\_attention\_heads,n,attention\_head\_size\right ) (batch,num_attention_heads,n,attention_head_size)
def transpose_for_scores(self, x):
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
x = x.view(*new_x_shape)
return x.permute(0, 2, 1, 3)
query_layer = self.transpose_for_scores(mixed_query_layer)
key_layer = self.transpose_for_scores(mixed_key_layer)
value_layer = self.transpose_for_scores(mixed_value_layer)
矩阵 Q Q Q和 K K K相乘,得到注意力矩阵,并除以向量的维度的开方,防止注意力分数随维度增大而增大。维度变化为:query_layer ( b a t c h , n u m _ a t t e n t i o n _ h e a d s , n , a t t e n t i o n _ h e a d _ s i z e ) \left ( batch,num\_attention\_heads,n,attention\_head\_size\right ) (batch,num_attention_heads,n,attention_head_size)到attention_scores ( b a t c h , n u m _ a t t e n t i o n _ h e a d s , n , n ) \left ( batch,num\_attention\_heads,n,n\right ) (batch,num_attention_heads,n,n)
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
注意力矩阵归一化。维度变化为:attention_scores ( b a t c h , n u m _ a t t e n t i o n _ h e a d s , n , n ) \left ( batch,num\_attention\_heads,n,n\right ) (batch,num_attention_heads,n,n)到attention_probs ( b a t c h , n u m _ a t t e n t i o n _ h e a d s , n , n ) \left ( batch,num\_attention\_heads,n,n\right ) (batch,num_attention_heads,n,n)
attention_probs = nn.Softmax(dim=-1)(attention_scores)
将注意力矩阵乘以矩阵 V V V。维度变化为:ttention_probs ( b a t c h , n u m _ a t t e n t i o n _ h e a d s , n , n ) \left ( batch,num\_attention\_heads,n,n\right ) (batch,num_attention_heads,n,n)乘以value_layer ( b a t c h , n u m _ a t t e n t i o n _ h e a d s , n , a t t e n t i o n _ h e a d _ s i z e ) \left ( batch,num\_attention\_heads,n,attention\_head\_size\right ) (batch,num_attention_heads,n,attention_head_size)到context_layer ( b a t c h , n u m _ a t t e n t i o n _ h e a d s , n , a t t e n t i o n _ h e a d _ s i z e ) \left ( batch,num\_attention\_heads,n,attention\_head\_size\right ) (batch,num_attention_heads,n,attention_head_size)。
context_layer = torch.matmul(attention_probs, value_layer)
变换context_layer维度,为了后面将各头得到的结果拼接。这里的contiguous()是将tensor的内存变成连续的,为后面的view()做准备。维度变化为:context_layer ( b a t c h , n u m _ a t t e n t i o n _ h e a d s , n , a t t e n t i o n _ h e a d _ s i z e ) \left ( batch,num\_attention\_heads,n,attention\_head\_size\right ) (batch,num_attention_heads,n,attention_head_size)到context_layer ( b a t c h , n , n u m _ a t t e n t i o n _ h e a d s , a t t e n t i o n _ h e a d _ s i z e ) \left ( batch,n,num\_attention\_heads,attention\_head\_size\right ) (batch,n,num_attention_heads,attention_head_size)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
将各头的结果拼接起来。维度变化为:context_layer ( b a t c h , n , n u m _ a t t e n t i o n _ h e a d s , a t t e n t i o n _ h e a d _ s i z e ) \left ( batch,n,num\_attention\_heads,attention\_head\_size\right ) (batch,n,num_attention_heads,attention_head_size)到context_layer ( b a t c h , n , a l l _ h e a d _ s i z e ) \left ( batch,n,all\_head\_size\right ) (batch,n,all_head_size)
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(*new_context_layer_shape)
class LayerNorm(nn.Module): def __init__(self, hidden_size, eps=1e-12): """Construct a layernorm module in the TF style (epsilon inside the square root). """ super(LayerNorm, self).__init__() self.weight = nn.Parameter(torch.ones(hidden_size)) self.bias = nn.Parameter(torch.zeros(hidden_size)) self.variance_epsilon = eps def forward(self, x): u = x.mean(-1, keepdim=True) s = (x - u).pow(2).mean(-1, keepdim=True) x = (x - u) / torch.sqrt(s + self.variance_epsilon) return self.weight * x + self.bias class SelfAttention(nn.Module): def __init__(self, num_attention_heads, input_size, hidden_size, hidden_dropout_prob): super(SelfAttention, self).__init__() if hidden_size % num_attention_heads != 0: raise ValueError( "The hidden size (%d) is not a multiple of the number of attention " "heads (%d)" % (hidden_size, num_attention_heads)) self.num_attention_heads = num_attention_heads self.attention_head_size = int(hidden_size / num_attention_heads) self.all_head_size = hidden_size self.query = nn.Linear(input_size, self.all_head_size) self.key = nn.Linear(input_size, self.all_head_size) self.value = nn.Linear(input_size, self.all_head_size) self.attn_dropout = nn.Dropout(attention_probs_dropout_prob) # 做完self-attention 做一个前馈全连接 LayerNorm 输出 self.dense = nn.Linear(hidden_size, hidden_size) self.LayerNorm = LayerNorm(hidden_size, eps=1e-12) self.out_dropout = nn.Dropout(hidden_dropout_prob) def transpose_for_scores(self, x): new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) x = x.view(*new_x_shape) return x.permute(0, 2, 1, 3) def forward(self, input_tensor): mixed_query_layer = self.query(input_tensor) mixed_key_layer = self.key(input_tensor) mixed_value_layer = self.value(input_tensor) query_layer = self.transpose_for_scores(mixed_query_layer) key_layer = self.transpose_for_scores(mixed_key_layer) value_layer = self.transpose_for_scores(mixed_value_layer) # Take the dot product between "query" and "key" to get the raw attention scores. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) attention_scores = attention_scores / math.sqrt(self.attention_head_size) # Apply the attention mask is (precomputed for all layers in BertModel forward() function) # [batch_size heads seq_len seq_len] scores # [batch_size 1 1 seq_len] # attention_scores = attention_scores + attention_mask # Normalize the attention scores to probabilities. attention_probs = nn.Softmax(dim=-1)(attention_scores) # This is actually dropping out entire tokens to attend to, which might # seem a bit unusual, but is taken from the original Transformer paper. # Fixme attention_probs = self.attn_dropout(attention_probs) context_layer = torch.matmul(attention_probs, value_layer) context_layer = context_layer.permute(0, 2, 1, 3).contiguous() new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) context_layer = context_layer.view(*new_context_layer_shape) hidden_states = self.dense(context_layer) hidden_states = self.out_dropout(hidden_states) hidden_states = self.LayerNorm(hidden_states + input_tensor) return hidden_states
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。