当前位置:   article > 正文

注意力机制汇总(3)----多头注意力机制_多头注意力机制公式

多头注意力机制公式

注意力机制汇总(3)---- 多头注意力机制(tranformer)

本章节接着上面注意力机制汇总(2),再进一步探索多头注意力机制的原理。点击此处跳转


在上一章我们了解到self-attention的公式,是有Q,K点乘然后除以d k 的1/2次方,再经过softmax后,乘上V得到经过注意力机制之后的输出。本章节将详细介绍由 N个self-attention组成的 多头注意力机制(Multi-Head-Attention),其公式如下所示:
M u l t i H e a d ( Q , K , V ) = C o n c a t ( h e a d 1 , . . . , h e a d h ) W O MultiHead(Q,K,V) = Concat(head1, ...,head_h)W^O MultiHead(Q,K,V)=Concat(head1,...,headh)WO

w h e r e h e a d i = A t t e n t i o n ( Q W i Q , K W i K , V W i V ) where head_i = Attention(QW^Q_i, KW^K_i,VW^V_i) whereheadi=Attention(QWiQ,KWiK,VWiV)

在这里插入图片描述

由上面公式和图例可以看出,多头注意力机制是由N(N=8)个self-attention计算完成后,先经过concat拼凑到一起,然后经过WO的矩形完成线性变换,变化成与输入的token维度一致的输出。(WO和WQ,WK,WV矩阵一样,都是在模型训练阶段一同训练出来的权重矩阵

我们用X模拟网络的输入,Z模拟网络的输出,多头注意力机制的流程如下:

  1. 通过WQ和WK矩阵,将输入X线性变换为Qi(Query)Ki(Key);----->此处线性变换=矩阵乘法,i:(1,8)
  2. 然后Qi与KiT做矩阵乘法,再经过softmax(QKT/dk1/2),得到转化后的权重系数,称为Yi
  3. 然后将WIV和输入X经过线性变换得到Vi(Vaule),接着让Vi 和Yi 做矩阵乘法得到Zi 矩阵
  4. 重复上述1-3过程8次。即i=18**。将得到的8个Zi~,concat到 一起得到Zconcat **
  5. 最后将Zconcat WO相乘得到输出Z。----> 此处输出Z可以看作输入X的经过一次Multi-Head-Attention后的变形, 此外WO除了让输出的维度一致之外,主要是具备将随机concat的Zi 还原成特定组合的Z,。

上面的self-Attention, Multi-Head-Attention便是Transformer的灵魂、核心!

此处贴上transformer训练的过程图

  1. 首先通过<起始>预测出The然基于已预测出的2个继续预测出computer
  2. 然后在计算机当中并行计算。
  3. 使用标签当做训练时的输入,来减小训练时产生的误差(此处加入mask:盖住部分区域来模拟真实输入)。

在这里插入图片描述

图例

img

代码解读

1. 词嵌入

import torch

torch.nn.Embedding(num_embeddings, embedding_dim)# 可以实现词嵌入, 
# num_embeddings设置为输入X的词的个数+2, size of the dictionary of embedding
# embedding_dim则是想要将词映射到的维度,the size of each embedding vector
  • 1
  • 2
  • 3
  • 4
  • 5

2. 位置编码

词嵌入之后紧接着就是位置编码,位置编码用以区分不同词以及同词不同特征之间的关系。代码中需要注意:X_只是初始化的矩阵,并不是输入进来的;完成位置编码之后会加一个dropout。另外,位置编码是最后加上去的,因此输入输出形状不变。

def positional_encoding(X, num_features, dropout_p=0.1, max_len=512) -> Tensor:
    r'''
        给输入加入位置编码
    参数:
        - num_features: 输入进来的维度
        - dropout_p: dropout的概率,当其为非零时执行dropout
        - max_len: 句子的最大长度,默认512
    
    形状:
        - 输入: [batch_size, seq_length, num_features]
        - 输出: [batch_size, seq_length, num_features]

    例子:
        >>> X = torch.randn((2,4,10))
        >>> X = positional_encoding(X, 10)
        >>> print(X.shape)
        >>> torch.Size([2, 4, 10])
    '''

    dropout = nn.Dropout(dropout_p)
    P = torch.zeros((1,max_len,num_features))
    X_ = (torch.arange(max_len,dtype=torch.float32).reshape(-1,1) / 
          torch.pow(10000, torch.arange(0,num_features,2,dtype=torch.float32) /num_features))
    P[:,:,0::2] = torch.sin(X_)
    P[:,:,1::2] = torch.cos(X_)
    X = X + P[:,:X.shape[1],:].to(X.device)  # 此处表面位置编码是直接数值相加的。所以输出的type没有变化
    return dropout(X)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27

3. self-attention

自注意力机制,在上一篇文章中讨论了很多,具体可以去查看

# 核心代码
	# 计算Q*K的转置,在除上根号dk
	attn_scores = torch.bmm(q, k.transpose(1, 2)) / self.scale
    # 送入softmax进行归一化
	attn_weights = F.softmax(attn_scores, dim=-1)
    # 与V相乘得到新的输出
	attn_output = torch.bmm(attn_weights, v)
	
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8

4. Encode编码层(多头此处忽略直接到编码层)

首先经过位置编码,然后经过多头注意力机制,再次期间混杂着short-cut和dropout,接着经过LN归一化与2个Linear全连接层(中间包含一个relu激活函数),在经过short-cut、dropout、LN得到输出结果

def forward(self, src: Tensor, src_mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None):
        src = positional_encoding(src, src.shape[-1])  # 位置编码
        src2 = self.self_attn(src, src, src, attn_mask=src_mask, 
        key_padding_mask=src_key_padding_mask)[0]
        src = src + self.dropout1(src2)
        # LN
        src = self.norm1(src)
        # 全连接+relu+dropout+全连接
        src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
        src = src + self.dropout2(src2)
        # LN
        src = self.norm2(src)
        return src
    
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14

5. Decode解码层

解码层的代码与编码层的类似:多头注意力与全连接层的组合,中间夹杂着一些归一化的方法。

5. Decode解码层

解码层的代码与编码层的类似:多头注意力与全连接层的组合,中间夹杂着一些归一化的方法。
在这里插入图片描述

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/寸_铁/article/detail/780279
推荐阅读
相关标签
  

闽ICP备14008679号