当前位置:   article > 正文

循环神经网络 (处理时序型数据)_处理时序数据的神经网络

处理时序数据的神经网络

1. RNN 存在的理由

循环神经网络(Recurrent Neural Network, RNN)是一种处理时序型输入的神经网络。它被广泛应用在语音识别、机器翻译、人名识别、文本生成等任务上。这些任务要处理的都是时序型的数据,一言以蔽之,这些时序型数据有着 输入是不定长度的输入的上下文是由关联的 的特征。

经典的深度神经网络(NN)通过构建多层全连接的神经网络来对固定的多维输入进行预测;卷积神经网络(CNN)通过设计不同的卷积和池化层组合来对图像(网格化输入)进行预测。但在实际应用中,这两种神经网络均很难处理时序型数据,因为:

  1. 它们的网络结构的输入维度是固定的,不能处理不同长度的输入;
  2. 它们的网络结构忽略了输入节点间的横向联系,如上下文关系等。

这两个缺点恰好妨碍了 NN 和 CNN 对于时序型数据的处理。于是,一种更新的能够处理不同维度输入的、且能利用前后顺序关系的网络结构就问世了,它就是循环神经网络 RNN。换句话说,RNN 存在的理由就是为了处理时序型数据

2. RNN 的基本结构

以机器翻译为例,句子中单词的翻译跟语境十分相关,也可以认为要翻译的单词含义与之前出现的单词十分相关,要想预测当前词的含义就必须 “记住” 之前出现过的词。我们以一个简单的语句 “John likes to drink the spring water.” 为例。通过上下文(如 “drink”),人类马上就知道这里的 “spring” 指的是 “泉水” 的意思(而非 “春天”),但是要让机器理解上下文则相对困难。
在这里插入图片描述
我们将每个单词视作一个输入,即这 7 个单词分别对应于 x < 1 > , x < 2 > , . . . , x < 7 > x^{<1>},x^{<2>},...,x^{<7>} x<1>,x<2>,...,x<7>,同时,记句子的输入长度为 T x T_x Tx,即 T x = 7 T_x=7 Tx=7。那么这 7 个单词的输出分别对应于 y ^ < 1 > , y ^ < 2 > , . . . , y ^ < 7 > \hat{y}^{<1>},\hat{y}^{<2>},...,\hat{y}^{<7>} y^<1>,y^<2>,...,y^<7>,输出的总长度为 T y T_y Ty,实际情况下允许 T x ≠ T y T_x \neq T_y Tx=Ty。同时,每个输入都有一个实时状态,记为 a < i > , i = 0 , 1 , 2 , 3 , 4 , 5 a^{<i>}, i=0,1,2,3,4,5 a<i>,i=0,1,2,3,4,5。这个实时状态其实就是 RNN 要记下来的前文知识。
在这里插入图片描述
上图展示了 RNN 两种不同的表达方式,图 (a) 常见于各种论文中,理解起来较难,将其横向展开就得到了(b);图 (b) 中的模式更适合大部分人的理解,它表示了 RNN 中信息的流动方式。

给定初始状态 a < 0 > a^{<0>} a<0> ,RNN 将依次从左往右传播,且当前状态 a < i > a^{<i>} a<i> 只与当前输入 x < i > x^{<i>} x<i> 和前个时刻的状态 a < i − 1 > a^{<i-1>} a<i1> 相关。如果当前 RNN 节点有输出值 y ^ < i > \hat{y}^{<i>} y^<i>,则该输出值与当前的状态 a < i > a^{<i>} a<i> 相关。具体的节点结构和计算方式见下节。

这里值得一提的是,在 RNN 中我们没有强制规定输入维度 T x T_x Tx 一定等于输出维度 T y T_y Ty,事实上根据处理问题的不同,二者存在多种关系,如下图所示的 many-to-one(分类), one-to-many(文本生成), many-to-many (机器翻译,摘要生成)等不同模式。
在这里插入图片描述

