赞
踩
本章节接着上面注意力机制汇总(2),再进一步探索多头注意力机制的原理。点击此处跳转
w h e r e h e a d i = A t t e n t i o n ( Q W i Q , K W i K , V W i V ) where head_i = Attention(QW^Q_i, KW^K_i,VW^V_i) whereheadi=Attention(QWiQ,KWiK,VWiV)
由上面公式和图例可以看出,多头注意力机制是由N(N=8)个self-attention计算完成后,先经过concat拼凑到一起,然后经过WO的矩形完成线性变换,变化成与输入的token维度一致的输出。(WO和WQ,WK,WV矩阵一样,都是在模型训练阶段一同训练出来的权重矩阵)
我们用X模拟网络的输入,Z模拟网络的输出,多头注意力机制的流程如下:
上面的self-Attention, Multi-Head-Attention便是Transformer的灵魂、核心!
此处贴上transformer训练的过程图
import torch
torch.nn.Embedding(num_embeddings, embedding_dim)# 可以实现词嵌入,
# num_embeddings设置为输入X的词的个数+2, size of the dictionary of embedding
# embedding_dim则是想要将词映射到的维度,the size of each embedding vector
词嵌入之后紧接着就是位置编码,位置编码用以区分不同词以及同词不同特征之间的关系。代码中需要注意:X_只是初始化的矩阵,并不是输入进来的;完成位置编码之后会加一个dropout。另外,位置编码是最后加上去的,因此输入输出形状不变。
def positional_encoding(X, num_features, dropout_p=0.1, max_len=512) -> Tensor: r''' 给输入加入位置编码 参数: - num_features: 输入进来的维度 - dropout_p: dropout的概率,当其为非零时执行dropout - max_len: 句子的最大长度,默认512 形状: - 输入: [batch_size, seq_length, num_features] - 输出: [batch_size, seq_length, num_features] 例子: >>> X = torch.randn((2,4,10)) >>> X = positional_encoding(X, 10) >>> print(X.shape) >>> torch.Size([2, 4, 10]) ''' dropout = nn.Dropout(dropout_p) P = torch.zeros((1,max_len,num_features)) X_ = (torch.arange(max_len,dtype=torch.float32).reshape(-1,1) / torch.pow(10000, torch.arange(0,num_features,2,dtype=torch.float32) /num_features)) P[:,:,0::2] = torch.sin(X_) P[:,:,1::2] = torch.cos(X_) X = X + P[:,:X.shape[1],:].to(X.device) # 此处表面位置编码是直接数值相加的。所以输出的type没有变化 return dropout(X)
自注意力机制,在上一篇文章中讨论了很多,具体可以去查看
# 核心代码
# 计算Q*K的转置,在除上根号dk
attn_scores = torch.bmm(q, k.transpose(1, 2)) / self.scale
# 送入softmax进行归一化
attn_weights = F.softmax(attn_scores, dim=-1)
# 与V相乘得到新的输出
attn_output = torch.bmm(attn_weights, v)
首先经过位置编码,然后经过多头注意力机制,再次期间混杂着short-cut和dropout,接着经过LN归一化与2个Linear全连接层(中间包含一个relu激活函数),在经过short-cut、dropout、LN得到输出结果
def forward(self, src: Tensor, src_mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None):
src = positional_encoding(src, src.shape[-1]) # 位置编码
src2 = self.self_attn(src, src, src, attn_mask=src_mask,
key_padding_mask=src_key_padding_mask)[0]
src = src + self.dropout1(src2)
# LN
src = self.norm1(src)
# 全连接+relu+dropout+全连接
src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
src = src + self.dropout2(src2)
# LN
src = self.norm2(src)
return src
解码层的代码与编码层的类似:多头注意力与全连接层的组合,中间夹杂着一些归一化的方法。
解码层的代码与编码层的类似:多头注意力与全连接层的组合,中间夹杂着一些归一化的方法。
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。