当前位置:   article > 正文

用于序列建模的深度学习:门控循环单元 (GRU)_门控循环单元网络gru

门控循环单元网络gru

概述

门控循环单元 (GRU) 是一种循环神经网络 (RNN) 架构类型。与其他 RNN 一样,GRU 可以处理时间序列、自然语言和语音等顺序数据。GRU 与其他 RNN 架构(例如长短期记忆 (LSTM) 网络)之间的主要区别在于网络如何处理随时间推移的信息流

先决条件

学习门控循环单元的先决条件是:

  • 神经网络和深度学习的基础知识。
  • 熟悉梯度下降和反向传播等概念。
  • 了解递归神经网络 (RNN)(链接到递归神经网络博客和神经网络博客)以及它们可能遭受的梯度消失问题。
  • 基本线性代数,主要是矩阵运算及其性质。
  • 熟悉python和Tensorflow、Keras、Pytorch等库的编程。

介绍

请看下面这句话:

“我妈妈在我生日那天送了我一辆自行车,因为她知道我想和朋友一起骑自行车。

从上面的句子中我们可以看出,相互影响的单词可以相距更远。例如,“自行车”和“骑自行车”密切相关,但在句子中相距甚远。

RNN 网络发现很难在如此长的上下文中跟踪状态。它需要找出哪些信息是重要的。然而,GRU单元大大缓解了这个问题。

GRU 网络发明于 2014 年。它解决了涉及长序列的问题,上下文相距更远,就像上面的自行车示例一样。这是可能的,因为 GRU 体系结构中的 GRU 单元是如何构建的。现在让我们更深入地了解GRU网络的理解和工作。

了解 GRU 单元

门控循环单元 (GRU) 单元是 GRU 网络的基本构建块。它由三个主要组件组成:更新门复位门候选隐藏状态

GRU单元的主要优点之一是其简单性。由于它的参数比长短期记忆 (LSTM) 单元少,因此训练和运行速度更快,并且不易出现过拟合

此外,要记住的一件事是,GRU单元的架构很简单,单元本身就是一个黑匣子,关于我们应该考虑多少过去状态以及应该忘记多少的最终决定是由这个GRU单元决定的。我们需要观察内部并了解细胞在想什么。

GRU 与 LSTM 比较

以下是门控循环单元 (GRU) 和长短期记忆 (LSTM) 网络的比较

GRULSTM
结构结构更简单,有两个闸门(更新和复位闸门)具有三个门(输入、忘记和输出门)的更复杂结构
参数更少的参数(3 个权重矩阵)更多参数(4 个权重矩阵)
训练训练速度更快训练缓慢
空间复杂性在大多数情况下,GRU 由于其结构更简单、参数更少,往往使用较少的内存资源,因此更适合大型数据集或序列。LSTM 具有更复杂的结构和更多的参数,因此可能需要更多的内存资源,并且对于大型数据集或序列可能不太有效。
性能在许多任务上通常与 LSTM 类似,但在某些情况下,GRU 已被证明优于 LSTM,反之亦然。最好同时尝试这两种方法,看看哪种方法更适合您的数据集和任务。LSTM 通常在许多任务上表现良好,但计算成本更高,需要更多的内存资源。LSTM 在自然语言理解和机器翻译任务方面比 GRU 更具优势。

GRU的架构

GRU 单元跟踪整个网络中维护的重要信息。GRU 网络通过以下两个门来实现此目的:

  • 复位门
  • 更新 Gate。

下面给出的是 GRU 单元的最简单架构形式。

如下图所示,GRU 单元接受两个输入:

  1. 以前的隐藏状态
  2. 当前时间戳中的输入。

该单元将这些组合在一起,并将它们传递到更新和复位门。为了获得当前时间步的输出,我们必须通过具有softmax激活的密集层传递此隐藏状态以预测输出。这样做,将获得新的隐藏状态,然后传递到下一个时间步骤。

更新门

更新门确定当前哪个 GRU 单元将信息传递到下一个 GRU 单元。它有助于跟踪最重要的信息

