当前位置:   article > 正文

广西民族大学高级人工智能课程—头歌实践教学实践平台-LSTM

广西民族大学高级人工智能课程—头歌实践教学实践平台-LSTM

代码文件

  1. import torch
  2. def lstm(X, state, params):
  3. W_xi, W_hi, b_i, W_xf, W_hf, b_f, W_xo, W_ho, b_o, W_xc, W_hc, b_c, W_hq, b_q = params
  4. H, C = state
  5. # 遗忘门
  6. F = torch.sigmoid(torch.mm(X, W_xf) + torch.mm(H, W_hf) + b_f)
  7. # 输入门
  8. I = torch.sigmoid(torch.mm(X, W_xi) + torch.mm(H, W_hi) + b_i)
  9. C_tilde = torch.tanh(torch.mm(X, W_xc) + torch.mm(H, W_hc) + b_c)
  10. # 更新单元状态
  11. C = F * C + I * C_tilde
  12. # 输出门
  13. O = torch.sigmoid(torch.mm(X, W_xo) + torch.mm(H, W_ho) + b_o)
  14. H = O * torch.tanh(C)
  15. # 输出层
  16. Y = torch.mm(H, W_hq) + b_q
  17. return Y, (H, C)

题目描述

任务描述

本关任务:通过学习长短时记忆网络相关知识,编写实现长短时记忆网络。

相关知识

为了完成本关任务,你需要掌握:

  1. 长短时记忆网络;
  2. 门结构;
  3. 长短时记忆网络实现。
长短时记忆网络

传统的循环神经网络受限于梯度爆炸与梯度消失问题,使得网络随着输入序列的增长,抖动变得更为剧烈,导致无法学习 。长短时记忆网络( Long Short Term Memory Network, LSTM )便是为了解决此问题而被设计提出。其核心思想是通过添加一个网络内部状态c来记忆长期信息,这个新的状态我们称之为单元状态(Cell State),主要负责记忆长期信息。


图1 长短时记忆网络结构展开图

1为 LSTM 结构展开图,在某一时刻t,长短时记忆网络的神经元输入由三部分组成:当前网络的输入Xt​、上一时刻的输出st−1​以及上一时刻的单元状态ct−1​,神经元的输出为当前时刻的输出st​,当前时刻的单元状态为ct​。 如图2所示, LSTM 的核心是单元状态。单元状态像传送带一样,它贯穿整个网络却只有很少的分支,这样能保证信息不变的流过整个网络。后面会 LSTM 结构进行详细的说明。


图2 单元状态图

LSTM 能通过一种被称为门的结构对单元状态进行控制,选择性的决定让哪些信息通过。门的结构很简单,由一个 Sigmoid 层和一个点乘操作的组合而成。如图3所示:


图3 门示意图

其中黄色矩形表示 Sigmoid 层,红色圆圈代表点乘操作。

因为 Sigmoid 层的输出是01,这代表有多少信息能够流过Sigmoid 层。0表示都不能通过,1表示都能通过。 其神经元的结构如图4所示:


图4 LSTM 神经元结构示意图

其中的图标的含义如图5所示:


图5 图标示意图

Vector transfer 表示一个向量从一个节点的输出到其他节点的输入。Pointwise Operation 代表按位 Pointwise 的操作,例如向量的和。 Concatenate 表示向量的连接,Copy 表示内容被复制,然后分发到不同的位置。

门结构

一个 LSTM 里面包含三个门来控制单元状态,分别为:遗忘门、输入门和输出门。

遗忘门

LSTM 首先需要决定细胞状态需要留下那些信息,这个功能结构即遗忘门。 它主要决定上一时刻的输出ht−1​ct−1​状态是否保留到当前时刻的ct​当中。具体是通过一个 Sigmoid 层来实现。它通过查看ht−1​xt​信息来输出一个[0,1]之间的向量,该向量的值表示单元状态Ct−1​中哪些信息保留或丢弃。如图6所示:


图6 遗忘门示意图

它的输入为上一时刻的输出ht−1​与当前时刻的输入xt​,经过 Sigmoid 函数变换,得到内部当前时刻输出ft​。 具体公式表达如下 :

ft​=σ(Wf​[ht−1​,xt​]+bf​)

其中,Wf​ 表示遗忘门的权值矩阵,[ht−1​,xt​]表示两个向量纵向连接操作,bf​表示输入的偏置项。

输入门

通过遗忘门决定神经元中什么信息保留下来后,我们现在需要确定当前的输入xt​有多少信息需要保存到当前的单元状态ct​中,此功能结构为输入门。这里包含两个部分:第一,Sigmoid 层决定那些输入将要被更新。第二,一个 Tanh 层生成一个新的候选向量C~t​。这两部分的输出进行逐点相乘,从而对单元状态ct​进行更新。输入门结构如图7所示:


图7 输入门示意图

根据图7所示,输入门的计算公式如下:

it​=σ(Wi​[st−1​,xt​]+bi​)c~t​=tanh(Wc​[st−1​,xt​]+bc​)

在计算it​c~t​时,它们的权值矩阵是不同的,因此在训练的过程中需要单独训练。 通过遗忘门与输入门的计算后,我们可以对单元状态ct​进行更新操作。操作方法如图8所示:


图8 更新单元状态示意图

计算公式如下:

c=ft​∘ct−1​+it​∘c~t​

表示按元素逐乘操作。

输出门

更新完单元状态后需要根据ht−1​xt​来考虑如何将当前的信息进行输出,这部分功能由输出门完成。输出门主要来控制单元状态ct​有多少可以输出到长短时记忆网络的当前输出值ht​中,如图9所示:


图9 输出门示意图

该单元的输出主要依赖当前的神经元状态ct​,不只是单纯依赖单元状态,还需要进行一次信息过滤的处理,即由引入的 Sigmoid 层来完成。这一层将单元状态经过 Tanh 层处理后的数据进行元素相乘操作,将得到的ht​有选择地输出到下一时刻和对外输出。具体计算公式如下 :

ot​=σ(Wo​[st−1​,xt​]+bo​)st​=ot​∘tanh(ct​)

长短时记忆网络实现

LSTM 的实现步骤:

  1. 通过遗忘门,计算允许继续通过神经元的信息;

    F=sigmoid(Wxf​X+Whf​H+bf​)

  2. 通过输入门,计算当前输入中需要保留到单元状态的信息;

    IC~C​=sigmoid(Wxi​X+Whi​H+bi​)=tanh(Wxc​X+Whc​H+bc​)=F∘C+I∘C~​

  3. 通过输出门,计算需要输出的信息;

    OHnew​​=sigmoid(Wxo​X+Who​H+bo​)=O∘tanh(C)​

  4. 通过输出层计算输出。

    Y=Whq​Hnew​+bq​

编程要求

根据提示,在右侧编辑器 Begin-End 区间补充代码,编写实现 LSTM 的遗忘门、输入门、输出门。

测试说明

平台会对你编写的代码进行测试:

测试输入:无 预期输出: True


开始你的任务吧,祝你成功!

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/IT小白/article/detail/532153
推荐阅读
相关标签
  

闽ICP备14008679号