赞
踩
近期抽空重整了一遍Transformer(论文下载)。
距离Transformer提出差不多有四年了,也算是一个老生常谈的话题,关于Transformer的讲解有相当多的线上资源可以参考,再不济详读一遍论文也能大致掌握,但是如果现在要求从零开始写出一个Transformer,可能这并不是很轻松的事情。笔者虽然之前也已经数次应用,但是主要还是基于Tensorflow和keras框架编写,然而现在Tensorflow有些问题,这将在本文的第三部分Tensorflow 实现与问题中详细说明。考虑到之后可能还是主要会在PyTorch的框架下进行开发,正好趁过渡期空闲可以花时间用PyTorch实现一个Transformer的小demo,一方面是熟悉PyTorch的开发,另一方面也是加深对Transformer的理解,毕竟将来大约是会经常需要使用,并且在其基础上进行改良的。
事实上很多事情都是如此,看起来容易,做起来就会发现有很多问题,本文在第一部分Transformer模型详解及注意点中将记录笔者在本次Transformer实现中做的一些值得注意的点;第二部分将展示PyTorch中Transformer模型的实现代码,以及如何使用该模型完成一个简单的seq2seq预测任务;第三部分同样会给出Tensorflow中Transformer模型的实现代码,以及目前Tensorflow的一些问题。
本文不再赘述Transformer的原理,这个已经有很多其他文章进行了详细说明,因此需要一些前置的了解知识,可以通过上面的论文下载 链接阅读原文。
上图是Transformer的结构图解, 当中大致包含如下几个元素:
- Position Encoding: 位置编码;
- Position-wise Feed-Forward Networks: 即图中的Feed Forward模块, 这个其实是一个非常简单的模块, 简单实现就是一个只包含一个隐层的神经网络;
- Multihead Attention 与 Scaled Dot-Product Attention: 注意力机制;
- Encoder 与 Decoder: 编码器与解码器(核心部件);
- 关于上述部件在本文3.1节中的transformer.py代码中都有相应的类与其对应, 并且笔者已经做了非常详细的注释(英文), 以下主要就实现上的细节做说明, 可结合本文3.1节中的transformer.py代码一起理解;
Position-wise Feed-Forward Networks正如上述是一个非常简单的三层神经网络: F F N ( x ) = max ( 0 , x W 1 + b 1 ) W 2 + b 2 {\rm FFN}(x) = \max(0, xW_1 + b_1)W_2 + b_2 FFN(x)=max(0,xW1+b1)W2+b2隐层中使用的是ReLU激活函数(即 max ( 0 , x ) \max(0,x) max(0,x)), 但是要确保的是该模块的输入与输出的维度是完全相同的;
关于Position Encoding的理解:
class BertEmbeddings(nn.Module):
def __init__(self, config):
super().__init__()
self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id) # (vocab_size, hidden_size)
self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) # (512, hidden_size)
self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size) # (2, hidden_size)
# self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
# any TensorFlow checkpoint file
self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
q
: 即查询矩阵
Q
Q
Q, 形状为(batch_size, n_head, len_q, d_q)
;k
: 即键矩阵
K
K
K, 形状为(batch_size, n_head, len_k, d_k)
;v
: 即值矩阵
V
V
V, 形状为(batch_size, n_head, len_v, d_v)
;mask
: 掩码矩阵, 形状应当为(batch_size, n_head, len_q, len_k)
, 不过其实只要(len_q, len_k)
即可, 因为另外两个维度可以直接用unsqueeze
或extend
来复制扩充;d_q
与d_k
的大小必须相等;len_k
与len_v
大小必须相等;d_q = d_k = d_v = d_model / n_head = 8
, 并且d_model = 512
, n_head = 8
;d_input
, d_output
;q
: 即查询矩阵
Q
Q
Q, 形状为(batch_size, len_q, d_q)
;k
: 即键矩阵
Q
Q
Q, 形状为(batch_size, len_k, d_k)
;v
: 即值矩阵
Q
Q
Q, 形状为(batch_size, len_v, d_v)
;mask
: 掩码矩阵, 同scaled dot-product attention中的描述;len_k
, len_v
应当被padding到等长, 即长度应为scaled dot-product attention中的len_k = len_v
;d_input
等于
d
m
o
d
e
l
d_{\rm model}
dmodeld_output
等于d_q
(以及d_k
和d_v
, 正如scaled dot-product attention中所提, 它们是相等的);d_input = d_output
, 这样可以便于进行残差连接的计算(因为残差连接需要将输入加到输出上再归一化后得到最终输出);Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。