当前位置:   article > 正文

多头注意力和自注意力源码分析-读书笔记_多头注意力机制结构图

多头注意力机制结构图

1. 多头注意力

1.1 结构图

在实践中,当给定相同的查询,键和值的集合时,我们希望模型可以基于相同的注意力机制学习到不同的行为,然后将不同的行为作为知识组合起来,捕获序列内各种范围的依赖关系。(短距离依赖和长距离依赖关系)。因此,允许注意力机制组合使用查询、键和值的不同子空间表示可能是有益的。我们分两步思考

  • h 1 , . . . h h h_1,...h_h h1,...hh个注意力进行学习
  • h 1 , . . . , h h h_1,...,h_h h1,...,hh进行注意力汇聚输出
    在这里插入图片描述

1.2 相关公式

给定查询 q ∈ R d q q\in R^{d_q} qRdq,键 k ∈ R d k k\in R^{d_k} kRdk,和值 v ∈ R d v v\in R^{d_v} vRdv,每个注意力头 h i ( i = 1 , . . . , h ) h_i(i=1,...,h) hi(i=1,...,h)的计算方法为:
h i = f ( W i ( q ) q , W i ( k ) k , W i ( v ) v ) ∈ R p v (1) h_i=f(W_i^{(q)}q,W_i^{(k)}k,W_i^{(v)}v)\in R^{p_v}\tag1 hi=f(Wi(q)q,Wi(k)k,Wi(v)v)Rpv(1)
其中可学习的参数包括 W i ( q ) ∈ R p q × d q W_i^{(q)}\in R^{p_q \times d_q} Wi(q)Rpq×dq, W i ( k ) ∈ R p k × d k W_i^{(k)}\in R^{p_k \times d_k} Wi(k)Rpk×dk, W i ( v ) ∈ R p v × d v W_i^{(v)}\in R^{p_v \times d_v} Wi(v)Rpv×dv,以及代表注意力汇聚的函数 f,f可以使加性注意力和缩放点积注意力,多头注意力的输出需要经过另一线性变换,它对应着 h 个头连结后的结果。因此其学习参数是 W o ∈ R p o × h p v W_o\in R^{p_o \times hp_v} WoRpo×hpv
W o [ h 1 ⋮ h h ] ∈ R p o (2) W_o[h1hh]

\in R^{p_o}\tag2 Woh1hhRpo(2)

1.3 源码分析

# -*- coding: utf-8 -*-
# @Project: zc
# @Author: zc
# @File name: MultiHeadAttention_test
# @Create time: 2022/2/25 9:11

import torch
from torch import nn
from d2l import torch as d2l


class MultiHeadAttention(nn.Module):
	"""
	作用:将输入的矩阵X按照特征维度进行分割为num_heads个
	"""
	# key_size=100;query_size=100,value_size=100,value_size=100
	# num_hiddens=100;num_head=5,dropout=0.5
	def __init__(self, key_size, query_size, value_size, num_hiddens,
				 num_heads, dropout, bias=False, **kwargs):
		super(MultiHeadAttention, self).__init__(**kwargs)
		self.num_heads = num_heads
		self.attention = d2l.DotProductAttention(dropout)
		# 100->100
		self.W_q = nn.Linear(query_size, num_hiddens, bias=False)
		# 100->100
		self.W_k = nn.Linear(key_size, num_hiddens, bias=False)
		# 100->100
		self.W_v = nn.Linear(value_size, num_hiddens, bias=False)
		# 100->100
		self.W_o = nn.Linear(num_hiddens, num_hiddens, bias=False)

	def forward(self, queries, keys, values, valid_lens):
		# 输入 queies=(2,4,100);keys=(2,6,100);values=(2,6,100)
		# valid_lens=torch.tensor([3,2])
		# 输出 queries=(2,4,100) -> (2,4,5,20) -> (2,5,4,20) -> (10,4,20)
		# 输出 keys=(2,6,100) ->(2,6,5,20) -> (2,5,6,20) -> (10,6,20)
		# 输出 values=(2,6,100) -> (2,6,5,20) -> (2,5,6,20) -> (10,6,20)
		queries = transpose_qkv(self.W_q(queries), self.num_heads)
		keys = transpose_qkv(self.W_k(keys), self.num_heads)
		values = transpose_qkv(self.W_v(values), self.num_heads)

		if valid_lens is not None:
			valid_lens = torch.repeat_interleave(valid_lens, repeats=self.num_heads, dim=0)
		# queries=(10,6,20);keys=(10,6,20);values(10,6,20)
		# output=(10,4,20)
		output = self.attention(queries, keys, values, valid_lens)
		# (10,4,20) -> (2,5,4,20) -> (2,4,5,20) -> (2,4,100)=output_concat
		output_concat = transpose_output(output, self.num_heads)
		# return (2,4,100) -> (2,4,100)
		return self.W_o(output_concat)


