赞
踩
- import torch
-
- def lstm(X, state, params):
- W_xi, W_hi, b_i, W_xf, W_hf, b_f, W_xo, W_ho, b_o, W_xc, W_hc, b_c, W_hq, b_q = params
- H, C = state
-
- # 遗忘门
- F = torch.sigmoid(torch.mm(X, W_xf) + torch.mm(H, W_hf) + b_f)
-
- # 输入门
- I = torch.sigmoid(torch.mm(X, W_xi) + torch.mm(H, W_hi) + b_i)
- C_tilde = torch.tanh(torch.mm(X, W_xc) + torch.mm(H, W_hc) + b_c)
-
- # 更新单元状态
- C = F * C + I * C_tilde
-
- # 输出门
- O = torch.sigmoid(torch.mm(X, W_xo) + torch.mm(H, W_ho) + b_o)
- H = O * torch.tanh(C)
-
- # 输出层
- Y = torch.mm(H, W_hq) + b_q
-
- return Y, (H, C)
本关任务:通过学习长短时记忆网络相关知识,编写实现长短时记忆网络。
为了完成本关任务,你需要掌握:
传统的循环神经网络受限于梯度爆炸与梯度消失问题,使得网络随着输入序列的增长,抖动变得更为剧烈,导致无法学习 。长短时记忆网络( Long Short Term Memory Network, LSTM )便是为了解决此问题而被设计提出。其核心思想是通过添加一个网络内部状态c
来记忆长期信息,这个新的状态我们称之为单元状态(Cell State),主要负责记忆长期信息。
图1 长短时记忆网络结构展开图
图1
为 LSTM 结构展开图,在某一时刻t
,长短时记忆网络的神经元输入由三部分组成:当前网络的输入Xt
、上一时刻的输出st−1
以及上一时刻的单元状态ct−1
,神经元的输出为当前时刻的输出st
,当前时刻的单元状态为ct
。 如图2
所示, LSTM 的核心是单元状态。单元状态像传送带一样,它贯穿整个网络却只有很少的分支,这样能保证信息不变的流过整个网络。后面会 LSTM 结构进行详细的说明。
图2 单元状态图
LSTM 能通过一种被称为门的结构对单元状态进行控制,选择性的决定让哪些信息通过。门的结构很简单,由一个 Sigmoid 层和一个点乘操作的组合而成。如图3
所示:
图3 门示意图
其中黄色矩形表示 Sigmoid 层,红色圆圈代表点乘操作。
因为 Sigmoid 层的输出是0
或1
,这代表有多少信息能够流过Sigmoid 层。0
表示都不能通过,1
表示都能通过。 其神经元的结构如图4
所示:
图4 LSTM 神经元结构示意图
其中的图标的含义如图5
所示:
图5 图标示意图
Vector transfer 表示一个向量从一个节点的输出到其他节点的输入。Pointwise Operation 代表按位 Pointwise 的操作,例如向量的和。 Concatenate 表示向量的连接,Copy 表示内容被复制,然后分发到不同的位置。
一个 LSTM 里面包含三个门来控制单元状态,分别为:遗忘门、输入门和输出门。
LSTM 首先需要决定细胞状态需要留下那些信息,这个功能结构即遗忘门。 它主要决定上一时刻的输出ht−1
与ct−1
状态是否保留到当前时刻的ct
当中。具体是通过一个 Sigmoid 层来实现。它通过查看ht−1
和xt
信息来输出一个[0,1]
之间的向量,该向量的值表示单元状态Ct−1
中哪些信息保留或丢弃。如图6
所示:
图6 遗忘门示意图
它的输入为上一时刻的输出ht−1
与当前时刻的输入xt
,经过 Sigmoid 函数变换,得到内部当前时刻输出ft
。 具体公式表达如下 :
ft=σ(Wf[ht−1,xt]+bf)
其中,Wf 表示遗忘门的权值矩阵,[ht−1,xt]表示两个向量纵向连接操作,bf
表示输入的偏置项。
通过遗忘门决定神经元中什么信息保留下来后,我们现在需要确定当前的输入xt
有多少信息需要保存到当前的单元状态ct
中,此功能结构为输入门。这里包含两个部分:第一,Sigmoid 层决定那些输入将要被更新。第二,一个 Tanh 层生成一个新的候选向量C~t
。这两部分的输出进行逐点相乘,从而对单元状态ct
进行更新。输入门结构如图7
所示:
图7 输入门示意图
根据图7
所示,输入门的计算公式如下:
it=σ(Wi[st−1,xt]+bi)c~t=tanh(Wc[st−1,xt]+bc)
在计算it
与c~t
时,它们的权值矩阵是不同的,因此在训练的过程中需要单独训练。 通过遗忘门与输入门的计算后,我们可以对单元状态ct
进行更新操作。操作方法如图8
所示:
图8 更新单元状态示意图
计算公式如下:
c=ft∘ct−1+it∘c~t
∘
表示按元素逐乘操作。
更新完单元状态后需要根据ht−1
和xt
来考虑如何将当前的信息进行输出,这部分功能由输出门完成。输出门主要来控制单元状态ct
有多少可以输出到长短时记忆网络的当前输出值ht
中,如图9
所示:
图9 输出门示意图
该单元的输出主要依赖当前的神经元状态ct
,不只是单纯依赖单元状态,还需要进行一次信息过滤的处理,即由引入的 Sigmoid 层来完成。这一层将单元状态经过 Tanh 层处理后的数据进行元素相乘操作,将得到的ht
有选择地输出到下一时刻和对外输出。具体计算公式如下 :
ot=σ(Wo[st−1,xt]+bo)st=ot∘tanh(ct)
LSTM 的实现步骤:
F=sigmoid(WxfX+WhfH+bf)
IC~C=sigmoid(WxiX+WhiH+bi)=tanh(WxcX+WhcH+bc)=F∘C+I∘C~
OHnew=sigmoid(WxoX+WhoH+bo)=O∘tanh(C)
Y=WhqHnew+bq
根据提示,在右侧编辑器 Begin-End 区间补充代码,编写实现 LSTM 的遗忘门、输入门、输出门。
平台会对你编写的代码进行测试:
测试输入:无 预期输出: True
开始你的任务吧,祝你成功!
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。