当前位置:   article > 正文

门控循环单元(GRU)

门控循环单元(gru)

1. 什么是GRU

在循环神经⽹络中的梯度计算⽅法中,我们发现,当时间步数较⼤或者时间步较小时,**循环神经⽹络的梯度较容易出现衰减或爆炸。虽然裁剪梯度可以应对梯度爆炸,但⽆法解决梯度衰减的问题。**通常由于这个原因,循环神经⽹络在实际中较难捕捉时间序列中时间步距离较⼤的依赖关系。

**门控循环神经⽹络(gated recurrent neural network)的提出,正是为了更好地捕捉时间序列中时间步距离较⼤的依赖关系。**它通过可以学习的⻔来控制信息的流动。其中,门控循环单元(gatedrecurrent unit,GRU)是⼀种常⽤的门控循环神经⽹络。

2. ⻔控循环单元

2.1 重置门和更新门

GRU它引⼊了**重置⻔(reset gate)和更新⻔(update gate)**的概念,从而修改了循环神经⽹络中隐藏状态的计算⽅式。

门控循环单元中的重置⻔和更新⻔的输⼊均为当前时间步输⼊ 与上⼀时间步隐藏状态,输出由激活函数为sigmoid函数的全连接层计算得到。 如下图所示:

 

具体来说,假设隐藏单元个数为 h,给定时间步 t 的小批量输⼊ (样本数为n,输⼊个数为d)和上⼀时间步隐藏状态 。重置⻔ 和更新⻔ 的计算如下:

 

 

sigmoid函数可以将元素的值变换到0和1之间。因此,重置⻔ 和更新⻔ 中每个元素的值域都是[0*,* 1]。

2.2 候选隐藏状态

接下来,⻔控循环单元将计算候选隐藏状态来辅助稍后的隐藏状态计算。我们将当前时间步重置⻔的输出与上⼀时间步隐藏状态做按元素乘法(符号为)。如果重置⻔中元素值接近0,那么意味着重置对应隐藏状态元素为0,即丢弃上⼀时间步的隐藏状态。如果元素值接近1,那么表⽰保留上⼀时间步的隐藏状态。然后,将按元素乘法的结果与当前时间步的输⼊连结,再通过含激活函数tanh的全连接层计算出候选隐藏状态,其所有元素的值域为[-1,1]。

具体来说,时间步 t 的候选隐藏状态 的计算为:

 

从上⾯这个公式可以看出,重置⻔控制了上⼀时间步的隐藏状态如何流⼊当前时间步的候选隐藏状态。而上⼀时间步的隐藏状态可能包含了时间序列截⾄上⼀时间步的全部历史信息。因此,重置⻔可以⽤来丢弃与预测⽆关的历史信息。

2.3 隐藏状态

最后,时间步t的隐藏状态 的计算使⽤当前时间步的更新⻔ 来对上⼀时间步的隐藏状态 和当前时间步的候选隐藏状态 做组合:

 

值得注意的是,**更新⻔可以控制隐藏状态应该如何被包含当前时间步信息的候选隐藏状态所更新,**如上图所⽰。假设更新⻔在时间步之间⼀直近似1。那么,在时间步间的输⼊信息⼏乎没有流⼊时间步 t 的隐藏状态 实际上,这可以看作是较早时刻的隐藏状态 直通过时间保存并传递⾄当前时间步 t。这个设计可以应对循环神经⽹络中的梯度衰减问题,并更好地捕捉时间序列中时间步距离较⼤的依赖关系。

 

我们对⻔控循环单元的设计稍作总结:

  • 重置⻔有助于捕捉时间序列⾥短期的依赖关系;
  • 更新⻔有助于捕捉时间序列⾥⻓期的依赖关系。

3. 代码实现GRU(tensorflow实现)

  1. import tensorflow as tf
  2. from tensorflow.contrib import rnn
  3. old_v = tf.compat.v1.logging.get_verbosity()
  4. tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
  5. # 导入 MINST 数据集
  6. from tensorflow.examples.tutorials.mnist import input_data
  7. mnist = input_data.read_data_sets("data/", one_hot=True)
  8. tf.compat.v1.logging.set_verbosity(old_v)
  1. batch_size = 100
  2. time_step =28 # 时间步(每个时间步处理图像的一行)
  3. data_length = 28 # 每个时间步输入数据的长度(这里就是图像的宽度)
  4. learning_rate = 0.01
  1. # 定义占位符
  2. X_ = tf.placeholder(tf.float32, [None, 784])
  3. Y_ = tf.placeholder(tf.int32, [None, 10])
  4. # dynamic_rnn的输入数据(batch_size, max_time, ...)
  5. inputs = tf.reshape(X_, [-1, time_step, data_length])
  6. # 验证集
  7. validate_data = {X_: mnist.validation.images, Y_: mnist.validation.labels}
  8. # 测试集
  9. test_data = {X_: mnist.test.images, Y_: mnist.test.labels}
  10. # 定义一个两层的GRU模型
  11. gru_layers = rnn.MultiRNNCell([rnn.GRUCell(num_units=num) for num in [100, 100]], state_is_tuple=True)
  12. outputs, h_ = tf.nn.dynamic_rnn(gru_layers, inputs, dtype=tf.float32)
  13. output = tf.layers.dense(outputs[:, -1, :], 10) #获取GRU网络的最后输出状态
  14. # 定义交叉熵损失函数和优化器
  15. loss = tf.losses.softmax_cross_entropy(onehot_labels=Y_, logits=output)
  16. train_op = tf.train.AdamOptimizer(learning_rate).minimize(loss)
  17. # 计算准确率
  18. accuracy = tf.metrics.accuracy(labels=tf.argmax(Y_, axis=1), predictions=tf.argmax(output, axis=1))[1]
  19. ## 初始化变量
  20. sess = tf.Session()
  21. init = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())
  22. sess.run(init)
  23. for step in range(3000):
  24. # 获取一个batch的训练数据
  25. train_x, train_y = mnist.train.next_batch(batch_size)
  26. _, loss_ = sess.run([train_op, loss], {X_: train_x, Y_: train_y})
  27. # 在验证集上计算准确率
  28. if step % 100 == 0:
  29. val_acc = sess.run(accuracy, feed_dict=validate_data)
  30. print('step:', step,'train loss: %.4f' % loss_, '| val accuracy: %.2f' % val_acc)
  31. ## 计算测试集史上的准确率
  32. test_acc = sess.run(accuracy, feed_dict=test_data)
  33. print('test loss: %.4f' % test_acc)

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/笔触狂放9/article/detail/449837
推荐阅读
相关标签
  

闽ICP备14008679号