让我们看看如何在 GRU 单元中获得更新门的输出。更新门的输入是上一个时间步长  和电流输入。两者都有与之相关的权重,这些权重是在训练过程中学习的。假设与​是,以及​是。更新门的输出​由下式给出,

复位门

重置门识别不必要的信息,并决定从GRU网络中传递哪些信息。简单地说,它决定在特定时间戳删除哪些信息。

让我们看看如何在 GRU 单元中获得复位门的输出。复位门的输入是上一个时间步的隐藏层和电流输入。两者都有与之相关的权重,这些权重是在训练过程中学习的。假设与​是​,以及​是。更新门的输出​由下式给出,

PS的: 需要注意的是,对于两个门,在前一个时间步和当前输入中与隐藏层相关的权重是不同的。这些权重的值是在训练过程中学习的。

GRU是如何工作的?

门控循环单元 (GRU) 网络处理顺序数据,例如时间序列或自然语言,绕过从一个时间步到下一个时间步的隐藏状态。隐藏状态是一个向量,用于捕获与当前时间步相关的过去时间步长中的信息。GRU 背后的主要思想是允许网络决定上一个时间步的哪些信息与当前时间步相关,以及哪些信息可以丢弃。

候选隐藏状态

候选项的隐藏状态是从复位门计算得出的。这用于确定过去存储的信息。这通常称为 GRU 单元中的内存组件。它的计算公式是,

这里 W- 与电流输入相关的权重​- 复位门的输出 U- 与上一个时间步的隐藏层相关的权重​- 候选隐藏状态

隐藏状态

以下公式给出了新的隐藏状态,并取决于更新入口和候选隐藏状态。

这里​- 更新门 KaTeX 解析错误的输出:预期为“EOF”,在位置 2 处得到“”:h_t - 候选隐藏状态​- 上一个时间步的隐藏状态

正如我们在上面的公式中看到的,每当​为 0,则先前隐藏层的信息会被遗忘。它使用新的候选隐藏层的值进行更新(如1−​将是 1)。如果​为 1,则保留先前隐藏层的信息。这就是最相关的信息从一个状态传递到另一个状态的方式。

现在,我们已经掌握了理解 GRU 网络前向传播(即工作)的所有基础知识。事不宜迟,让我们开始吧。

GRU 信元中的前向传播

在门控循环单元 (GRU) 信元中,前向传播过程包括几个步骤:

  • 计算更新门(​) 使用更新门公式:

  • 计算复位门的输出(​) 使用复位门公式

  • 计算候选人的隐藏状态

  • 计算新的隐藏状态

这就是 GRU 网络的 GRU 信元中前向传播的方式。

我们有一个关于如何在 GRU 网络中学习权重以做出正确预测的问题。让我们在下一节中了解这一点。

GRU 信元中的反向传播

请看下面的图片。让每个隐藏层(橙色)代表一个 GRU 单元格。

在上图中,我们可以看到,每当网络预测错误时,网络都会将其与原始标签进行比较,然后损失会在整个网络中传播。这种情况一直持续到识别所有权重的值,以便用于计算损失的损失函数的值最小。在此期间,与隐藏层和输入相关的权重和偏差会进行微调。

让我们看看如何在下面的示例的帮助下在 GRU 网络中微调单个权重值。 让我们概括一个变量的概念;我们称之为 θ1。

让我们考虑一个参数的值​(theta)使一些任意成本函数最小化.

首先,让我们绘制成本函数作为​如下:

附言: 为了简单易用,我们只考虑了一个参数。这可以很容易地扩展到多维度(或多个参数)

我们为并绘制其对应的成本函数值。我们可以看到成本函数值相当高。

我们可以找到合适的值​使得成本函数的值最小化,如下所示:

  • 由于我们最初为​并发现成本函数值较高,求该点的斜率。
  • 如果斜率为正,我们可以从下图中看到,减小​将降低成本函数值。因此,继续降低少量。

  • 如果斜率为负,我们可以从下图中看出,增加 θ1 的值会减小成本函数值。因此,继续增加 θ1 的值。

让我们将上面的梯度下降概念(是的,这个概念称为梯度下降)总结为一个广义公式:

