赞
踩
本人也是在学习过程中,以下是经过学习对transformer模型的一些理解和一些常见问题的解释,如有错误之处,还请批评指正
本文主要根据数据的输入、计算和输出进行阐述
图1 transformer架构
输入
编码器:输入文本序列并引入到嵌入层进行编码,并于位置信息进行相加(维度必须一致)
解码器:
分为训练和预测
训练时,将encoder的值,target,位置信息作为输入(teacher-forcing)
预测时,encoder的值(kv),i – 1的decoder输出作为(query)
什么是teacher-forcing,以及为什么要使用:
但teacher-forcing也存在一定的问题,使得模型在预测时效果较差,为此可以引入概率p,使得模型以概率p决定是否使用teacher-forcing
encoder层
由于计算机无法直接理解我们的输入(文本序列、图像等),因此我们需要将其进行编码,one-hot便是一种常见的编码方式。
举个例子:
对于一串文字,我们可以应用one-hot进行编码:小普喜欢星海湾的朋友
经过one-hot编码,变成
这样的表达很容易计算,但太过于冗余,因此,嵌入层(embeding)的作用就是将其降维,举个例子:
可以看到,矩阵经过embeding后,2*6的大小得到了压缩。
接着引入了位置编码,它的功能主要解决了self-attention不能关注位置信息的缺陷,在这里我们通过直接将位置信息与嵌入层进行相加,从而得到了含有位置信息的特征,至于为什么是相加而不是连接,我查阅资料发现,连接虽然保留了位置信息和原始信息的独立性,但最终的结果并没有比直接相加效果好多少,反而增加了计算量,因此直接进行相加操作。
而位置编码主要使用如下计算方式:
p代表位置矩阵的某个元素,可以这样理解:
对于输入(n, d)的文本序列,其中n是词元数量,d是文本的嵌入层后的维度,位置编码使用相同形状的位置嵌入矩阵,矩阵第i行、第2j列和2j+1列上的元素为pi,2j(2j+1)
接着是一个多头注意力层,有关注意力的概念,这里不做详细解释,网上有很多优秀的文章,主要工作如下:
引入了query,key,value的概念,对于每个输入X,他的attention计算如下:
多头注意力中的多头,体现在我们可以使用多个Q,K,V,他们分别由不同的权重矩阵Wq,Wk,Wv得到,从而实现不同的注意力机制
其中Q = Wq * embedding后的向量,以此类推
接着是一个加&规范化
首先是加,它相当于是一个resnet模型,对原始嵌入层后的数据与处理后的数据进行相加,从而保留了原始的数据语义,有关restnet的机制,大家可以参考其他大佬的文章
引入resnet的原因:
防止出现网络退化导致性能下降
通过h(x) = f(x) + x,只需让网络学习到f(x)等于0即可
规范化,就是对特征进行正态分布标准化,这里transformer使用的是layer_norm,因为对于不同的批量语句来说,它的信息更多包含在当前的序列当中,因此使用layer_norm效果更好。
batch_norm和layer_norm的区别:
Batch_norm和layer_norm
Transformer采用LN。
假设特征是(N,C,H,W)
N是批量大小,C是通道数,H是矩阵的高度,W是矩阵的宽度,对于刚入门的同学来说,可以这样理解:
我们的特征是一个四维矩阵,它由N个通道数为C的形状为(H, W)的矩阵组成,通道数可以用图片的RGB举例,假设我们有N张分辨率大小为(H,W)的图片输入,那么特征矩阵形状就是(N, 3, H, W)
BN是对相同通道的不同样本进行归一化,即
在这里固定i, j和c,进行计算
LN是在每个矩阵(h, w)上进行计算,即
前馈网络:
该部分进行首先是一个线性层,接着进行relu,然后又是一个线性层,如下:
decoder层:
上面已经阐述,输入部分不在赘述
掩蔽多头注意力:
这里的“掩蔽”主要体现在预测当中,主要用于处理序列数据,并确保在自注意力机制中,模型在预测序列的每个位置时只能依赖于该位置之前的信息,而不能依赖之后的信息。这对于序列生成任务,如语言建模或文本生成,是非常重要的。
在普通的多头自注意力机制中,每个位置的输出是由所有序列位置的加权和组成的,权重由注意力分布确定。但是,在训练过程中,我们需要确保在生成序列时,每个位置只能依赖于之前的信息,以避免信息泄漏。为此,引入了掩蔽(masking)的概念。
接下来这个多头注意力机制有所不同,这里我们的输入有encoder部分得到的K和V矩阵,以及前一个时间步得到的query,共同输入到该模块。
接着是和encoder相同的两个模块,最后通过一个全连接层,输出我们想要的结果。
请注意,在论文attention is all you need原文中,图片中的n=6,也就是堆叠了6次虚线中的模块。
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。