当前位置:   article > 正文

10分钟理解RNN、LSTM、Transformer结构原理!_transformer lstm

transformer lstm

一、RNN

RNN 即循环神经网路,是NLP、语言识别等时间序列数据处理的基本网络框架。与图像数据不同,时间序列数据是指在不同时刻采集到的数据,这类数据的状态一般与时间有关。对于一句话,通过单个单词是难以理解整体意思的,只有通过处理这些词连接起来的整个序列,才能更好地理解整体信息。

1.1 RNN基本架构

如下图所示,是RNN的一个基本架构图。输入是一句话"I dislike the boring movie",通过中间的隐藏层H,最终通过一定计算得到输出O。我们可以将纵向的每一个子模块(红色框所示)看作是一个全连接网络,那么沿着时间维度,一共有5个全连接结构,这些全连接网络的参数是共享的。

![在这里插入图片描述](https://img-blog.csdnimg.cn/14c15e7edb364e818b705537d4959246.png
与之不同的是,中间隐藏层的状态不仅受当前时刻的输入影响,还与上一时刻的隐藏层节点有关。用公式如下:
在这里插入图片描述
进一步,公式展开如下:
在这里插入图片描述
通过下面这张图(情感分类案例),可以更直观地看出隐藏层中状态变化(图中蓝色框所示):
在这里插入图片描述

1.2 RNN经典的三种结构

根据RNN输入输出的不同,常分为三种结构:Vector-to- Sequence(1对多)、Sequence-to-Vector(多对1)、Encoder-Decoder(多对多)

1.2.1 vector-to-sequence结构

有时我们要处理的问题输入是一个单独的值,输出是一个序列,比如从图像生成文本。此时,有两种主要建模方式:

  • 方式1:可只在其中的某一个序列进行计算,比如序列第一个进行输入计算,其建模方式如下:
    在这里插入图片描述

  • 方式二:把输入信息X作为每个阶段的输入,其建模方式如下:
    在这里插入图片描述

1.2.2 sequence-to-vector结构

当我们要处理的问题输入是一个序列,输出是一个单独的值(比如感情分类、文本分类问题),此时通常在最后的一个序列上进行输出变换,其建模如下所示:
在这里插入图片描述

1.2.3 Encoder-Decoder结构

原始的sequence-to-sequence结构的RNN要求序列等长,然而我们遇到的大部分问题序列都是不等长的,如机器翻译中,源语言和目标语言的句子往往并没有相同的长度。​ 其结构由编码器和解码器组成:

  • 编码器:将输入数据编码成一个上下文向量 c c c,这部分称为Encoder,得到 c c c有多种方式,最简单的方法就是把Encoder的最后一个隐状态赋值给 c c c,还可以对最后的隐状态做一个变换得到 c c c,也可以对所有的隐状态做变换。其示意如下所示:
    在这里插入图片描述
  • 解码器:用另一个RNN网络(我们将其称为Decoder)对其进行解码。解码的方式一般有两种:
    • 将步骤一中的 c ​ c​ c作为初始状态输入到Decoder :
      在这里插入图片描述
    • c c c作为Decoder的每一步输入,示意图如下所示:
      在这里插入图片描述

1.3 RNN常用领域

根据其结构的不同,所使用的场景也不同:

  • Vector-to- Sequence(1对多):常用于图像生成文字、图像生成语音或音乐等领域
  • Sequence-to-vector(多对1):常用于文本分类、情感识别、视频分类等领域
  • Encoder-Decoder(多对多):使用场景很广,包括机器翻译、文本摘要、阅读理解、语言识别等领域

1.4 RNN的优缺点

优点:

  • 隐藏层中,t 时刻的状态与 t 时刻输入和 t-1 时刻状态共同决定,这样有助于建立词语上下文之间的联系。
  • RNN中各个全连接网络共享一套参数,大大减少了网络的参数量,使得网络训练起来更加高效。

缺点:

  • 在上述通用的Encoder-Decoder结构中,Encoder把所有的输入序列都编码成一个统一的语义特征 c ​ c​ c再解码,因此, c ​ c​ c中必须包含原始序列中的所有信息,它的长度就成了限制模型性能的瓶颈。如机器翻译问题,当要翻译的句子较长时,一个 c ​ c​ c可能存不下那么多信息,就会造成翻译精度的下降。
  • 由于RNN特有的memory会影响后期其他的RNN的特点,梯度时大时小,learning rate没法个性化的调整,导致RNN在train的过程中,Loss会震荡起伏。(可以设置临界值,当梯度大于某个临界值,直接截断,用这个临界值作为梯度的大小,防止大幅震荡

1.5 RNN中为什么会出现梯度消失

sigmoid函数的函数及导数图如下所示:
在这里插入图片描述

  • 从上图观察可知,sigmoid函数的导数范围是(0,0.25],tanh函数的导数范围是(0,1],他们的导数最大都不大于1。
  • 如果取tanh或sigmoid函数作为激活函数的话,随着时间序列的不断深入,激活函数导数的累乘会导致结果越乘越小,直到接近于0,这就是“梯度消失“现象。
  • 实际使用中,会优先选择tanh函数,原因是tanh函数相对于sigmoid函数来说梯度较大,收敛速度更快且引起梯度消失更慢

梯度消失是由于无限的利用历史数据而造成,但是RNN的特点本来就是能利用历史数据获取更多的可利用信息,解决RNN中的梯度消失方法主要有:

  • 选取更好的激活函数,如Relu激活函数。ReLU函数的左侧导数为0,右侧导数恒为1,这就避免了“梯度消失“的发生。但恒为1的导数容易导致“梯度爆炸“,但设定合适的阈值可以解决这个问题。
  • 加入BN层,其优点包括可加速收敛、控制过拟合,可以少用或不用Dropout和正则、降低网络对初始化权重不敏感,且能允许使用较大的学习率等。
  • 改变传播结构,比如下面的LSTM结构

二、LSTM

LSTM即长短期记忆网络,是为了解决长依赖问题而设计的一种特殊的RNN网络。所谓的长依赖,就是因为计算距离较远的节点之间的联系会涉及到雅可比矩阵的多次相乘,造成的梯度消失现象。LSTM也可称为门限RNN,它通过在不同时刻改变系数,控制网络忘记当前已经积累的信息,从而解决这一问题。

2.1 LSTM与RNN差异

  • RNN 都具有一种重复神经网络模块的链式的形式,在标准的 RNN 中,这个重复的模块只有一个非常简单的结构,例如一个 tanh 层(当然,也可以是sigmoid激活函数),如下:
    在这里插入图片描述
  • 在LSTM中的重复模块中,不同于RNN,这里有四个神经网络层,以一种非常特殊的方式进行交互:
    在这里插入图片描述

2.2 LSTM核心思想图解

LSTM 通过精心设计“门”结构来删除或者增加信息到细胞状态。门是一种让信息选择式通过的方法,他包含一个 sigmoid 神经网络层和一个 pointwise 乘法操作。示意图如下:
在这里插入图片描述
LSTM 拥有三个门,分别是:忘记层门,输入层门和输出层门,来保护和控制细胞状态。

2.2.1 忘记层门
  • 作用对象:细胞状态。
  • ​作用:将细胞状态中的信息选择性的遗忘。
  • 操作步骤:该门会读取 h t − 1 h_{t-1} ht1 x t x_t xt,输出一个在 0 到 1 之间的数值给每个在细胞状态 C t − 1 ​ C_{t-1}​ Ct1中的数字。1 表示“完全保留”,0 表示“完全舍弃”。示意图如下:
    在这里插入图片描述
2.2.2 输入层门
  • 作用对象:细胞状态
  • 作用:将新的信息选择性的记录到细胞状态中。
  • 操作步骤:​
    • 步骤一,sigmoid 层称 “输入门层” 决定什么值我们将要更新。
    • 步骤二,tanh 层创建一个新的候选值向量 C ~ t \tilde{C}_t C~t加入到状态中。其示意图如下:
      在这里插入图片描述
    • 步骤三:将 c t − 1 c_{t-1} ct1更新为 c t c_{t} ct。将旧状态与 f t f_t ft相乘,丢弃掉我们确定需要丢弃的信息。接着加上 i t ∗ C ~ t i_t * \tilde{C}_t itC~t得到新的候选值,根据我们决定更新每个状态的程度进行变化。其示意图如下:
      在这里插入图片描述
2.2.3 输出层门
  • 作用对象:隐层 h t h_t ht
  • 作用:确定输出什么值。
  • 操作步骤:
    • 步骤一:通过sigmoid 层来确定细胞状态的哪个部分将输出。
    • 步骤二:把细胞状态 c t c_{t} ct通过 tanh 进行处理,并将它和 sigmoid 门的输出相乘,最终我们仅仅会输出我们确定输出的那部分。

其示意图如下所示:
在这里插入图片描述

2.3 LSTM应用场景

三、Transformer

Transformer是一种用于自然语言处理(NLP)和其他序列到序列(sequence-to-sequence)任务的模型架构,由Google在2017年提出。它在机器翻译任务中取得了重大突破,并在NLP领域得到广泛应用。

传统的序列模型,如循环神经网络(RNN),在处理长序列时存在梯度消失和计算效率低的问题。而Transformer采用了注意力机制(Attention Mechanism)来建立全局上下文关系,有效解决了这些问题。

3.1 Transformer的核心

自注意力机制是Transformer的核心。它允许模型在生成每个词语的表示时,根据整个输入序列中其他词语的上下文信息进行加权计算。因此,模型可以更好地捕捉词语之间的依赖关系和长距离依赖。注意力机制的计算效率通过使用矩阵运算和并行计算得到提高。

除了注意力机制,Transformer还引入了残差连接(Residual Connections)和层归一化(Layer Normalization)等技术,有助于模型收敛和训练稳定性。

Transformer的应用包括机器翻译、文本生成、问答系统、语言模型等。它不仅在性能上超过了传统的序列模型,而且具有并行计算的优势,能够高效地处理长序列。

Transformer的成功引发了一系列基于Transformer的模型的发展,如BERT、GPT等。这些模型在各个NLP任务上取得了重大突破,并成为了自然语言处理领域的重要里程碑之一。

3.2 Transformer主体架构

以机器翻译任务为例,下面这句话中的it指代The animal,这是需要理解上下文信息才能得出的,而Transformer就能很好地帮我们做到这一点,具体如何做到?还需要下面我们区去探究它的注意力机制在这里插入图片描述
在这里插入图片描述

3.2.1 整体结构

Transformer模型由编码器(Encoder)和解码器(Decoder)组成。

  • 编码器将输入序列中的每个词嵌入向量化,并通过多个自注意力层(Self-Attention)和前馈神经网络层进行信息提取和特征表示。
  • 解码器则在编码器的基础上进一步使用自注意力层来生成目标序列。
    在这里插入图片描述
    在实际的使用过程中,会使用多个编码器和解码器串联的形式,如下图所示:
    在这里插入图片描述
3.2.2 编码器

在Transformer模型中,编码器主要包含以下结构:

  • 位置编码(Positional Encoding):由于Transformer模型没有使用卷积或循环结构,无法通过位置信息来捕捉序列中元素的顺序关系。为了解决这个问题,位置编码被添加到输入序列中,以提供每个元素相对于其他元素的位置信息。

  • 自注意力(Self-Attention):自注意力机制是Transformer模型的核心组成部分。编码器中的每个注意力头都会为输入序列中的每个元素计算一个注意力权重,该权重表示该元素与其他元素的相关性。通过自注意力机制,编码器可以在不同层次上学习序列中元素之间的依赖关系。

  • 多头注意力(Multi-Head Attention):为了捕捉不同关注点的信息,编码器通常会使用多个注意力头。每个注意力头都会独立学习不同的注意力权重,并生成一个注意力值的加权和。多头注意力可以使模型更好地捕捉不同特征之间的关联。

  • 前馈神经网络(Feed-forward Neural Network):编码器中的每个注意力子层后面通常跟着一个前馈神经网络。前馈神经网络是一个全连接的前向传播网络,用于将注意力子层的输出进行非线性变换和映射。通过前馈神经网络,编码器可以引入更多的非线性性和表达能力。

  • 残差连接(Residual Connections)和层归一化(Layer Normalization):为了稳定训练并加快信息传递,编码器中引入了残差连接和层归一化。残差连接允许信息在不同层之间直接跳过,有助于避免梯度消失或爆炸的问题。而层归一化则用于对注意力子层和前馈神经网络进行归一化,提高模型的训练稳定性。

通过以上结果的组合和堆叠,编码器能够对输入序列进行多层次的特征提取和表示学习,从而为下游任务(如机器翻译、文本分类等)提供更准确和丰富的表示。

在这里插入图片描述

(1)输入部分

通过词嵌入位置编码的方式,让单词信息转换成向量,同时让模型能轻松学到相对位置信息。

在这里插入图片描述
位置编码公式:
在这里插入图片描述
在这里插入图片描述

(2)注意力机制结构

如图所示,是单个注意力机制和多头注意力机制的结构示意图:
在这里插入图片描述
单个注意力机制的计算过程可以描述成公式:Q和K进行点乘计算向量相似性;然后采用softmax转换为概率分布;最后将概率分布和V进行加权求和。整体公式如下:
在这里插入图片描述
通过下面图片,可以更直观地理解:

在这里插入图片描述

(3)注意力机制推导过程

具体推到过程:我们以翻译”Thinking Machines“这两个单词为例

  • 首先,将单词进行词嵌入,和位置编码向量相加,得到自注意力层输入X
  • 初始化三个权重核 W Q W^Q WQ W K W^K WK W V W^V WV,分别对X进行矩阵相乘,得到查询向量Q、键向量K、值向量V
  • 通过Ateention(Q,K,V) 公式计算注意力,获得单词的注意力向量,该单词向量反映了与上下文单词的加权求结果
  • 由于Transformer采用多头注意力结构,所以需要将每一个注意力机制输出的向量,进行concat操作,然后通过全连接层,获得最后的输出。
    在这里插入图片描述

注:由于Q,K,V均来自同一个输入X计算得到的,故通常称为自注意力层

在这里插入图片描述
Transformer采用多头注意力机制的原因主要是:消除 W Q W^Q WQ W K W^K WK W V W^V WV初始矩阵值的影响;也有另一种说话,类似于CNN,增强表达空间。
在这里插入图片描述

3.2.3 解码器

在Transformer模型中,解码器(Decoder)是负责从编码器的输出中生成目标序列的部分。解码器主要包含以下结构:

  • 自注意力(Self-Attention)层:解码器的每个注意力头都会为目标序列中的每个位置计算一个注意力权重,该权重表示该位置与其他位置的相关性。通过自注意力机制,解码器可以在不同层次上学习目标序列中不同位置之间的依赖关系。

  • 编码-解码注意力(Encoder-Decoder Attention)层:在解码器中,为了获取与编码器输出相关的上下文信息,编码-解码注意力层被引入。该层会计算目标序列位置和编码器输出之间的注意力权重,以捕捉编码器输出对于当前目标位置的相关性。

  • 前馈神经网络(Feed-forward Neural Network):解码器中的每个注意力子层后面通常跟着一个前馈神经网络。前馈神经网络是一个全连接的前向传播网络,用于将注意力子层的输出进行非线性变换和映射。

  • 残差连接(Residual Connections)和层归一化(Layer Normalization):与编码器类似,解码器中也引入了残差连接和层归一化。残差连接允许信息在不同层之间直接跳过,有助于避免梯度消失或爆炸的问题。层归一化则用于对注意力子层和前馈神经网络进行归一化。

在解码器中,通常会使用多个解码器层堆叠在一起,每个层都具有相同的结构和参数。这样的堆叠使得解码器能够逐步生成目标序列,并逐渐获取更多的上下文信息和语义表示。

通过以上结构的组合和堆叠,解码器能够从编码器的输出中解码并生成目标序列,如机器翻译任务中将源语言句子翻译成目标语言句子。解码器的设计旨在允许在生成序列时引入上下文信息和全局依赖关系,以提高生成的序列质量和一致性

3.2.4 输出层

解码器的输出是向量,最终需要得到翻译的解决过,整个过程:首先会通过Linear层(即全连接网络,输出节点数可认为是词库的单词总数,比如2w个),然后经过softmax转换成概率输出,最后输出概率最大的索引值对应的词,即为翻译的结果。

在这里插入图片描述

3.3 基于Transformer的常用模型

基于Transformer的模型有很多种,以下是一些常见的基于Transformer的模型:

  • BERT(Bidirectional Encoder Representations from Transformers):BERT是一种基于Transformer的预训练语言模型,它在大规模的无标注数据上进行预训练,然后可以用于各种下游任务,如文本分类、命名实体识别、句子相似度等。

  • GPT(Generative Pre-trained Transformer):GPT是一种基于Transformer的生成式预训练模型,它通过在大规模的文本数据上进行预训练来学习上下文信息,然后可以用于生成文本、机器翻译等任务。

  • Transformer-XL:Transformer-XL是一种用于语言建模的扩展Transformer模型,它通过引入循环机制来解决长文本序列建模中的问题,并能够在长文本中捕捉更长距离的依赖关系。

  • T5(Text-to-Text Transfer Transformer):T5是一种多任务学习的Transformer模型,它可以统一各种自然语言处理任务,将输入和输出都转化为通用的文本形式,从而可以通过微调来适应不同任务。

  • XLNet:XLNet是一种基于Transformer的自回归语言模型,它通过学习排列不变性来解决自回归模型中的顺序偏置问题,并在多项下游任务上取得了优秀的性能。

这些模型都是基于Transformer架构的改进和扩展,通过充分利用Transformer的自注意力机制和多头注意力机制,能够有效地学习序列中的依赖关系和语义表示,从而在自然语言处理任务中取得了显著的性能提升。


由于水平有限,博客中难免会有一些错误,有纰漏之处恳请各位大佬不吝赐教!

在这里插入图片描述

声明:本文内容由网友自发贡献,转载请注明出处:【wpsshop】
推荐阅读
相关标签
  

闽ICP备14008679号