赞
踩
一、GRU 概述
GRU 是 LSTM 神经网络的一种效果很好的变体,GRU 保持了 LSTM 的效果同时又使得结构更加简单计算量更小,它较 LSTM 网络的结构更加简单,而且效果也更好,因此也是当前非常流行的一种神经网络。GRU 既然是 LSTM 的变体,因此也是可以解决 RNN 神经网络中的长依赖问题的。
GRU 把 LSTM 中的 forget gate 和 input gate 用 update gate 来替代,把 cell state 和 隐状态 进行合并,在计算当前时刻信息的方法和 LSTM 有所不同。
GRU 中只有两个门:
(1)reset gate(重置门):决定了如何将新的输入信息与前面的记忆相结合(控制前一状态有多少信息被写入到当前的候选集上),reset gate 的值越小说明丢弃的越多。
(2)update gate(更新门):定义了前面记忆保存到当前时间步的量(控制前一时刻的状态信息被被带入到当前状态中的程度),update gate 的值越大说明前一时刻的状态信息带入越多。
GRU 的结构如下图所示:
图中 和 分别代表 update gate 和 reset gate。
GRU 与 LSTM 的区别:
(1)GRU 有两个门(重置门与更新门),而 LSTM 有三个门(输入门、遗忘门和输出门)。
(2)GRU 并不会控制并保留内部记忆(),且没有 LSTM 中的输出门。
(3)LSTM 中的输入与遗忘门对应于 GRU 的更新门,重置门直接作用于前面的隐藏状态。
(4)在计算输出时并不应用二阶非线性。
GRU 原理图:
二、GRU 前向传播
GRU 前向传播公式:
其中,[] 表示两个向量相连,* 表示矩阵的乘积。
三、GRU 更新过程
1、update gate()
在时间步 ,我们首先需要使用以下公式计算更新门 :
其中 为第 个时间步的输入向量,即输入序列 的第 个分量,它会经过一个线性变换(与权重矩阵 相乘)。保存的是前一个时间步 的信息,它同样也会经过一个线性变换。更新门将这两部分信息相加并投入到 Sigmoid() 激活函数中,因此将激活结果压缩到 0 到 1 之间。以下是更新门在整个单元的位置与表示方法:
更新门帮助模型决定到底要将多少过去的信息传递到未来,或到底前一时间步和当前时间步的信息有多少是需要继续传递的。这一点非常强大,因为模型能决定从过去复制所有的信息以减少梯度消失的风险。
2、reset gate()
本质上来说,重置门主要决定了到底有多少过去的信息需要遗忘,我们可以使用以下表达式计算:
该表达式与更新门的表达式是一样的,只不过线性变换的参数和用处不一样而已。下图展示了该运算过程的表示方法:
前面更新门所述, 和 先经过一个线性变换,再相加投入 Sigmoid() 激活函数以输出激活值。
3、其次是计算候选隐藏层(candidate hidden layer),这个 候选隐藏层 和 LSTM 中的 类似,可以看成是当前时刻的新信息,其中 用来控制需要保留多少之前的记忆,如果 为 0,那么 只包含当前词的信息:
输入 与上一时间步信息 先经过一个线性变换,即分别右乘矩阵 。
计算重置门 与 的 Hadamard 乘积,即 与 的对应元素乘积。因为前面计算的重置门是一个由 0 到 1 组成的向量,它会衡量门控开启的大小。例如某个元素对应的门控值为 0,那么它就代表这个元素的信息完全被遗忘掉。该 Hadamard 乘积将确定所要保留与遗忘的以前信息。该计算过程可表示为:
4、最后 控制需要从前一时刻的隐藏层 中遗忘多少信息,需要加入多少当前时刻的隐藏层信息 ,最后得到 ,直接得到最后输出的隐藏层信息,这里与 LSTM 的区别是 GRU 中没有 output gate:
为更新门的激活结果,它同样以门控的形式控制了信息的流入。 与 的 Hadamard 乘积表示前一时间步保留到最终记忆的信息,该信息加上当前记忆保留至最终记忆的信息就等于最终门控循环单元输出的内容。该计算过程可表示为:
如果 reset gate 接近 0,那么之前的隐藏层信息就会丢弃,允许模型丢弃一些和未来无关的信息。
update gate 控制当前时刻的隐藏层输出 需要保留多少之前的隐藏层信息,若 接近 1 相当于我们把之前的隐藏层信息拷贝到当前时刻,可以学习长距离依赖。
一般来说,那些具有短距离依赖的 unit(单元) reset gate 比较活跃(如果 为 1,而 为 0,那么相当于变成了一个标准的 RNN,能处理短距离依赖),具有长距离依赖的 unit(单元) update gate 比较活跃。
四、GRU 训练过程
从前向传播公式可以看出要学习的参数有 、、、。其中前三个参数都是拼接的(因为后面的向量也是拼接的),所以在训练过程中需要将它们分割出来:
输出层的输入:
输出层的输出:
在得到最终的输出后,就可以写出网络传递的损失,单个样本某时刻的损失为:
则单个样本在所有时刻的损失为:
采用后向误差传播算法来学习网络,求损失函数对各参数的偏导(共7个):
其中各中间参数为:
在计算出了对各参数的偏导之后,就可以更新参数,依次迭代直到损失收敛。
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。