当前位置:   article > 正文

动手学深度学习(三十九)——门控循环单元GRU_门控循环单元(gru)

门控循环单元(gru)

再次声明:本文主要参考李沐老师B站动手学深度学习课程进行笔记整理和代码复现。如果需要看视频,可以click。感谢沐神的分享!!!

门控循环单元(GRU

  我们想想对于一个序列而言,有的早期观测值对所有的未来观测值都非常有用,有的观测值对所有的未来预测都没有用,或者说有的序列各个部分之间有逻辑中断。总结起来就是:

  • 并不是每个观测值都是同等重要
  • 想要只记住相关的观察需要:
    • 能关注的机制(更新门)
    • 能遗忘的机制(重置门)

  在学术界已经提出了许多方法来解决这个问题。其中最早的方法是"长-短期记忆" (long-short-term memory, LSMT):(Hochreiter.Schmidhuber.1997) 。门控循环单元(gated recurrent unit, GRU)(Cho.Van-Merrienboer.Bahdanau.ea.2014) 是一个稍微简化的变体,通常能够提供同等的效果,并且计算 (Chung.Gulcehre.Cho.ea.2014) 的速度明显更快。由于它更简单,就让我们从门控循环单元开始。

来都来了,把这几篇文章都贴出来吧:
【1】LSTM:Long Short-Term Memory
【2】GRU:Learning Phrase Representations using RNN Encoder–Decoderfor Statistical Machine Translation

一、门控隐藏状态

  普通的循环神经网络和门控循环单元之间的关键区别在于后者支持隐藏状态的门控(或者说选通)。这意味着有专门的机制来确定应该何时 更新 隐藏状态,以及应该何时 重置 隐藏状态。这些机制是可学习的,并且能够解决了上面列出的问题。

例如,如果第一个标记非常重要,我们将学会在第一次观测之后不更新隐藏状态。同样,我们也可以学会跳过不相关的临时观测。最后,我们还将学会在需要的时候重置隐藏状态。

1.1 重置门和更新门

  我们首先要介绍的是 重置门(reset gate)和 更新门(update gate)。我们把它们设计成 ( 0 , 1 ) (0, 1) (0,1) 区间中的向量,这样我们就可以进行凸组合。例如,重置门允许我们控制可能还想记住的过去状态的数量。同样,更新门将允许我们控制新状态中有多少个是旧状态的副本。

  我们从构造这些门控开始。下图描述了门控循环单元中的重置门和更新门的输入,输入是由当前时间步的输入和前一时间步的隐藏状态给出。两个门的输出是由使用 sigmoid 激活函数的两个全连接层给出。

数学描述,对于给定的时间步 t t t,假设输入是一个小批量 X t ∈ R n × d \mathbf{X}_t \in \mathbb{R}^{n \times d} XtRn×d (样本个数: n n n,输入个数: d d d),上一个时间步的隐藏状态是 H t − 1 ∈ R n × h \mathbf{H}_{t-1} \in \mathbb{R}^{n \times h} Ht1Rn×h(隐藏单元个数: h h h)。然后,重置门 R t ∈ R n × h \mathbf{R}_t \in \mathbb{R}^{n \times h} RtRn×h 和更新门 Z t ∈ R n × h \mathbf{Z}_t \in \mathbb{R}^{n \times h} ZtRn×h 的计算如下:

R t = σ ( X t W x r + H t − 1 W h r + b r ) , Z t = σ ( X t W x z + H t − 1 W h z + b z ) ,

Rt=σ(XtWxr+Ht1Whr+br),Zt=σ(XtWxz+Ht1Whz+bz),
Rt=σ(XtWxr+Ht1Whr+br),Zt=σ(XtWxz+Ht1Whz+bz),

其中 W x r , W x z ∈ R d × h \mathbf{W}_{xr}, \mathbf{W}_{xz} \in \mathbb{R}^{d \times h} Wxr,WxzRd×h W h r , W h z ∈ R h × h \mathbf{W}_{hr}, \mathbf{W}_{hz} \in \mathbb{R}^{h \times h} Whr,WhzRh×h 是权重参数, b r , b z ∈ R 1 × h \mathbf{b}_r, \mathbf{b}_z \in \mathbb{R}^{1 \times h} br,bzR1×h 是偏置参数。请注意,在求和过程中会触发广播机制(请参阅 :numref:subsec_broadcasting )。我们使用 sigmoid 函数(如 :numref:sec_mlp 中介绍的)将输入值转换到区间 ( 0 , 1 ) (0, 1) (0,1)

1.2候选隐藏状态

  接下来,让我们将重置门 R t \mathbf{R}_t Rt 与RNN中的常规隐状态更新机制集成,得到在时间步 t t t 的候选隐藏状态 H ~ t ∈ R n × h \tilde{\mathbf{H}}_t \in \mathbb{R}^{n \times h} H~tRn×h

H ~ t = tanh ⁡ ( X t W x h + ( R t ⊙ H t − 1 ) W h h + b h ) , \tilde{\mathbf{H}}_t = \tanh(\mathbf{X}_t \mathbf{W}_{xh} + \left(\mathbf{R}_t \odot \mathbf{H}_{t-1}\right) \mathbf{W}_{hh} + \mathbf{b}_h), H~t=tanh(XtWxh+(RtHt1)Whh+bh),

  其中 W x h ∈ R d × h \mathbf{W}_{xh} \in \mathbb{R}^{d \times h} WxhRd×h W h h ∈ R h × h \mathbf{W}_{hh} \in \mathbb{R}^{h \times h} WhhRh×h 是权重参数, b h ∈ R 1 × h \mathbf{b}_h \in \mathbb{R}^{1 \times h} bhR1×h 是偏置项,符号 ⊙ \odot 是哈达码乘积(按元素乘积)运算符。在这里,我们使用 tanh 非线性激活函数来确保候选隐藏状态中的值保持在区间 ( − 1 , 1 ) (-1, 1) (1,1) 中。

   计算的结果是 候选者(candidate),因为我们仍然需要结合更新门的操作。与基础的RNN相比 候选隐藏状态中的 R t \mathbf{R}_t Rt H t − 1 \mathbf{H}_{t-1} Ht1 的元素相乘可以减少以往状态的影响。每当重置门 R t \mathbf{R}_t Rt 中的项接近 1 1 1 时,我们恢复一个如基本RNN中的普通的循环神经网络。对于重置门 R t \mathbf{R}_t Rt 中所有接近 0 0 0 的项,候选隐藏状态是以 X t \mathbf{X}_t Xt 作为输入的多层感知机的结果。因此,任何预先存在的隐藏状态都会被 重置 为默认值。下图明了应用重置门之后的计算流程。

1.3 隐藏状态

  最后,我们需要结合更新门 Z t \mathbf{Z}_t Zt 的效果。这确定新的隐藏状态 H t ∈ R n × h \mathbf{H}_t \in \mathbb{R}^{n \times h} HtRn×h 在多大程度上就是旧的状态 H t − 1 \mathbf{H}_{t-1} Ht1 ,以及对新的候选状态 H ~ t \tilde{\mathbf{H}}_t H~t 的使用量。更新门 Z t \mathbf{Z}_t Zt 仅需要在 H t − 1 \mathbf{H}_{t-1} Ht1 H ~ t \tilde{\mathbf{H}}_t H~t 之间进行按元素的凸组合就可以实现这个目标。这就得出了门控循环单元的最终更新公式:

H t = Z t ⊙ H t − 1 + ( 1 − Z t ) ⊙ H ~ t . \mathbf{H}_t = \mathbf{Z}_t \odot \mathbf{H}_{t-1} + (1 - \mathbf{Z}_t) \odot \tilde{\mathbf{H}}_t. Ht=ZtHt1+(1Zt)H~t.

  每当更新门 Z t \mathbf{Z}_t Zt 接近 1 1 1 时,我们就只保留旧状态。此时,来自 X t \mathbf{X}_t Xt 的信息基本上被忽略,从而有效地跳过了依赖链条中的时间步 t t t。相反,当 Z t \mathbf{Z}_t Zt 接近 0 0 0 时,新的隐藏状态 H t \mathbf{H}_t Ht 就会接近候选的隐藏状态 H ~ t \tilde{\mathbf{H}}_t H~t。==这些设计可以帮助我们处理循环神经网络中的梯度消失问题,并更好地捕获时间步距离很长的序列的依赖关系。==例如,如果整个子序列的所有时间步的更新门都接近于 1 1 1,则无论序列的长度如何,在序列起始时间步的旧隐藏状态都将很容易保留并传递到序列结束。下图说明了更新门起作用后的计算流。

总之,门控循环单元具有以下两个显著特征:

  • 重置门有助于捕获序列中的短期依赖关系。
  • 更新门有助于捕获序列中的长期依赖关系。

二、从零实现GRU

import torch
from torch import nn
from d2l import torch as d2l

batch_size, num_steps = 32, 35
train_iter, vocab = d2l.load_data_time_machine(batch_size, num_steps)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6

2.1 初始化模型参数

从标准差为0.01的高斯分布中提取权重,偏置设置为0,使用超参数num_hidden定义隐藏单元的数量,实例化与更新门、重置门、候选状态和输出层相关的所有权重和偏置

def get_params(vocab_size,num_hiddens,device):
    num_inputs= num_outputs = vocab_size
    
    def normal(shape):
        return torch.randn(size=shape,device=device)*0.01
    
    def three():
        return (normal((num_inputs,num_hiddens)),
                normal((num_hiddens,num_hiddens)),
                torch.zeros(num_hiddens,device=device))
    
    W_xz,W_hz,b_z = three() # 更新门参数
    W_xr,W_hr,b_r = three() # 重置门参数
    W_xh,W_hh,b_h = three() # 候选状态参数
    
    # 输出层参数
    W_hq = normal((num_hiddens,num_outputs))
    b_q = torch.zeros(num_outputs,device=device)
    # 附加梯度
    # 附加梯度
    params = [W_xz, W_hz, b_z, W_xr, W_hr, b_r, W_xh, W_hh, b_h, W_hq, b_q]
    for param in params:
        param.requires_grad_(True)
    return params
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24

2.2 定义模型

# 定义一个隐藏状态的初始化函数,返回一个形状为(批量大小,隐藏单元数)的张量,值全为0
def init_gru_state(batch_size,num_hiddens,device):
    return (torch.zeros((batch_size,num_hiddens),device=device),)
  • 1
  • 2
  • 3

R t = σ ( X t W x r + H t − 1 W h r + b r ) , Z t = σ ( X t W x z + H t − 1 W h z + b z ) , H ~ t = tanh ⁡ ( X t W x h + ( R t ⊙ H t − 1 ) W h h + b h ) , H t = Z t ⊙ H t − 1 + ( 1 − Z t ) ⊙ H ~ t .

Rt=σ(XtWxr+Ht1Whr+br),Zt=σ(XtWxz+Ht1Whz+bz),H~t=tanh(XtWxh+(RtHt1)Whh+bh),Ht=ZtHt1+(1Zt)H~t.
Rt=σ(XtWxr+Ht1Whr+br),Zt=σ(XtWxz+Ht1Whz+bz),H~t=tanh(XtWxh+(RtHt1)Whh+bh),Ht=ZtHt1+(1Zt)H~t.

# 定义GRU模型
def gru(inputs,state,params):
    W_xz, W_hz, b_z, W_xr, W_hr, b_r, W_xh, W_hh, b_h, W_hq, b_q = params
    H, = state
    outputs = []
    for X in inputs:
        Z = torch.sigmoid((X @ W_xz) + (H @ W_hz) + b_z)
        R = torch.sigmoid((X @ W_xr) + (H @ W_hr) + b_r)
        H_tilda = torch.tanh((X @ W_xh) + ((R * H)@W_hh) + b_h)
        H = Z * H + (1 - Z) * H_tilda
        Y = H @ W_hq + b_q
        outputs.append(Y)
    return torch.cat(outputs, dim=0), (H,)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13

2.3 训练与预测

训练和预测的工作方式与RNN中的实现完全相同。训练结束后,我们分别打印输出训练集的困惑度和前缀“time traveler”和“traveler”的预测序列上的困惑度。

vocab_size, num_hiddens, device = len(vocab), 256, d2l.try_gpu()
num_epochs, lr = 500, 1
model = d2l.RNNModelScratch(len(vocab), num_hiddens, device, get_params,
                            init_gru_state, gru)
d2l.train_ch8(model, train_iter, vocab, lr, num_epochs, device)
  • 1
  • 2
  • 3
  • 4
  • 5
perplexity 1.0, 57290.3 tokens/sec on cuda:0
time travelleryou can show black is white by argument said filby
traveller with a slight accession ofcheerfulness really thi
  • 1
  • 2
  • 3

2.4 简洁实现

高级API包含了前文介绍地全部配置细节,所以可以直接实例化GRU。其使用编译好的运算符来进行计算,而非python处理其中的许多细节

batch_size, num_steps = 32, 35
train_iter, vocab = d2l.load_data_time_machine(batch_size, num_steps)
vocab_size, num_hiddens, device = len(vocab), 256, d2l.try_gpu()
num_inputs = vocab_size
gru_layer = nn.GRU(num_inputs,num_hiddens)
model = d2l.RNNModel(gru_layer,len(vocab))
model = model.to(device)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
d2l.train_ch8(model, train_iter, vocab, lr, num_epochs, device)
  • 1
perplexity 1.0, 353447.0 tokens/sec on cuda:0
time traveller for so it will be convenient to speak of himwas e
travelleryou can show black is white by argument said filby
  • 1
  • 2
  • 3

三、小结

  • 门控循环神经网络可以更好地捕获时间步距离很长的序列上的依赖关系。
  • 重置门有助于捕获序列中的短期依赖关系。
  • 更新门有助于捕获序列中的长期依赖关系。
  • 重置门打开时,门控循环单元包含基本循环神经网络;更新门打开时,门控循环单元可以跳过子序列。

四、练习

  1. 假设我们只想使用时间步 t ′ t' t 的输入来预测时间步 t > t ′ t > t' t>t 的输出。对于每个时间步,重置门和更新门的最佳值是什么?

更新们和重置门都为0表示不使用之前的隐藏状态数据

  1. 调整和分析超参数对运行时间、困惑度和输出顺序的影响。
  2. 比较 rnn.RNNrnn.GRU 的不同实现对运行时间、困惑度和输出字符串的影响。
  3. 如果仅仅实现门控循环单元的一部分,例如,只有一个重置门或一个更新门会怎样?
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/菜鸟追梦旅行/article/detail/347755
推荐阅读
相关标签
  

闽ICP备14008679号