赞
踩
循环神经网络(Recurrent Neural Network, RNN)是一种处理时序型输入的神经网络。它被广泛应用在语音识别、机器翻译、人名识别、文本生成等任务上。这些任务要处理的都是时序型的数据,一言以蔽之,这些时序型数据有着 输入是不定长度的,输入的上下文是由关联的 的特征。
经典的深度神经网络(NN)通过构建多层全连接的神经网络来对固定的多维输入进行预测;卷积神经网络(CNN)通过设计不同的卷积和池化层组合来对图像(网格化输入)进行预测。但在实际应用中,这两种神经网络均很难处理时序型数据,因为:
这两个缺点恰好妨碍了 NN 和 CNN 对于时序型数据的处理。于是,一种更新的能够处理不同维度输入的、且能利用前后顺序关系的网络结构就问世了,它就是循环神经网络 RNN。换句话说,RNN 存在的理由就是为了处理时序型数据。
以机器翻译为例,句子中单词的翻译跟语境十分相关,也可以认为要翻译的单词含义与之前出现的单词十分相关,要想预测当前词的含义就必须 “记住” 之前出现过的词。我们以一个简单的语句 “John likes to drink the spring water.” 为例。通过上下文(如 “drink”),人类马上就知道这里的 “spring” 指的是 “泉水” 的意思(而非 “春天”),但是要让机器理解上下文则相对困难。
我们将每个单词视作一个输入,即这 7 个单词分别对应于
x
<
1
>
,
x
<
2
>
,
.
.
.
,
x
<
7
>
x^{<1>},x^{<2>},...,x^{<7>}
x<1>,x<2>,...,x<7>,同时,记句子的输入长度为
T
x
T_x
Tx,即
T
x
=
7
T_x=7
Tx=7。那么这 7 个单词的输出分别对应于
y
^
<
1
>
,
y
^
<
2
>
,
.
.
.
,
y
^
<
7
>
\hat{y}^{<1>},\hat{y}^{<2>},...,\hat{y}^{<7>}
y^<1>,y^<2>,...,y^<7>,输出的总长度为
T
y
T_y
Ty,实际情况下允许
T
x
≠
T
y
T_x \neq T_y
Tx=Ty。同时,每个输入都有一个实时状态,记为
a
<
i
>
,
i
=
0
,
1
,
2
,
3
,
4
,
5
a^{<i>}, i=0,1,2,3,4,5
a<i>,i=0,1,2,3,4,5。这个实时状态其实就是 RNN 要记下来的前文知识。
上图展示了 RNN 两种不同的表达方式,图 (a) 常见于各种论文中,理解起来较难,将其横向展开就得到了(b);图 (b) 中的模式更适合大部分人的理解,它表示了 RNN 中信息的流动方式。
给定初始状态 a < 0 > a^{<0>} a<0> ,RNN 将依次从左往右传播,且当前状态 a < i > a^{<i>} a<i> 只与当前输入 x < i > x^{<i>} x<i> 和前个时刻的状态 a < i − 1 > a^{<i-1>} a<i−1> 相关。如果当前 RNN 节点有输出值 y ^ < i > \hat{y}^{<i>} y^<i>,则该输出值与当前的状态 a < i > a^{<i>} a<i> 相关。具体的节点结构和计算方式见下节。
这里值得一提的是,在 RNN 中我们没有强制规定输入维度
T
x
T_x
Tx 一定等于输出维度
T
y
T_y
Ty,事实上根据处理问题的不同,二者存在多种关系,如下图所示的 many-to-one(分类), one-to-many(文本生成), many-to-many (机器翻译,摘要生成)等不同模式。
下图展示了 RNN 第 t 时刻的节点的输入输出情况。为了更好的计算当前状态值,RNN 设定全局共享的参数
W
a
a
W_{aa}
Waa 和
W
a
x
W_{ax}
Wax(即所有的 RNN 节点参数值均相同),通过激活函数
g
1
g_1
g1 来计算出第 t 时刻的状态值
a
<
t
>
a^{<t>}
a<t>,通常这里的
g
1
g_1
g1 函数会被设置为 ReLU 或者 Tanh 函数。另外,计算
y
^
<
t
>
\hat{y}^{<t>}
y^<t> 时使用的
g
2
g2
g2 函数根据预测任务而定,例如在分类情况下使用 Softmax 函数。
实际的计算中为了节省计算量,在函数
g
1
g_1
g1 中,我们将
a
<
t
−
1
>
a^{<t-1>}
a<t−1> 和
x
<
t
>
x^{<t>}
x<t> 做一个垂直方向的拼接,这样参数
W
a
a
W_{aa}
Waa 和
W
a
x
W_{ax}
Wax 就可以整合成一个参数
W
a
W_a
Wa 来计算了。举个简单的例子,假设我们的输入
x
<
t
>
x^{<t>}
x<t> 都是 1000 维的,状态
a
<
t
>
a^{<t>}
a<t> 都是 100 维,那么
W
a
a
W_{aa}
Waa 应该是 100x100 的尺寸,
W
a
x
W_{ax}
Wax 则是 100x1000 的尺寸,拼接起来的
W
a
W_a
Wa 则是 100x1100 尺寸的。先假设
W
y
W_y
Wy 的尺寸是 64x100 ,那么最终的
y
^
<
t
>
\hat{y}^{<t>}
y^<t> 的尺寸就是 64x1。这里的 “64” 就是 RNN 节点隐藏层的节点个数,它决定了输出的尺寸大小。
RNN 与之前标准神经网络类似,依旧使用链式法则来反向更新参数。不同的是由于 RNN 每个节点都公用参数 W a , W y W_a, W_y Wa,Wy,因此求参数偏导时需要累加之前的偏导数的值,RNN 的更新方式也被称为 BPTT 算法。
为了书写方便,我们将上一小节的 g 2 g_2 g2 函数暂时省略。同时确定 RNN 的损失函数 L ( y ^ k , y k ) L(\hat{y}_{k}, y_k) L(y^k,yk) 如下均方差形式(假设 RNN 只有最后 T 时刻的输出,即 many-to-one),
L ( y ^ , y ) = 1 2 ( y ^ < T > − y < T > ) 2 L(\hat{y}, y) = \frac{1}{2} (\hat{y}^{<T>} - y^{<T>})^2 L(y^,y)=21(y^<T>−y<T>)2
其中 y ^ < T > \hat{y}^{<T>} y^<T> 表示样本在时刻 T T T 的实际输出, y < T > y^{<T>} y<T> 表示样本在时刻 T T T 的预期输出。当 t ∈ [ 1 , T ] t \in [1,T] t∈[1,T] 时,前向传播过程可以记为,
a < t > = g 1 ( W a a a < t − 1 > + W a x x < t > + b x ) a^{<t>} = \bm{g_1}(W_{aa} \ a^{<t-1>} + W_{ax} \ x^{<t>} + b_x) a<t>=g1(Waa a<t−1>+Wax x<t>+bx)
y ^ < t > = W y a < t > + b y \hat{y}^{<t>} = W_y \ a^{<t>} + b_y y^<t>=Wy a<t>+by
通过上述两式可知,当我们使用 BPTT 算法对
W
a
a
W_{aa}
Waa 或
W
a
x
W_{ax}
Wax 求导时,不但要考虑
t
t
t 时刻的情况,还要考虑
t
t
t 时刻以前的状况,因为
W
a
a
W_{aa}
Waa 或
W
a
x
W_{ax}
Wax 影响了前面时刻的状态
a
<
t
>
a^{<t>}
a<t>。于是乎,求这2个参数的偏导的过程是一个累加的过程。如下图的红色箭头所指。
首先,对
W
y
W_{y}
Wy 求偏导。由于
W
y
W_{y}
Wy 只与
y
^
<
T
>
\hat{y}^{<T>}
y^<T> 相关,因此其偏导可写为,
∂ L ( y ^ , y ) ∂ W y = ∂ L ∂ y ^ < T > ∂ y ^ < T > W y \frac{\partial L(\hat y, y)}{\partial{W_y}} = \frac{\partial{L}}{\partial{\hat{y}^{<T>}}}\frac{\partial{\hat{y}^{<T>}}}{W_y} ∂Wy∂L(y^,y)=∂y^<T>∂LWy∂y^<T>
然后,对 W a x W_{ax} Wax 求偏导。由于 E k E_k Ek 是关于 T y T_y Ty 个输出 y ^ k < t > \hat{y}^{<t>}_k y^k<t> 的函数,因此是一个累加的形式,
∂ L ( y ^ , y ) ∂ W a x = ∂ L ∂ y ^ < T > ∂ y ^ < T > ∂ a < T > ∂ a < T > ∂ W a x + ∂ L ∂ y ^ < T > ∂ y ^ < T > ∂ a < T > ∂ a < T > ∂ a < T − 1 > ∂ a < T − 1 > ∂ W a x + . . . + ∂ L ∂ y ^ < T > ∂ y ^ k < T > ∂ a < T > ∂ a < T − 1 > ∂ a < T − 2 > . . ∂ a < 1 > ∂ W a x \frac{\partial L(\hat y, y)}{\partial W_{ax}} = \frac{\partial L}{\partial \hat{y}^{<T>}} \frac{\partial \hat{y}^{<T>}}{\partial a^{<T>} }\frac{\partial a^{<T>}}{\partial W_{ax}} + \frac{\partial L}{\partial \hat{y}^{<T>}} \frac{\partial \hat{y}^{<T>}}{\partial a^{<T>}} \frac{\partial a^{<T>}}{\partial a^{<T-1>}} \frac{\partial a^{<T-1>}}{\partial W_{ax}}+ \\ ... + \frac{\partial L}{\partial \hat{y}^{<T>}} \frac{\partial \hat{y}^{<T>}_k}{\partial a^{<T>}} \frac{\partial a^{<T-1>}}{\partial a^{<T-2>}}..\frac{\partial a^{<1>}}{\partial W_{ax}} ∂Wax∂L(y^,y)=∂y^<T>∂L∂a<T>∂y^<T>∂Wax∂a<T>+∂y^<T>∂L∂a<T>∂y^<T>∂a<T−1>∂a<T>∂Wax∂a<T−1>+...+∂y^<T>∂L∂a<T>∂y^k<T>∂a<T−2>∂a<T−1>..∂Wax∂a<1>
上式可化简为,
∂ L ( y ^ , y ) ∂ W a x = ∑ i = 1 T ∂ L ( y ^ , y ) ∂ y ^ < T > ∂ y ^ < T > ∂ a < T > ∂ a < T > ∂ a < i > ∂ a < i > ∂ W a x = ∑ i = 1 T ∂ L ( y ^ , y ) ∂ y ^ < T > ∂ y ^ < T > ∂ a < T > ∏ k = i T − 1 ∂ a < k + 1 > ∂ a < k > ∂ a < i > ∂ W a x \frac{\partial L(\hat y, y)}{\partial W_{ax}} = \sum_{i=1}^{T}\frac{\partial L(\hat y, y)}{\partial \hat y^{<T>}}\frac{\partial \hat y^{<T>}}{\partial a^{<T>}}\frac{\partial a^{<T>}}{\partial a^{<i>}}\frac{\partial a^{<i>}}{\partial W_{ax}} \\ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ =\sum_{i=1}^{T}\frac{\partial L(\hat y, y)}{\partial \hat y^{<T>}}\frac{\partial \hat y^{<T>}}{\partial a^{<T>}}\prod^{T-1}_{k=i}\frac{\partial a^{<k+1>}}{\partial a^{<k>}} \frac{\partial a^{<i>}}{\partial W_{ax}} ∂Wax∂L(y^,y)=i=1∑T∂y^<T>∂L(y^,y)∂a<T>∂y^<T>∂a<i>∂a<T>∂Wax∂a<i> =i=1∑T∂y^<T>∂L(y^,y)∂a<T>∂y^<T>k=i∏T−1∂a<k>∂a<k+1>∂Wax∂a<i>
更加细致的推导请参考博客1,2。同理,我们也可以利用链式求导法则求出 W a a W_{aa} Waa 的偏导数,
∂ L ( y ^ , y ) ∂ W a a = ∑ i = 1 T ∂ L ( y ^ , y ) ∂ y ^ < T > ∂ y ^ < T > ∂ a < T > ∏ k = i T − 1 ∂ a < k + 1 > ∂ a < k > ∂ a < i > ∂ W a a \frac{\partial L(\hat y, y)}{\partial W_{aa}} =\sum_{i=1}^{T}\frac{\partial L(\hat y, y)}{\partial \hat y^{<T>}}\frac{\partial \hat y^{<T>}}{\partial a^{<T>}}\prod^{T-1}_{k=i}\frac{\partial a^{<k+1>}}{\partial a^{<k>}} \frac{\partial a^{<i>}}{\partial W_{aa}} ∂Waa∂L(y^,y)=i=1∑T∂y^<T>∂L(y^,y)∂a<T>∂y^<T>k=i∏T−1∂a<k>∂a<k+1>∂Waa∂a<i>
综上可知,求解 W a x W_{ax} Wax 和 W a a W_{aa} Waa 的梯度是复杂的累积之后,再累加的一个过程,其中连续的累积部分( ∏ k = i T − 1 ∂ a < k + 1 > ∂ a < k > \prod^{T-1}_{k=i}\frac{\partial a^{<k+1>}}{\partial a^{<k>}} ∏k=iT−1∂a<k>∂a<k+1>)会造成“梯度爆炸”和“梯度消失”问题。因为 RNN 使用的是 Tanh 函数为 g 1 g1 g1 函数,因此 ∂ a < k + 1 > ∂ a < k > \frac{\partial a^{<k+1>}}{\partial a^{<k>}} ∂a<k>∂a<k+1> 的值如果接近0,累积之后会更加接近于0,这就造成 RNN 的梯度只与靠近输出部分的梯度有很大关系,而越远的梯度则对总梯度影响不大。
“RNN 有梯度消失,LSTM解决了它” 可能是对 RNN 或者 LSTM 最经典的误解3。事实上,RNN 的 “梯度消失” 和传统的 NN 的 “梯度消失” 含义不同:
综上,虽然随着层数的增加,越远离输出层的梯度会减少,但是总的梯度和是不会消失的。因此 RNN 中不存在传统的梯度消失问题的!。因此,RNN 的弱点并不是梯度为趋近于0或消失,而是“健忘”,记不住较远距离对其的影响。
为了解决 RNN 训练过程中的 “健忘” 的问题,人们提出了 长短期记忆网络(Long short-term memory, LSTM4)的网络结构。LSTM 是 RNN 的一种变体,相比传统的 RNN 网络结构,LSTM 能够在更长的序列中有更好的表现。RNN 与 LSTM 的结构对比如下,可见 LSTM 肉眼可见的复杂了,每个节点有 2 个输入和 2 个输出。下图中的
h
t
h_t
ht 指的就是
t
t
t 时刻的输出,
x
t
x_t
xt 指的是
t
t
t 时刻的输入,
C
t
C_t
Ct 是
t
t
t 时刻的状态。
LSTM 结构最大的创新点在于它引入了 “门” 这个网络结构。具体来讲,“门” 结构相当于一个阈值控制机关,它控制了当前信息有多少比例可以通过。LSTM 设置有 “遗忘门”,“输入门”,“输出门” 3 个门结构,他们分别控制了之前状态,当前输入和当前状态有多少信息被保留下来。
遗忘门 Forget Gate: 遗忘门决定了前一时刻状态
C
t
−
1
C_{t-1}
Ct−1 有多少信息保留到当前状态
C
t
C_t
Ct。遗忘门的输入是前一时刻输出
h
t
−
1
h_{t-1}
ht−1 和当前输入
x
t
x_t
xt,其输出
f
t
f_t
ft 表示应该保留的比例,
σ
\sigma
σ 符号表示 Sigmoid 函数,其输出是 0 至 1 间的实数,数值越大表示信息保留的越多。
输入门 Input Gate: 输入门决定当前输入
x
t
x_t
xt 有多少信息输入到
C
t
C_t
Ct 状态中。这个步骤分为两步,首先通过
h
t
−
1
h_{t-1}
ht−1 和
x
t
x_t
xt 生成
i
t
i_t
it 和
C
~
t
\widetilde{C}_t
C
t,然后与前一时刻状态
C
t
−
1
C_{t-1}
Ct−1 保留的信息相加,从而得到当前时刻的状态
C
t
C_t
Ct。注意这里的
σ
\sigma
σ 函数的输出范围是 0 到 1,而
t
a
n
h
tanh
tanh 函数的输出范围是 -1 到 1。因此,
i
t
i_t
it 是一个0-1的实数,而
C
~
t
\widetilde{C}_t
C
t 则表示的是向量。
输出门 Output Gate: 输出门决定了当前时刻的输入
x
t
x_t
xt 有多少比例影响当前时刻的输出
h
t
h_t
ht。由图可知
h
t
h_t
ht 的值不但跟
h
t
−
1
h_{t-1}
ht−1 和
x
t
x_t
xt 有关,还与当前时刻的状态
C
t
C_t
Ct 有关。
根据上述的几个步骤,可以清晰的了解到 LSTM 节点(或细胞)中的信息流向。换个角度看,标准的 RNN 中,状态与输入是相同的,即
h
t
=
C
t
h_t = C_t
ht=Ct,而 LSTM 则是分开计算的;其实在实际应用中,人们往往会直接使用 LSTM 而非标准的 RNN。
回到本小节开头部分,LSTM 的提出是为了解决传统 RNN 存在的 “健忘” 的问题。BPTT 算法中那一段累积部分 ∂ a < t + 1 > ∂ a < t > \frac{\partial a^{<t+1>}}{\partial a^{<t>}} ∂a<t>∂a<t+1> 就是罪魁祸首。根据 LSTM 的网络结构,该段状态 C C C 的累积部分变为,
∂ C t + 1 ∂ C t = f t \frac{\partial C_{t+1}}{\partial C_{t}}=f_t ∂Ct∂Ct+1=ft
关于输出 h h h 的累积则略微复杂一点,
∂ h t + 1 ∂ h t = ∂ h t ∂ o t ∂ o t ∂ h t − 1 + ∂ h t ∂ C t ∂ C t ∂ C ~ t ∂ C ~ t ∂ h t − 1 + ∂ h t ∂ C t ∂ C t ∂ i t ∂ i t ∂ h t − 1 + ∂ h t ∂ C t ∂ C t ∂ f t ∂ f t ∂ h t − 1 \frac{\partial h_{t+1}}{\partial h_{t}} = \frac{\partial h_t}{\partial o_t}\frac{\partial o_t}{\partial h_{t-1}} + \frac{\partial h_t}{\partial C_t}\frac{\partial C_t}{\partial \widetilde{C}_t}\frac{\partial \widetilde{C}_t}{\partial h_{t-1}} + \frac{\partial h_t}{\partial C_t}\frac{\partial C_t}{\partial i_t}\frac{\partial i_t}{\partial h_{t-1}} + \frac{\partial h_t}{\partial C_t}\frac{\partial C_t}{\partial f_t}\frac{\partial f_t}{\partial h_{t-1}} ∂ht∂ht+1=∂ot∂ht∂ht−1∂ot+∂Ct∂ht∂C t∂Ct∂ht−1∂C t+∂Ct∂ht∂it∂Ct∂ht−1∂it+∂Ct∂ht∂ft∂Ct∂ht−1∂ft
(注:本小节的图片均来自 Christopher Olah 的博文 Understanding LSTM Networks. )
双向循环神经网络(Biodirectional Recurrent Neural Network,BRNN)也是 RNN 的一种变体,与传统 RNN 相比,BRNN 增加了后续时刻的输入对当前状态的影响。设想在人名识别的任务重有下属两个句子,
He said, “Teddy bears are on sale!”
He said, “Teddy Roosevelt was not a good President!”
(注:Teddy Roosevelt 西奥多·罗斯福(大罗斯福)是美国第 26 任总统。他的侄子 富兰克林·罗斯福 (小罗斯福)是美国第 32 任总统,并且是二战中重要的同盟国领袖之一,他成功连任 4 次美国总统!)
两个语句的开头都是一样的,但很明显第一个句子中的 Teddy 指的是泰迪熊而非人名,第二个句子的 Teddy 才指的是名字。由于 2 个句子开头都是 He said,故使用标准 RNN 则很难判断出来谁是人名。造成 RNN 瓶颈的原因正式 标准 RNN 不能够联系后文的内容。
为了解决 RNN 的这个瓶颈,人们于是提出了双向 RNN 网络,即 BRNN。在 BRNN 中 影响单个时刻输出的不但有之前的内容,并且也有后续的内容。一个标准 BRNN 的结构如下所示,
a
→
<
t
>
\overrightarrow{a}^{<t>}
a
<t> 表示正向传播时的
t
t
t 时刻状态,
a
←
<
t
>
\overleftarrow{a}^{<t>}
a
<t> 表示反向传播时
t
t
t 时刻的状态。
当我们计算
y
<
3
>
y^{<3>}
y<3> 的输出时,我们首先通过输入
x
<
1
>
,
x
<
2
>
x^{<1>}, x^{<2>}
x<1>,x<2> 得到 1 和 2 时刻的状态
a
→
<
1
>
,
a
→
<
2
>
\overrightarrow{a}^{<1>},\overrightarrow{a}^{<2>}
a
<1>,a
<2>,正向传播得到当前正向的状态
a
→
<
3
>
\overrightarrow{a}^{<3>}
a
<3>;再结合输入
x
<
4
>
x^{<4>}
x<4> 的输入得到 4 时刻的状态
a
←
<
4
>
\overleftarrow{a}^{<4>}
a
<4>,反向传播得到当前正向的状态
a
←
<
3
>
\overleftarrow{a}^{<3>}
a
<3>。最终
y
<
3
>
y^{<3>}
y<3> 是由
a
→
<
3
>
,
a
←
<
3
>
\overrightarrow{a}^{<3>}, \overleftarrow{a}^{<3>}
a
<3>,a
<3> 共同决定的,这就是 BRNN 的基本思想。
RNN 网络主要处理时序型数据,通过给不同时段的输入之间建立联系,RNN 生成最终的预测结果(many-to-one 或其他模式)。他具有一定的记忆功能,但同时记忆功能有待改进。它的后续改进版本 LSTM 注重于记忆更长时刻的知识,BRNN 注重于利用后续时刻的输入知识。当然 RNN 很多的改进版本,如 GRU 或 深层 RNN 网络。RNN 的提出给语音识别、机器翻译,文本生成等时序型数据的处理提供了基石方法,期待未来会在此基础上提出更厉害的网络模型。
在 Pytorch 中定义一个 RNN 方式如下,注意如果没有给 RNN 喂 h0 变量的话,那么 RNN 被会自动喂一个全 0 的向量作为 h0。
import torch import torch.nn as nn from torch.autograd import Variable toy_input = Variable(torch.rand(100, 32, 20)) # t 时刻的输入,(seq, batch_size, input_size) h0 = Variable(torch.rand(2, 32, 50)) # 0 时刻的状态输入,(layer_num*direction_num, batch_size, hidden_size) basic_rnn = nn.RNN(input_size=20, # 输入 xt 的维度 hidden_size=50, # 隐藏层 h 的维度 bidirectional=False, # 是否是双向 num_layers=2, batch_first=False) rnn_out, rnn_hn = basic_rnn(toy_input, h0) print('RNN Network ===>', basic_rnn) # RNN 输出的尺寸 (seq, bath_size, hidden_size), 即 (100, 32, 50) print('Output:', rnn_out.size()) # RNN 状态的尺寸 (layer_num*direction_num, batch_size, hidden_size), 即 (2, 32, 50) print('Hidden state:', rnn_hn.size())
定义 LSTM 和 RNN 稍微有些不同,因为 LSTM 的输入状态除了 h0 之外中多了一个 c0。
toy_input = Variable(torch.rand(100, 32, 20)) # t 时刻的输入,(seq, batch_size, input_size)
h0 = Variable(torch.rand(2, 32, 50)) # 0 时刻的状态 h 输入,(layer_num*direction_num, batch_size, hidden_size)
c0 = Variable(torch.rand(2, 32, 50)) # 0 时刻的状态 c 输入,(layer_num*direction_num, batch_size, hidden_size)
basic_lstm = nn.LSTM(input_size=20,
hidden_size=50,
bidirectional=False,
num_layers=2,
batch_first=False)
lstm_out, (lstm_hn, lstm_cn) = basic_lstm(toy_input, (h0, c0)) # LSTM 比 RNN 的输入/输出多了状态 C0
print('LSTM Network ===>', basic_lstm)
# LSTM 输出的尺寸 (seq, bath_size, hidden_size), 即 (100, 32, 50)
print('Output:', lstm_out.size())
# LSTM 状态的尺寸 (layer_num*direction_num, batch_size, hidden_size), 即 (2, 32, 50)
print('Hidden state:', lstm_hn.size(), lstm_cn.size())
定义了基本的 RNN 或者 LSTM(类似的还有 GRU)之后,我们可以利用他们构建一个稍微复杂的分类网络。该网络首先由 LSTM 挖掘出特征,再由 LR 模型进行分类,其定义如下,
class RNNetwork(nn.Module): """搭建一个基于 RNN 的神经网络""" def __init__(self, input_size, hidden_size, class_num): super(RNNetwork, self).__init__() # RNN 网络提取特征 self.feature = nn.RNN(input_size=input_size, hidden_size=hidden_size, num_layers=2, bidirectional=False, batch_first=False) # LR 网络进行分类 self.classifier = nn.Sequential( nn.Linear(in_features=hidden_size, out_features=120), nn.ReLU(), nn.Linear(in_features=120, out_features=class_num), nn.Softmax(dim=1) ) def forward(self, x): out, _ = self.feature(x) # out = (seq_num, batch_size, hidden_size) out = self.classifier(out[-1, :, :]) return out
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。