赞
踩
在实践中,当给定相同的查询,键和值的集合时,我们希望模型可以基于相同的注意力机制学习到不同的行为,然后将不同的行为作为知识组合起来,捕获序列内各种范围的依赖关系。(短距离依赖和长距离依赖关系)。因此,允许注意力机制组合使用查询、键和值的不同子空间表示可能是有益的。我们分两步思考
给定查询
q
∈
R
d
q
q\in R^{d_q}
q∈Rdq,键
k
∈
R
d
k
k\in R^{d_k}
k∈Rdk,和值
v
∈
R
d
v
v\in R^{d_v}
v∈Rdv,每个注意力头
h
i
(
i
=
1
,
.
.
.
,
h
)
h_i(i=1,...,h)
hi(i=1,...,h)的计算方法为:
h
i
=
f
(
W
i
(
q
)
q
,
W
i
(
k
)
k
,
W
i
(
v
)
v
)
∈
R
p
v
(1)
h_i=f(W_i^{(q)}q,W_i^{(k)}k,W_i^{(v)}v)\in R^{p_v}\tag1
hi=f(Wi(q)q,Wi(k)k,Wi(v)v)∈Rpv(1)
其中可学习的参数包括
W
i
(
q
)
∈
R
p
q
×
d
q
W_i^{(q)}\in R^{p_q \times d_q}
Wi(q)∈Rpq×dq,
W
i
(
k
)
∈
R
p
k
×
d
k
W_i^{(k)}\in R^{p_k \times d_k}
Wi(k)∈Rpk×dk,
W
i
(
v
)
∈
R
p
v
×
d
v
W_i^{(v)}\in R^{p_v \times d_v}
Wi(v)∈Rpv×dv,以及代表注意力汇聚的函数 f,f可以使加性注意力和缩放点积注意力,多头注意力的输出需要经过另一线性变换,它对应着 h 个头连结后的结果。因此其学习参数是
W
o
∈
R
p
o
×
h
p
v
W_o\in R^{p_o \times hp_v}
Wo∈Rpo×hpv
W
o
[
h
1
⋮
h
h
]
∈
R
p
o
(2)
W_o[h1⋮hh]
# -*- coding: utf-8 -*- # @Project: zc # @Author: zc # @File name: MultiHeadAttention_test # @Create time: 2022/2/25 9:11 import torch from torch import nn from d2l import torch as d2l class MultiHeadAttention(nn.Module): """ 作用:将输入的矩阵X按照特征维度进行分割为num_heads个 """ # key_size=100;query_size=100,value_size=100,value_size=100 # num_hiddens=100;num_head=5,dropout=0.5 def __init__(self, key_size, query_size, value_size, num_hiddens, num_heads, dropout, bias=False, **kwargs): super(MultiHeadAttention, self).__init__(**kwargs) self.num_heads = num_heads self.attention = d2l.DotProductAttention(dropout) # 100->100 self.W_q = nn.Linear(query_size, num_hiddens, bias=False) # 100->100 self.W_k = nn.Linear(key_size, num_hiddens, bias=False) # 100->100 self.W_v = nn.Linear(value_size, num_hiddens, bias=False) # 100->100 self.W_o = nn.Linear(num_hiddens, num_hiddens, bias=False) def forward(self, queries, keys, values, valid_lens): # 输入 queies=(2,4,100);keys=(2,6,100);values=(2,6,100) # valid_lens=torch.tensor([3,2]) # 输出 queries=(2,4,100) -> (2,4,5,20) -> (2,5,4,20) -> (10,4,20) # 输出 keys=(2,6,100) ->(2,6,5,20) -> (2,5,6,20) -> (10,6,20) # 输出 values=(2,6,100) -> (2,6,5,20) -> (2,5,6,20) -> (10,6,20) queries = transpose_qkv(self.W_q(queries), self.num_heads) keys = transpose_qkv(self.W_k(keys), self.num_heads) values = transpose_qkv(self.W_v(values), self.num_heads) if valid_lens is not None: valid_lens = torch.repeat_interleave(valid_lens, repeats=self.num_heads, dim=0) # queries=(10,6,20);keys=(10,6,20);values(10,6,20) # output=(10,4,20) output = self.attention(queries, keys, values, valid_lens) # (10,4,20) -> (2,5,4,20) -> (2,4,5,20) -> (2,4,100)=output_concat output_concat = transpose_output(output, self.num_heads) # return (2,4,100) -> (2,4,100) return self.W_o(output_concat) def transpose_qkv(X, num_heads): X = X.reshape(X.shape[0], X.shape[1], num_heads, -1) X = X.permute(0, 2, 1, 3) return X.reshape(-1, X.shape[2], X.shape[3]) def transpose_output(X, num_heads): X = X.reshape(-1, num_heads, X.shape[1], X.shape[2]) X = X.permute(0, 2, 1, 3) return X.reshape(X.shape[0], X.shape[1], -1) num_hiddens, num_heads = 100, 5 attention = MultiHeadAttention(num_hiddens, num_hiddens, num_hiddens, num_hiddens, num_heads, 0.5) attention.eval() print(attention) batch_size, num_queries = 2, 4 num_kvpairs, valid_lens = 6, torch.tensor([3, 2]) # x=(2,4,100);y=(2,6,100) x = torch.ones((batch_size, num_queries, num_hiddens)) y = torch.ones((batch_size, num_kvpairs, num_hiddens)) # attention(x,y,y,valid_lens).shape=(2,4,100) print(attention(x, y, y, valid_lens).shape)
MultiHeadAttention(
(attention): DotProductAttention(
(dropout): Dropout(p=0.5, inplace=False)
)
(W_q): Linear(in_features=100, out_features=100, bias=False)
(W_k): Linear(in_features=100, out_features=100, bias=False)
(W_v): Linear(in_features=100, out_features=100, bias=False)
(W_o): Linear(in_features=100, out_features=100, bias=False)
)
torch.Size([2, 4, 100])
给定一个由词元组成的输入序列
x
1
,
.
.
.
,
x
n
x_1,...,x_n
x1,...,xn,其中任意
x
i
∈
R
d
(
1
≤
i
≤
n
)
x_i\in R^d(1\leq i \leq n)
xi∈Rd(1≤i≤n).该序列的自注意力输出为一个长度相同的序列
y
1
,
.
.
.
,
y
n
y_1,...,y_n
y1,...,yn,其中
y
i
=
f
(
x
i
,
(
x
1
,
x
1
)
,
.
.
.
,
(
x
n
,
x
n
)
)
∈
R
d
y_i=f(x_i,(x_1,x_1),...,(x_n,x_n))\in R^d
yi=f(xi,(x1,x1),...,(xn,xn))∈Rd
# -*- coding: utf-8 -*- # @Project: zc # @Author: zc # @File name: self_attention_test # @Create time: 2022/2/27 9:55 import torch from torch import nn from d2l import torch as d2l num_hiddens, num_heads = 100, 5 attetion = d2l.MultiHeadAttention(num_hiddens, num_hiddens, num_hiddens, num_hiddens, num_heads, 0.5) attetion.eval() print(attetion) batch_size, num_queries, valid_lens = 2, 4, torch.tensor([3, 2]) x = torch.ones((batch_size,num_queries,num_hiddens)) print(f"attetion(x,x,x,valid_lens).shape={attetion(x,x,x,valid_lens).shape}")
自注意力机制运用了多头注意力机制,只不过区别在于自注意力机制的 queries,keys,values是相同的。
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。