当前位置:   article > 正文

循环神经网络(RNN)之门控循环单元(GRU)_门控循环单元(gru)是那一篇论文中提出来的

门控循环单元(gru)是那一篇论文中提出来的

        在实现门控循环单元的循环神经网络之前,可以先熟悉论文:Empirical Evaluation of Gated Recurrent Neural Networks on Sequence Modeling
另外一篇关于RNN编码-解码的论文,大家有兴趣也可以看下,其中一位是来自“深度学习三巨头”之一的约书亚·本吉奥:
Learning Phrase Representations using RNN Encoder-Decoder for Statistical Machine Translation
        我们在上一篇文章:RNN模型参数与变量的依赖以及时间反向传播梯度的推导,熟悉了RNN的一些相关知识,我们发现当时间步较大的时候容易出现梯度爆炸,这个我们用了裁剪梯度来应对,但对于时间步较小时,无法应对梯度出现的衰减问题。通常因为这样的原因,就较难捕捉到时间序列中时间步距离较大的依赖关系。

        于是就提出了带门控的循环单元的RNN,通过学习门来控制信息的流动。在一个门控循环单元中,包含着重置门和更新门,两者有什么特点和优势,如下:

1、重置门(reset gate):当重置门中元素值接近1,就保留上一时间步的隐藏状态,如果元素值接近0时,将忽略上一步的隐藏状态,仅使用当前输入进行复位,有效地丢弃与预测无关的历史信息,从而允许更紧凑的表示。有助于捕捉时间序列里短期的依赖关系。
2、更新门(update gate):控制从上一步的隐藏状态有多少信息传递到当前的隐藏状态。有助于捕捉时间序列里长期的依赖关系。
除了上述两个控制门之外,还有一个候选隐藏状态(candidate hidden state),这个设计主要是应对梯度衰减了,因为它可以保存较早时刻的隐藏状态一直通过时间保存并传递到当前的时间步来。
门控循环单元最终的计算输出是来自上一步的隐藏状态与更新门按元素乘法,再跟当前时间的候选隐藏状态做组合(相加)。

对于这样一个门控循环单元,在流程中如何进行计算的,我个人还是比较喜欢使用画图来直观表现,如下:

 代码即解释,我们来实现它:

  1. import d2lzh as d2l
  2. from mxnet import nd
  3. from mxnet.gluon import rnn
  4. (corpus_indices,char_to_idx,idx_to_char,vocab_size)=d2l.load_data_jay_lyrics()
  5. num_inputs,num_hiddens,num_outputs=vocab_size,256,vocab_size
  6. ctx=d2l.try_gpu()
  7. #ctx=None
  8. def get_params():
  9. def _one(shape):
  10. return nd.random.normal(scale=0.01,shape=shape,ctx=ctx)
  11. def _three():
  12. return (_one((num_inputs,num_hiddens)),_one((num_hiddens,num_hiddens)),nd.zeros(num_hiddens,ctx=ctx))
  13. W_xz,W_hz,b_z=_three()#更新门参数
  14. W_xr,W_hr,b_r=_three()#重置门参数
  15. W_xh,W_hh,b_h=_three()#候选隐藏状态参数
  16. #输出层参数
  17. W_hq=_one((num_hiddens,num_outputs))
  18. b_q=nd.zeros(num_outputs,ctx=ctx)
  19. #附上梯度
  20. params=[W_xz,W_hz,b_z,W_xr,W_hr,b_r,W_xh,W_hh,b_h,W_hq,b_q]
  21. for param in params:
  22. param.attach_grad()
  23. return params
  24. #定义模型
  25. #隐藏状态初始化函数
  26. def init_gru_state(batch_size,num_hiddens,ctx):
  27. return (nd.zeros(shape=(batch_size,num_hiddens),ctx=ctx),)
  28. def gru(inputs,state,params):
  29. W_xz,W_hz,b_z,W_xr,W_hr,b_r,W_xh,W_hh,b_h,W_hq,b_q=params
  30. H,=state
  31. outputs=[]
  32. for X in inputs:
  33. Z=nd.sigmoid(nd.dot(X,W_xz)+nd.dot(H,W_hz)+b_z)
  34. R=nd.sigmoid(nd.dot(X,W_xr)+nd.dot(H,W_hr)+b_r)
  35. H_tilda=nd.tanh(nd.dot(X,W_xh)+nd.dot(R*H,W_hh)+b_h)
  36. H=Z*H+(1-Z)*H_tilda
  37. Y=nd.dot(H,W_hq)+b_q
  38. outputs.append(Y)
  39. return outputs,(H,)
  40. #训练模型(相邻采样)
  41. num_epochs,num_steps,batch_size,lr,clipping_theta=200,35,32,1e2,1e-2
  42. pred_period,pred_len,prefixes=40,50,['分开','不分开']
  43. #d2l.train_and_predict_rnn(gru,get_params,init_gru_state,num_hiddens,vocab_size,ctx,corpus_indices,idx_to_char,char_to_idx,False,num_epochs,num_steps,lr,clipping_theta,batch_size,pred_period,pred_len,prefixes)
  44. #简洁实现
  45. gru_layer=rnn.GRU(num_hiddens)
  46. model=d2l.RNNModel(gru_layer,vocab_size)
  47. d2l.train_and_predict_rnn_gluon(model,num_hiddens,vocab_size,ctx,corpus_indices,idx_to_char,char_to_idx,num_epochs,num_steps,lr,clipping_theta,batch_size,pred_period,pred_len,prefixes)

epoch 40, perplexity 155.968500, time 0.12 sec
 - 分开 我不的让我 我不的让我 我想你的让我 我想你的让我 我想你的让我 我想你的让我 我想你的让我 我想
 - 不分开 我不的让我 我不的让我 我想你的让我 我想你的让我 我想你的让我 我想你的让我 我想你的让我 我想
epoch 80, perplexity 34.249728, time 0.12 sec
 - 分开 我想要你的微笑 一定在美不人 你的让我有多多 爱你在我不多 我爱你的爱笑 让我想这样 我不要再想你
 - 不分开 我想要你的微笑 一定在美不人 你的让我有多多 爱你在我不多 我爱你的爱笑 让我想这样 我不要再想你
epoch 120, perplexity 5.254933, time 0.12 sec
 - 分开 我想带这样打 但知后觉 你已经离不舍 不知不觉 我跟了这节奏 后知后觉 我该好好生活 我该好好生活
 - 不分开 我已经这样奏 后知后觉 我该好好生活 我该好好生活 不知不觉 我跟了这节奏 后知后觉 我该好好生活
epoch 160, perplexity 1.513565, time 0.11 sec
 - 分开 我想轻这里 我妈著好恼 我后 这样 我不要再想要你 不知不觉 你已经离开我 不知不觉 我跟了这节奏
 - 不分开 我已天这样奏 后知后觉 又过了一个秋 后知后觉 我该好好生活 我该好好生活 不知不觉 你已经离开我
epoch 200, perplexity 1.072794, time 0.12 sec
 - 分开 让弄堂的太快否听的见 它一定实现它一定实现 载著你 彷彿载著阳光 不管到哪里都是晴天 蝴蝶自在飞
 - 不分开 我已 这样的玩奏就像龙卷风 离不开暴风圈来不及逃 我不能再想 我不能再想 我不 我不 我不能 爱情

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/煮酒与君饮/article/detail/750217
推荐阅读
相关标签
  

闽ICP备14008679号