2.1 RNN 的前向传播方式 - Forward

下图展示了 RNN 第 t 时刻的节点的输入输出情况。为了更好的计算当前状态值,RNN 设定全局共享的参数 W a a W_{aa} Waa W a x W_{ax} Wax(即所有的 RNN 节点参数值均相同),通过激活函数 g 1 g_1 g1 来计算出第 t 时刻的状态值 a < t > a^{<t>} a<t>,通常这里的 g 1 g_1 g1 函数会被设置为 ReLU 或者 Tanh 函数。另外,计算 y ^ < t > \hat{y}^{<t>} y^<t> 时使用的 g 2 g2 g2 函数根据预测任务而定,例如在分类情况下使用 Softmax 函数。
在这里插入图片描述
实际的计算中为了节省计算量,在函数 g 1 g_1 g1 中,我们将 a < t − 1 > a^{<t-1>} a<t1> x < t > x^{<t>} x<t> 做一个垂直方向的拼接,这样参数 W a a W_{aa} Waa W a x W_{ax} Wax 就可以整合成一个参数 W a W_a Wa 来计算了。举个简单的例子,假设我们的输入 x < t > x^{<t>} x<t> 都是 1000 维的,状态 a < t > a^{<t>} a<t> 都是 100 维,那么 W a a W_{aa} Waa 应该是 100x100 的尺寸, W a x W_{ax} Wax 则是 100x1000 的尺寸,拼接起来的 W a W_a Wa 则是 100x1100 尺寸的。先假设 W y W_y Wy 的尺寸是 64x100 ,那么最终的 y ^ < t > \hat{y}^{<t>} y^<t> 的尺寸就是 64x1。这里的 “64” 就是 RNN 节点隐藏层的节点个数,它决定了输出的尺寸大小。

2.2 RNN 的反向传播方式 - Backward

RNN 与之前标准神经网络类似,依旧使用链式法则来反向更新参数。不同的是由于 RNN 每个节点都公用参数 W a , W y W_a, W_y Wa,Wy,因此求参数偏导时需要累加之前的偏导数的值,RNN 的更新方式也被称为 BPTT 算法。

为了书写方便,我们将上一小节的 g 2 g_2 g2 函数暂时省略。同时确定 RNN 的损失函数 L ( y ^ k , y k ) L(\hat{y}_{k}, y_k) L(y^k,yk) 如下均方差形式(假设 RNN 只有最后 T 时刻的输出,即 many-to-one),

L ( y ^ , y ) = 1 2 ( y ^ < T > − y < T > ) 2 L(\hat{y}, y) = \frac{1}{2} (\hat{y}^{<T>} - y^{<T>})^2 L(y^,y)=21(y^<T>y<T>)2

其中 y ^ < T > \hat{y}^{<T>} y^<T> 表示样本在时刻 T T T 的实际输出, y < T > y^{<T>} y<T> 表示样本在时刻 T T T 的预期输出。当 t ∈ [ 1 , T ] t \in [1,T] t[1,T] 时,前向传播过程可以记为,

a < t > = g 1 ( W a a   a < t − 1 > + W a x   x < t > + b x ) a^{<t>} = \bm{g_1}(W_{aa} \ a^{<t-1>} + W_{ax} \ x^{<t>} + b_x) a<t>=g1(Waa a<t1>+Wax x<t>+bx)

y ^ < t > = W y   a < t > + b y \hat{y}^{<t>} = W_y \ a^{<t>} + b_y y^<t>=Wy a<t>+by

通过上述两式可知,当我们使用 BPTT 算法对 W a a W_{aa} Waa W a x W_{ax} Wax 求导时,不但要考虑 t t t 时刻的情况,还要考虑 t t t 时刻以前的状况,因为 W a a W_{aa} Waa W a x W_{ax} Wax 影响了前面时刻的状态 a < t > a^{<t>} a<t>。于是乎,求这2个参数的偏导的过程是一个累加的过程。如下图的红色箭头所指。
在这里插入图片描述
首先,对 W y W_{y} Wy 求偏导。由于 W y W_{y} Wy 只与 y ^ < T > \hat{y}^{<T>} y^<T> 相关,因此其偏导可写为,