def transpose_qkv(X, num_heads):
	X = X.reshape(X.shape[0], X.shape[1], num_heads, -1)
	X = X.permute(0, 2, 1, 3)
	return X.reshape(-1, X.shape[2], X.shape[3])


def transpose_output(X, num_heads):
	X = X.reshape(-1, num_heads, X.shape[1], X.shape[2])
	X = X.permute(0, 2, 1, 3)
	return X.reshape(X.shape[0], X.shape[1], -1)


num_hiddens, num_heads = 100, 5
attention = MultiHeadAttention(num_hiddens, num_hiddens, num_hiddens,
							   num_hiddens, num_heads, 0.5)
attention.eval()
print(attention)
batch_size, num_queries = 2, 4
num_kvpairs, valid_lens = 6, torch.tensor([3, 2])
# x=(2,4,100);y=(2,6,100)
x = torch.ones((batch_size, num_queries, num_hiddens))
y = torch.ones((batch_size, num_kvpairs, num_hiddens))

# attention(x,y,y,valid_lens).shape=(2,4,100)
print(attention(x, y, y, valid_lens).shape)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
MultiHeadAttention(
  (attention): DotProductAttention(
    (dropout): Dropout(p=0.5, inplace=False)
  )
  (W_q): Linear(in_features=100, out_features=100, bias=False)
  (W_k): Linear(in_features=100, out_features=100, bias=False)
  (W_v): Linear(in_features=100, out_features=100, bias=False)
  (W_o): Linear(in_features=100, out_features=100, bias=False)
)
torch.Size([2, 4, 100])
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10

1.4 小结

  • 为了避免我们用 for loop循环,我们先将 queries,keys,values 按照 num_heads 打散,再进行点积注意力运算,再concat合并,最后输出,这样我们就可以不需要循环了,通过大矩阵的计算来避免循环的使用,提高了计算的效率

2. 自注意力

给定一个由词元组成的输入序列 x 1 , . . . , x n x_1,...,x_n x1,...,xn,其中任意 x i ∈ R d ( 1 ≤ i ≤ n ) x_i\in R^d(1\leq i \leq n) xiRd(1in).该序列的自注意力输出为一个长度相同的序列 y 1 , . . . , y n y_1,...,y_n y1,...,yn,其中
y i = f ( x i , ( x 1 , x 1 ) , . . . , ( x n , x n ) ) ∈ R d y_i=f(x_i,(x_1,x_1),...,(x_n,x_n))\in R^d yi=f(xi,(x1,x1),...,(xn,xn))Rd

2.1源码

# -*- coding: utf-8 -*-
# @Project: zc
# @Author: zc
# @File name: self_attention_test
# @Create time: 2022/2/27 9:55
import torch
from torch import nn
from d2l import torch as d2l

num_hiddens, num_heads = 100, 5
attetion = d2l.MultiHeadAttention(num_hiddens, num_hiddens, num_hiddens,
								  num_hiddens, num_heads, 0.5)
attetion.eval()
print(attetion)
batch_size, num_queries, valid_lens = 2, 4, torch.tensor([3, 2])
x = torch.ones((batch_size,num_queries,num_hiddens))
print(f"attetion(x,x,x,valid_lens).shape={attetion(x,x,x,valid_lens).shape}")
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17

2.2 小结

自注意力机制运用了多头注意力机制,只不过区别在于自注意力机制的 queries,keys,values是相同的。

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/我家自动化/article/detail/346690
推荐阅读
相关标签
  

闽ICP备14008679号