当前位置:   article > 正文

【深度学习8】基于Pytorch实现门控循环单元(GRU)【附源代码】_gru pytorch 代码

gru pytorch 代码

目录

1. 概述

2. 更新门和重置门

3. 候选隐状态

4. 隐状态更新

5. GRU网络架构

6. 源代码

7. 参考资料


1. 概述

        本文是作者自学深度学习的第9篇章(学习资料为李沐老师的深度学习课程),主要对GRU模型中的一些基本概念,如更新门重置门以及GRU的网络架构进行了整理和归纳。

2. 更新门和重置门

        在GRU模型中引入了更新门重置门两个控制单元,其中,重置门允许我们控制“可能还想记住”的过去状态的数量; 更新门将允许我们控制新状态中有多少个是旧状态的副本。

        值得注意的是,重置门使用了sigmoid作为激活函数将输出映射到 [0,1] 实现 “软” 控制。具体而言,GRU使用下式来计算两个门:

R_t=\sigma (X_tW_{xr}+H_{t-1}W_{hr}+b_r)

Z_t=\sigma (X_tW_{xz}+H_{t-1}W_{hz}+b_z)

其中,R_t,Z_t 的大小和隐藏层单元大小一样,因此后续我们可以将其与隐藏层单元按元素相乘。 

3. 候选隐状态

        为了实现重置门的控制效果,我们定义了 “候选隐状态” \tilde{H_t} 这一概念,具体而言:

\tilde{H_t}=tanh(X_tW_{xh}+(R_t\bigodot H_t-1)W_{hh}+b_h)

其中,\bigodot 表示按元素相乘,可见,R_t 控制了候选隐状态被之前隐状态的影响程度,当R_t中的元素为0时,表示将隐状态“重置”。

4. 隐状态更新

        为了实现更新门的控制效果,我们使用下式对隐状态进行更新:

H_t=Z_t\bigodot H_{t-1}+(1-Z_t)\bigodot \tilde{H_t} 

可见若 Z_t  元素为1,表示当前隐状态为前1隐状态的副本,Z_t 元素为0, 表示当前原状态更新为候选隐藏状态。

5. GRU网络架构

        GRU网络的典型架构如下:

其计算步骤为:

  1. 计算重置门 R_t 和更新门 Z_t
  2. 利用重置门 R_t 计算候选隐状态
  3. 利用更新门 Z_t 对隐状态进行更新

6. 源代码

基于Pytorch实现GRU模型

7. 参考资料

李沐带你学深度学习【GRU】

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/Gausst松鼠会/article/detail/412648
推荐阅读
相关标签
  

闽ICP备14008679号