∂ L ( y ^ , y ) ∂ W y = ∂ L ∂ y ^ < T > ∂ y ^ < T > W y \frac{\partial L(\hat y, y)}{\partial{W_y}} = \frac{\partial{L}}{\partial{\hat{y}^{<T>}}}\frac{\partial{\hat{y}^{<T>}}}{W_y} WyL(y^,y)=y^<T>LWyy^<T>

然后,对 W a x W_{ax} Wax 求偏导。由于 E k E_k Ek 是关于 T y T_y Ty 个输出 y ^ k < t > \hat{y}^{<t>}_k y^k<t> 的函数,因此是一个累加的形式,

∂ L ( y ^ , y ) ∂ W a x = ∂ L ∂ y ^ < T > ∂ y ^ < T > ∂ a < T > ∂ a < T > ∂ W a x + ∂ L ∂ y ^ < T > ∂ y ^ < T > ∂ a < T > ∂ a < T > ∂ a < T − 1 > ∂ a < T − 1 > ∂ W a x + . . . + ∂ L ∂ y ^ < T > ∂ y ^ k < T > ∂ a < T > ∂ a < T − 1 > ∂ a < T − 2 > . . ∂ a < 1 > ∂ W a x \frac{\partial L(\hat y, y)}{\partial W_{ax}} = \frac{\partial L}{\partial \hat{y}^{<T>}} \frac{\partial \hat{y}^{<T>}}{\partial a^{<T>} }\frac{\partial a^{<T>}}{\partial W_{ax}} + \frac{\partial L}{\partial \hat{y}^{<T>}} \frac{\partial \hat{y}^{<T>}}{\partial a^{<T>}} \frac{\partial a^{<T>}}{\partial a^{<T-1>}} \frac{\partial a^{<T-1>}}{\partial W_{ax}}+ \\ ... + \frac{\partial L}{\partial \hat{y}^{<T>}} \frac{\partial \hat{y}^{<T>}_k}{\partial a^{<T>}} \frac{\partial a^{<T-1>}}{\partial a^{<T-2>}}..\frac{\partial a^{<1>}}{\partial W_{ax}} WaxL(y^,y)=y^<T>La<T>y^<T>Waxa<T>+y^<T>La<T>y^<T>a<T1>a<T>Waxa<T1>+...+y^<T>La<T>y^k<T>a<T2>a<T1>..Waxa<1>

上式可化简为,

∂ L ( y ^ , y ) ∂ W a x = ∑ i = 1 T ∂ L ( y ^ , y ) ∂ y ^ < T > ∂ y ^ < T > ∂ a < T > ∂ a < T > ∂ a < i > ∂ a < i > ∂ W a x                           = ∑ i = 1 T ∂ L ( y ^ , y ) ∂ y ^ < T > ∂ y ^ < T > ∂ a < T > ∏ k = i T − 1 ∂ a < k + 1 > ∂ a < k > ∂ a < i > ∂ W a x \frac{\partial L(\hat y, y)}{\partial W_{ax}} = \sum_{i=1}^{T}\frac{\partial L(\hat y, y)}{\partial \hat y^{<T>}}\frac{\partial \hat y^{<T>}}{\partial a^{<T>}}\frac{\partial a^{<T>}}{\partial a^{<i>}}\frac{\partial a^{<i>}}{\partial W_{ax}} \\ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ =\sum_{i=1}^{T}\frac{\partial L(\hat y, y)}{\partial \hat y^{<T>}}\frac{\partial \hat y^{<T>}}{\partial a^{<T>}}\prod^{T-1}_{k=i}\frac{\partial a^{<k+1>}}{\partial a^{<k>}} \frac{\partial a^{<i>}}{\partial W_{ax}} WaxL(y^,y)=i=1Ty^<T>L(y^,y)a<T>y^<T>a<i>a<T>Waxa<i>                         =i=1Ty^<T>L(y^,y)a<T>y^<T>k=iT1a<k>a<k+1>Waxa<i>