在这里,“α”表示学习率(即梯度下降步长应该有多大)。大多数情况下,“α”采用 0.01、0.001、0.0001 等值。

上面的公式适用于单个感兴趣的参数。对于两个参数,例如 θ1 和 θ2 的值如下所示:

重复直到收敛 {

(同时更新= 0和 j = 1)

}

正如我们所看到的,首先,您可以任意选择 θ1 和 θ2 的值。然后,您找到此时函数的斜率,并更新公式中 and 的值,并执行相同的操作,直到该值减少/增加很少(这是因为当参数接近其最佳值(局部/全局最小值)时,它们会采取非常小的步骤。因此,短语“重复直到收敛”)。

通过遵循上述方法,我们可以找到最小化成本函数所需的参数值数量。反过来,这将帮助我们找到权重和偏差的最佳值,并以最小的损失做出良好的预测。需要注意的是,当前 GRU 中的权重会根据下一个 GRU 单元进行更新。

在 Python 中实现 GRU

让我们在IMDB数据集上实现GRU网络。

1) 加载所需的库

  1. from keras.datasets import imdb
  2. from keras.models import Sequential
  3. from keras.layers import Dense,GRU,Flatten,LSTM
  4. from keras.layers import Embedding
  5. from keras.utils import pad_sequences

2)下载数据集,每条评论仅获得500字

  1. #Keeping only the top n words
  2. word_count = 5000
  3. (x_train, y_train), (x_test,y_test) = imdb.load_data(num_words = word_count)
  4. #Taking the top 500 words
  5. word_max = 500
  6. x_train = pad_sequences(x_train, maxlen=word_max)
  7. x_test = pad_sequences(x_test, maxlen=word_max)

3) 构建模型

  1. model = Sequential()
  2. model.add(Embedding(word_count,100,input_length=word_max))
  3. model.add(GRU(100))
  4. model.add(Dense(1,activation='sigmoid'))
  5. model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])
  6. print(model.summary())

输出

  1. Model: "sequential"
  2. _________________________________________________________________
  3. Layer (type) Output Shape Param #
  4. =================================================================
  5. embedding (Embedding) (None, 500, 100) 500000
  6. gru (GRU) (None, 100) 60600
  7. dense (Dense) (None, 1) 101
  8. =================================================================
  9. Total params: 560,701
  10. Trainable params: 560,701
  11. Non-trainable params: 0
  12. _________________________________________________________________
  13. None

4) 模型拟合

  1. model.fit(x_train,y_train, epochs=3, batch_size=64)

输出

  1. Epoch 1/3
  2. 391/391 [==============================] - 13s 25ms/step - loss: 0.4977 - accuracy: 0.7450
  3. Epoch 2/3
  4. 391/391 [==============================] - 8s 21ms/step - loss: 0.2910 - accuracy: 0.8800
  5. Epoch 3/3
  6. 391/391 [==============================] - 8s 21ms/step - loss: 0.2206 - accuracy: 0.9146

5)预测每个时间步的下一个单词

  1. y_predict = model.predict(x_test)
  2. print(y_predict)

输出

  1. [[0.07935733]
  2. [0.9810135 ]
  3. [0.5182 ]
  4. ...
  5. [0.08272645]
  6. [0.15271361]
  7. [0.9131225 ]]

6)寻找模型的准确性

  1. result = model.evaluate(x_test,y_test, verbose=0)
  2. print("Accuracy = %.2f%%" % (result[1]*100))

输出

Accuracy = 88.39%

结论

  • GRU 网络是 RNN 网络的修改版。GRU单元由两个门组成:更新门和复位门。
  • 更新门确定哪个模型将信息传递到下一个 GRU 单元。重置门识别不必要的信息,并决定从GRU网络中传递哪些信息。
  • 候选项的隐藏状态是从复位门计算得出的。这用于确定过去存储的信息。这通常称为 GRU 单元中的内存组件。
  • 使用复位门、更新门和候选隐藏状态的输出计算新的隐藏状态。与各种门、输入和隐藏状态相关的权重是通过称为反向传播的概念学习的。
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/凡人多烦事01/article/detail/449875
推荐阅读
相关标签
  

闽ICP备14008679号