更加细致的推导请参考博客1,2。同理,我们也可以利用链式求导法则求出 W a a W_{aa} Waa 的偏导数,

∂ L ( y ^ , y ) ∂ W a a = ∑ i = 1 T ∂ L ( y ^ , y ) ∂ y ^ < T > ∂ y ^ < T > ∂ a < T > ∏ k = i T − 1 ∂ a < k + 1 > ∂ a < k > ∂ a < i > ∂ W a a \frac{\partial L(\hat y, y)}{\partial W_{aa}} =\sum_{i=1}^{T}\frac{\partial L(\hat y, y)}{\partial \hat y^{<T>}}\frac{\partial \hat y^{<T>}}{\partial a^{<T>}}\prod^{T-1}_{k=i}\frac{\partial a^{<k+1>}}{\partial a^{<k>}} \frac{\partial a^{<i>}}{\partial W_{aa}} WaaL(y^,y)=i=1Ty^<T>L(y^,y)a<T>y^<T>k=iT1a<k>a<k+1>Waaa<i>

综上可知,求解 W a x W_{ax} Wax W a a W_{aa} Waa 的梯度是复杂的累积之后,再累加的一个过程,其中连续的累积部分( ∏ k = i T − 1 ∂ a < k + 1 > ∂ a < k > \prod^{T-1}_{k=i}\frac{\partial a^{<k+1>}}{\partial a^{<k>}} k=iT1a<k>a<k+1>)会造成“梯度爆炸”和“梯度消失”问题。因为 RNN 使用的是 Tanh 函数为 g 1 g1 g1 函数,因此 ∂ a < k + 1 > ∂ a < k > \frac{\partial a^{<k+1>}}{\partial a^{<k>}} a<k>a<k+1> 的值如果接近0,累积之后会更加接近于0,这就造成 RNN 的梯度只与靠近输出部分的梯度有很大关系,而越远的梯度则对总梯度影响不大。

2.3 RNN 到底有没有梯度消失?

“RNN 有梯度消失,LSTM解决了它” 可能是对 RNN 或者 LSTM 最经典的误解3。事实上,RNN 的 “梯度消失” 和传统的 NN 的 “梯度消失” 含义不同:

  1. 在传统的 NN 网络中,各层之间的参数 W i W_i Wi 不相同,因此在 BP 算法中,各求各的梯度,因此里输出层较远的参数的梯度会越来越小,直至为零。我们成这个为梯度消失 Gradient Vanishing。
  2. 在 RNN 中,参数 W a a , W a x , W y W_aa, W_ax, W_y Waa,Wax,Wy 在每层之间是共享的,求解他们的梯度其实是一个累加求和的过程。因此对于参数梯度而言,更多的是由里输出层近的梯度所决定的,而较远的层的梯度对总体度几乎没有影响。RNN 所谓梯度消失的真正含义是:梯度被近距离梯度主导,导致模型难以学到远距离的依赖关系。

综上,虽然随着层数的增加,越远离输出层的梯度会减少,但是总的梯度和是不会消失的。因此 RNN 中不存在传统的梯度消失问题的!。因此,RNN 的弱点并不是梯度为趋近于0或消失,而是“健忘”,记不住较远距离对其的影响。

3. RNN 变体 - 长短记忆网络 LSTM

为了解决 RNN 训练过程中的 “健忘” 的问题,人们提出了 长短期记忆网络(Long short-term memory, LSTM4)的网络结构。LSTM 是 RNN 的一种变体,相比传统的 RNN 网络结构,LSTM 能够在更长的序列中有更好的表现。RNN 与 LSTM 的结构对比如下,可见 LSTM 肉眼可见的复杂了,每个节点有 2 个输入和 2 个输出。下图中的 h t h_t ht 指的就是 t t t 时刻的输出, x t x_t xt 指的是 t t t 时刻的输入, C t C_t Ct t t t 时刻的状态。
在这里插入图片描述

3.1 LSTM 的“门”结构

LSTM 结构最大的创新点在于它引入了 “门” 这个网络结构。具体来讲,“门” 结构相当于一个阈值控制机关,它控制了当前信息有多少比例可以通过。LSTM 设置有 “遗忘门”,“输入门”,“输出门” 3 个门结构,他们分别控制了之前状态,当前输入和当前状态有多少信息被保留下来。

在这里插入图片描述
遗忘门 Forget Gate: 遗忘门决定了前一时刻状态 C t − 1 C_{t-1} Ct1 有多少信息保留到当前状态 C t C_t Ct。遗忘门的输入是前一时刻输出 h t − 1 h_{t-1} ht1 和当前输入 x t x_t xt,其输出 f t f_t ft 表示应该保留的比例, σ \sigma σ 符号表示 Sigmoid 函数,其输出是 0 至 1 间的实数,数值越大表示信息保留的越多。

在这里插入图片描述
输入门 Input Gate: 输入门决定当前输入 x t x_t xt 有多少信息输入到 C t C_t Ct 状态中。这个步骤分为两步,首先通过 h t − 1 h_{t-1} ht1 x t x_t xt 生成 i t i_t it C ~ t \widetilde{C}_t C t,然后与前一时刻状态 C t − 1 C_{t-1} Ct1 保留的信息相加,从而得到当前时刻的状态 C t C_t Ct。注意这里的 σ \sigma σ 函数的输出范围是 0 到 1,而 t a n h tanh tanh 函数的输出范围是 -1 到 1。因此, i t i_t it 是一个0-1的实数,而 C ~ t \widetilde{C}_t C t 则表示的是向量。

在这里插入图片描述
输出门 Output Gate: 输出门决定了当前时刻的输入 x t x_t xt 有多少比例影响当前时刻的输出 h t h_t ht。由图可知 h t h_t ht 的值不但跟 h t − 1 h_{t-1} ht1 x t x_t xt 有关,还与当前时刻的状态 C t C_t Ct 有关。

在这里插入图片描述
根据上述的几个步骤,可以清晰的了解到 LSTM 节点(或细胞)中的信息流向。换个角度看,标准的 RNN 中,状态与输入是相同的,即 h t = C t h_t = C_t ht=Ct,而 LSTM 则是分开计算的;其实在实际应用中,人们往往会直接使用 LSTM 而非标准的 RNN。

3.2 LSTM 如何解决梯度消失问题?

回到本小节开头部分,LSTM 的提出是为了解决传统 RNN 存在的 “健忘” 的问题。BPTT 算法中那一段累积部分 ∂ a < t + 1 > ∂ a < t > \frac{\partial a^{<t+1>}}{\partial a^{<t>}} a<t>a<t+1> 就是罪魁祸首。根据 LSTM 的网络结构,该段状态 C C C 的累积部分变为,

∂ C t + 1 ∂ C t = f t \frac{\partial C_{t+1}}{\partial C_{t}}=f_t CtCt+1=ft

关于输出 h h h 的累积则略微复杂一点,

∂ h t + 1 ∂ h t = ∂ h t ∂ o t ∂ o t ∂ h t − 1 + ∂ h t ∂ C t ∂ C t ∂ C ~ t ∂ C ~ t ∂ h t − 1 + ∂ h t ∂ C t ∂ C t ∂ i t ∂ i t ∂ h t − 1 + ∂ h t ∂ C t ∂ C t ∂ f t ∂ f t ∂ h t − 1 \frac{\partial h_{t+1}}{\partial h_{t}} = \frac{\partial h_t}{\partial o_t}\frac{\partial o_t}{\partial h_{t-1}} + \frac{\partial h_t}{\partial C_t}\frac{\partial C_t}{\partial \widetilde{C}_t}\frac{\partial \widetilde{C}_t}{\partial h_{t-1}} + \frac{\partial h_t}{\partial C_t}\frac{\partial C_t}{\partial i_t}\frac{\partial i_t}{\partial h_{t-1}} + \frac{\partial h_t}{\partial C_t}\frac{\partial C_t}{\partial f_t}\frac{\partial f_t}{\partial h_{t-1}} htht+1=oththt1ot+CthtC tCtht1C t+CthtitCtht1it+CthtftCtht1ft

(注:本小节的图片均来自 Christopher Olah 的博文 Understanding LSTM Networks. )

4. RNN 变体 - 双向循环神经网络 BRNN

双向循环神经网络(Biodirectional Recurrent Neural NetworkBRNN)也是 RNN 的一种变体,与传统 RNN 相比,BRNN 增加了后续时刻的输入对当前状态的影响。设想在人名识别的任务重有下属两个句子,

He said, “Teddy bears are on sale!”
He said, “Teddy Roosevelt was not a good President!”

(注:Teddy Roosevelt 西奥多·罗斯福(大罗斯福)是美国第 26 任总统。他的侄子 富兰克林·罗斯福 (小罗斯福)是美国第 32 任总统,并且是二战中重要的同盟国领袖之一,他成功连任 4 次美国总统!)

两个语句的开头都是一样的,但很明显第一个句子中的 Teddy 指的是泰迪熊而非人名,第二个句子的 Teddy 才指的是名字。由于 2 个句子开头都是 He said,故使用标准 RNN 则很难判断出来谁是人名。造成 RNN 瓶颈的原因正式 标准 RNN 不能够联系后文的内容

为了解决 RNN 的这个瓶颈,人们于是提出了双向 RNN 网络,即 BRNN。在 BRNN 中 影响单个时刻输出的不但有之前的内容,并且也有后续的内容。一个标准 BRNN 的结构如下所示, a → < t > \overrightarrow{a}^{<t>} a <t> 表示正向传播时的 t t t 时刻状态, a ← < t > \overleftarrow{a}^{<t>} a <t> 表示反向传播时 t t t 时刻的状态。
在这里插入图片描述
当我们计算 y < 3 > y^{<3>} y<3> 的输出时,我们首先通过输入 x < 1 > , x < 2 > x^{<1>}, x^{<2>} x<1>,x<2> 得到 1 和 2 时刻的状态 a → < 1 > , a → < 2 > \overrightarrow{a}^{<1>},\overrightarrow{a}^{<2>} a <1>,a <2>,正向传播得到当前正向的状态 a → < 3 > \overrightarrow{a}^{<3>} a <3>;再结合输入 x < 4 > x^{<4>} x<4> 的输入得到 4 时刻的状态 a ← < 4 > \overleftarrow{a}^{<4>} a <4>,反向传播得到当前正向的状态 a ← < 3 > \overleftarrow{a}^{<3>} a <3>。最终 y < 3 > y^{<3>} y<3> 是由 a → < 3 > , a ← < 3 > \overrightarrow{a}^{<3>}, \overleftarrow{a}^{<3>} a <3>,a <3> 共同决定的,这就是 BRNN 的基本思想。

5. 小结

RNN 网络主要处理时序型数据,通过给不同时段的输入之间建立联系,RNN 生成最终的预测结果(many-to-one 或其他模式)。他具有一定的记忆功能,但同时记忆功能有待改进。它的后续改进版本 LSTM 注重于记忆更长时刻的知识,BRNN 注重于利用后续时刻的输入知识。当然 RNN 很多的改进版本,如 GRU 或 深层 RNN 网络。RNN 的提出给语音识别、机器翻译,文本生成等时序型数据的处理提供了基石方法,期待未来会在此基础上提出更厉害的网络模型。


附录

在 Pytorch 中定义一个 RNN 方式如下,注意如果没有给 RNN 喂 h0 变量的话,那么 RNN 被会自动喂一个全 0 的向量作为 h0。

import torch
import torch.nn as nn
from torch.autograd import Variable

toy_input = Variable(torch.rand(100, 32, 20))  # t 时刻的输入,(seq, batch_size, input_size)
h0 = Variable(torch.rand(2, 32, 50))  # 0 时刻的状态输入,(layer_num*direction_num, batch_size, hidden_size)
basic_rnn = nn.RNN(input_size=20,  # 输入 xt 的维度
                   hidden_size=50,  # 隐藏层 h 的维度
                   bidirectional=False,  # 是否是双向
                   num_layers=2,
                   batch_first=False)

rnn_out, rnn_hn = basic_rnn(toy_input, h0)
print('RNN Network ===>', basic_rnn)
# RNN 输出的尺寸 (seq, bath_size, hidden_size), 即 (100, 32, 50)
print('Output:', rnn_out.size())  
# RNN 状态的尺寸 (layer_num*direction_num, batch_size, hidden_size), 即 (2, 32, 50)
print('Hidden state:', rnn_hn.size())
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18

定义 LSTM 和 RNN 稍微有些不同,因为 LSTM 的输入状态除了 h0 之外中多了一个 c0。

toy_input = Variable(torch.rand(100, 32, 20))  # t 时刻的输入,(seq, batch_size, input_size)
h0 = Variable(torch.rand(2, 32, 50))  # 0 时刻的状态 h 输入,(layer_num*direction_num, batch_size, hidden_size)
c0 = Variable(torch.rand(2, 32, 50))  # 0 时刻的状态 c 输入,(layer_num*direction_num, batch_size, hidden_size)
basic_lstm = nn.LSTM(input_size=20,
                     hidden_size=50,
                     bidirectional=False,
                     num_layers=2,
                     batch_first=False)

lstm_out, (lstm_hn, lstm_cn) = basic_lstm(toy_input, (h0, c0))  # LSTM 比 RNN 的输入/输出多了状态 C0
print('LSTM Network ===>', basic_lstm)
# LSTM 输出的尺寸 (seq, bath_size, hidden_size), 即 (100, 32, 50)
print('Output:', lstm_out.size())
# LSTM 状态的尺寸 (layer_num*direction_num, batch_size, hidden_size), 即 (2, 32, 50)
print('Hidden state:', lstm_hn.size(), lstm_cn.size())
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15

定义了基本的 RNN 或者 LSTM(类似的还有 GRU)之后,我们可以利用他们构建一个稍微复杂的分类网络。该网络首先由 LSTM 挖掘出特征,再由 LR 模型进行分类,其定义如下,

class RNNetwork(nn.Module):
    """搭建一个基于 RNN 的神经网络"""
    def __init__(self, input_size, hidden_size, class_num):
        super(RNNetwork, self).__init__()
        # RNN 网络提取特征
        self.feature = nn.RNN(input_size=input_size, hidden_size=hidden_size, num_layers=2,
                              bidirectional=False, batch_first=False)
        # LR 网络进行分类
        self.classifier = nn.Sequential(
            nn.Linear(in_features=hidden_size, out_features=120),
            nn.ReLU(),
            nn.Linear(in_features=120, out_features=class_num),
            nn.Softmax(dim=1)
        )

    def forward(self, x):
        out, _ = self.feature(x)  # out = (seq_num, batch_size, hidden_size)
        out = self.classifier(out[-1, :, :])
        return out
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19

  1. 习翔宇. “RNN/LSTM BPTT详细推导以及梯度消失问题分析.” Link ↩︎

  2. Hiroki. “BPTT 算法推导.” Link ↩︎

  3. 知乎. “LSTM 如何来避免梯度消失和梯度爆炸的?” Link ↩︎

  4. Sepp Hochreiter, Jürgen Schmidhuber. “Long Short-term Memory.” Link ↩︎

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/小丑西瓜9/article/detail/360291
推荐阅读
相关标签
  

闽ICP备14008679号