当前位置:   article > 正文

时序模型:长短期记忆网络(LSTM)_长短期记忆模型

长短期记忆模型

1. 模型定义

循环神经网络(RNN)模型存在长期依赖问题,不能有效学习较长时间序列中的特征。长短期记忆网络(long short-term memory,LSTM)1是最早被承认能有效缓解长期依赖问题的改进方案。

2. 模型结构

LSTM的隐藏状态计算模块,在RNN基础上引入一个新的内部状态:记忆细胞(memory cell),和三个控制信息传递的逻辑门:输入门(input gate)、遗忘门(forget gate)、输出门(output gate)。其结构如下图所示:

在这里插入图片描述
图中,记忆细胞memory cell)与隐状态具有相同的形状(向量维度),其设计目的是用于记录附加的隐藏状态与输入信息,有些文献认为记忆细胞是一种特殊类型的隐状态;输入门input gate)控制(本时刻)输入观测和(上时刻)隐藏状态中哪些信息会添加进记忆细胞;遗忘门forget gate)控制忘记上时刻记忆细胞中的哪些内容;输出门output gate)控制记忆细胞中哪些信息会输出给隐藏状态。

3. 前向传播

为更容易理解 LSTM 模型的前向传播过程,我们将模型结构图改编为如下所示2图中 a t a^t at t t t 时刻的候选记忆细胞 C ~ t \tilde{C}_t C~t):

在这里插入图片描述

由此我们可以得到 LSTM 模型的前向传播公式:

{ 候 选 记 忆 细 胞 : C ~ t = t a n h ( X t W x c + H t − 1 W h c + b c ) ,      X t ∈ R m × d , H t − 1 ∈ R m × h , W x c ∈ R d × h , W ∈ R h × h 输 入 门 : I t = σ ( X t W x i + H t − 1 W h i + b i ) , W x i ∈ R d × h , W h i ∈ R h × h 遗 忘 门 : F t = σ ( X t W x f + H t − 1 W h f + b f ) , W x f ∈ R d × h , W h f ∈ R h × h 输 出 门 : O t = σ ( X t W x o + H t − 1 W h o + b o ) , W x o ∈ R d × h , W h o ∈ R h × h (3.1.1)

{C~t=tanh(XtWxc+Ht1Whc+bc),    XtRm×d,Ht1Rm×h,WxcRd×h,WRh×hIt=σ(XtWxi+Ht1Whi+bi),WxiRd×h,WhiRh×hFt=σ(XtWxf+Ht1Whf+bf),WxfRd×h,WhfRh×hOt=σ(XtWxo+Ht1Who+bo),WxoRd×h,WhoRh×h
\tag {3.1.1} C~t=tanh(XtWxc+Ht1Whc+bc),It=σ(XtWxi+Ht1Whi+bi),Ft=σ(XtWxf+Ht1Whf+bf),Ot=σ(XtWxo+Ht1Who+bo),    XtRm×d,Ht1Rm×h,WxcRd×h,WRh×hWxiRd×h,WhiRh×hWxfRd×h,WhfRh×hWxoRd×h,WhoRh×h(3.1.1)

{ 记 忆 细 胞 : C t = I t ⊙ C ~ t + F t ⊙ C t − 1 隐 藏 状 态 : H t = O t ⊙ t a n h ( C t ) 模 型 输 出 : Y ^ t = H t W h y + b y , W h y ∈ R h × q ,   Y ^ t ∈ R m × q 损 失 函 数 : L = 1 T ∑ t = 1 T l ( Y ^ t , Y t ) , L ∈ R (3.1.2)

{Ct=ItC~t+FtCt1Ht=Ottanh(Ct)Y^t=HtWhy+by,WhyRh×q, Y^tRm×qL=1Tt=1Tl(Y^t,Yt),LR
\tag {3.1.2} Ct=ItC~t+FtCt1Ht=Ottanh(Ct)Y^t=HtWhy+by,L=T1t=1Tl(Y^t,Yt),WhyRh×q, Y^tRm×qLR(3.1.2)

式中 m m m 为小批量随机梯度下降的批量大小(batch size), d d d 为输入单词的词向量维度 h h h q q q 为隐藏状态和模型输出的向量宽度(维度)。

4. LSTM缓解长期依赖的原理

RNN模型存在长期依赖问题,源自于其反向传播过程中存在的梯度消失现象。LSTM模型通过改进RNN模型的梯度传播过程,来缓解反向传播过程中,距离语句结尾处较远的单词容易出现梯度消失的现象。由第3节所述前向传播过程,将LSTM模型反向传播的计算图绘制如下3
请添加图片描述
所以根据计算图,可以推导出LSTM模型的反向传播公式为:

∂ L ∂ Y ^ t = ∂ l ( Y ^ t , Y t ) T ⋅ ∂ Y ^ t (4.1) \frac{\partial L}{\partial \hat{Y}_t} = \frac{\partial l(\hat{Y}_t, Y_t)}{T \cdot\partial \hat{Y}_t} \tag {4.1} Y^tL=TY^tl(Y^t,Yt)(4.1)

∂ L ∂ Y ^ t ⇒ { ∂ L ∂ W h y = ∂ L ∂ Y ^ t ∂ Y ^ t ∂ W h y ∂ L ∂ H t = { ∂ L ∂ Y ^ t ∂ Y ^ t ∂ H t , t = T ∂ L ∂ Y ^ t ∂ Y ^ t ∂ H t + ∂ L ∂ C t + 1 ∂ C t + 1 ∂ H t , t < T (4.2) \frac{\partial L}{\partial \hat{Y}_t} \Rightarrow

\begin{cases} \frac{\partial L}{\partial W_{hy}} = \frac{\partial L}{\partial \hat{Y}_t} \frac{\partial \hat{Y}_t}{\partial W_{hy}} \\ \\ \frac{\partial L}{\partial H_t} = \begin{cases} \frac{\partial L}{\partial \hat{Y}_t} \frac{\partial \hat{Y}_t}{\partial H_t}, & t=T \\ \\ \frac{\partial L}{\partial \hat{Y}_t} \frac{\partial \hat{Y}_t}{\partial H_t} + \frac{\partial L}{\partial C_{t+1}}\frac{\partial C_{t+1}}{\partial H_{t}}, & t<T \end{cases}
\end{cases} \tag {4.2} Y^tLWhyL=Y^tLWhyY^tHtL=Y^tLHtY^t,Y^tLHtY^t+Ct+1LHtCt+1,t=Tt<T(4.2)

∂ L ∂ H t ⇒ { ∂ L ∂ O t = ∂ L ∂ H t ∂ H t ∂ O t ∂ L ∂ C t = { ∂ L ∂ H t ∂ H t ∂ C t , t = T ∂ L ∂ H t ∂ H t ∂ C t + ∂ L ∂ H t ∂ C t + 1 ∂ C t , t < T (4.3) \frac{\partial L}{\partial H_t} \Rightarrow

\begin{cases} \frac{\partial L}{\partial O_t} = \frac{\partial L}{\partial H_t} \frac{\partial H_t}{\partial O_t} \\ \\ \frac{\partial L}{\partial C_t} = \begin{cases} \frac{\partial L}{\partial H_t} \frac{\partial H_t}{\partial C_t}, & t=T \\ \\ \frac{\partial L}{\partial H_t} \frac{\partial H_t}{\partial C_t} + \frac{\partial L}{\partial H_t} \frac{\partial C_{t+1}}{\partial C_{t}}, & t<T \end{cases}
\end{cases} \tag {4.3} HtLOtL=HtLOtHtCtL=HtLCtHt,HtLCtHt+HtLCtCt+1,t=Tt<T(4.3)

∂ L ∂ O t ⇒ { ∂ L ∂ W x o = ∂ L ∂ O t ∂ O t ∂ W x o ∂ L ∂ W h o = ∂ L ∂ O t ∂ O t ∂ W h o ∂ L ∂ b o = ∂ L ∂ O t ∂ O t ∂ b o ∂ L ∂ C t ⇒ { ∂ L ∂ C ~ t = ∂ L ∂ C t ∂ C t ∂ C ~ t ∂ L ∂ I t = ∂ L ∂ C t ∂ C t ∂ I t ∂ L ∂ F t = ∂ L ∂ C t ∂ C t ∂ F t (4.4)

LOt{LWxo=LOtOtWxoLWho=LOtOtWhoLbo=LOtOtboLCt{LC~t=LCtCtC~tLIt=LCtCtItLFt=LCtCtFt
\tag {4.4} OtLWxoL=OtLWxoOtWhoL=OtLWhoOtboL=OtLboOtCtLC~tL=CtLC~tCtItL=CtLItCtFtL=CtLFtCt(4.4)

∂ L ∂ C ~ t ⇒ { ∂ L ∂ W x c = ∂ L ∂ C ~ t ∂ C ~ t ∂ W x c ∂ L ∂ W h c = ∂ L ∂ C ~ t ∂ C ~ t ∂ W h c ∂ L ∂ b c = ∂ L ∂ C ~ t ∂ C ~ t ∂ b c ∂ L ∂ I t ⇒ { ∂ L ∂ W x i = ∂ L ∂ I t ∂ I t ∂ W x i ∂ L ∂ W h i = ∂ L ∂ I t ∂ I t ∂ W h i ∂ L ∂ b i = ∂ L ∂ I t ∂ I t ∂ b i ∂ L ∂ F t ⇒ { ∂ L ∂ W x f = ∂ L ∂ F t ∂ F t ∂ W x f ∂ L ∂ W h f = ∂ L ∂ F t ∂ F t ∂ W h f ∂ L ∂ b f = ∂ L ∂ F t ∂ F t ∂ b f (4.5)

LC~t{LWxc=LC~tC~tWxcLWhc=LC~tC~tWhcLbc=LC~tC~tbcLIt{LWxi=LItItWxiLWhi=LItItWhiLbi=LItItbiLFt{LWxf=LFtFtWxfLWhf=LFtFtWhfLbf=LFtFtbf
\tag {4.5} C~tLWxcL=C~tLWxcC~tWhcL=C~tLWhcC~tbcL=C~tLbcC~tItLWxiL=ItLWxiItWhiL=ItLWhiItbiL=ItLbiItFtLWxfL=FtLWxfFtWhfL=FtLWhfFtbfL=FtLbfFt(4.5)
可见反向传播公式的难点是对 式 ( 4.2 ) 式(4.2) (4.2) 式 ( 4.3 ) 式(4.3) (4.3)中,不同时间步间的(传递)梯度 ∂ C t + 1 / ∂ H t \partial C_{t+1} / \partial H_{t} Ct+1/Ht ∂ C t + 1 / ∂ C t \partial C_{t+1} / \partial C_{t} Ct+1/Ct 的求解;而其他梯度项求解十分容易,本文便不做过多展开了。

本文自 t = T t=T t=T 时刻,逐(时间)步反向传播推算出每时刻损失函数对模型隐藏状态的偏导数后,根据数学归纳法得到损失函数对模型隐藏状态的梯度公式为(推导过程见作者符号计算程序:LSTM模型缓解长期依赖问题的数学证明(符号计算程序)):

$$

$$

可见,LSTM模型是通过增加模型参数的低阶幂次项和在每个模型参数的幂次项前添加可变(通过模型训练改变)的乘数项,来缓解参数高阶幂次项趋近于0引起的梯度消失问题

关于参数高阶幂次项引发的梯度消失问题,更详细解释可见作者文章:时序模型:循环神经网络(RNN)中关于式(3.5)和式(3.9)的解释。

5. 模型的代码实现

5.1 TensorFlow 框架实现


  • 1

5.2 Pytorch 框架实现

"""
v2.0 修复RNN参数初始化不当,引起的时间步传播梯度消失问题。   2022.04.28
"""
import torch
from torch import nn
from torch.nn import functional as F
from torch.nn.parameter import Parameter


#
class GRU_Cell(nn.Module):
    # noinspection PyTypeChecker
    def __init__(self, token_dim, hidden_dim
                 , reset_act=nn.Sigmoid()
                 , update_act=nn.Sigmoid()
                 , hathid_act=nn.Tanh()
                 , device="cpu"):
        super().__init__()
        #
        self.hidden_dim = hidden_dim
        self.device = device
        #
        self.ResetG = Simple_RNN_Cell(
            token_dim, hidden_dim, activation=reset_act, device=device
        )
        self.UpdateG = Simple_RNN_Cell(
            token_dim, hidden_dim, activation=update_act, device=device
        )
        self.HatHidden = Simple_RNN_Cell(
            token_dim, hidden_dim, activation=hathid_act, device=device
        )

    def forward(self, inputs, last_state):
        Rg = self.ResetG(
            inputs, last_state
        )[-1]
        Zg = self.UpdateG(
            inputs, last_state
        )[-1]
        hat_hidden = self.HatHidden(
            inputs, [Rg * last_state[-1]]
        )[-1]
        hidden = Zg * last_state[-1] + (1-Zg) * hat_hidden
        return [hidden]

    def zero_initialization(self, batch_size):
        return [torch.zeros([batch_size, self.hidden_dim]).to(self.device)]


#
class RNN_Layer(nn.Module):
    """
    bidirectional:  If ``True``, becomes a bidirectional RNN network. Default: ``False``.
    padding:        String, 'pre' or 'post' (optional, defaults to 'pre'): pad either before or after each sequence.
    """
    def __init__(self, rnn_cell, bidirectional=False, pad_position='post'):
        super().__init__()
        self.RNNCell = rnn_cell
        self.bidirectional = bidirectional
        self.padding = pad_position

    def forward(self, inputs, mask=None, initial_state=None):
        """
        inputs:   it's shape is [batch_size, time_steps, token_dim]
        mask:     it's shape is [batch_size, time_steps]
        :return
        sequence:    it is hidden state sequence, and its' shape is [batch_size, time_steps, hidden_dim]
        last_state: it is the hidden state of input sequences at last time step,
                    but, attentively, the last token wouble be a padding token,
                    so this last state is not the real last state of input sequences;
                    if you want to get the real last state of input sequences, please use utils.get_rnn_last_state(hidden state sequence).
        """
        batch_size, time_steps, token_dim = inputs.shape
        #
        if initial_state is None:
            initial_state = self.RNNCell.zero_initialization(batch_size)
        if mask is None:
            if batch_size == 1:
                mask = torch.ones([1, time_steps]).to(inputs.device.type)
            elif self.padding == 'pre':
                raise ValueError('请给定掩码矩阵(mask)')
            elif self.padding == 'post' and self.bidirectional is True:
                raise ValueError('请给定掩码矩阵(mask)')

        # 正向时间步循环
        hidden_list = []
        hidden_state = initial_state
        last_state = None
        for i in range(time_steps):
            hidden_state = self.RNNCell(inputs[:, i], hidden_state)
            hidden_list.append(hidden_state[-1])
            if i == time_steps - 1:
                """获取最后一时间步的输出隐藏状态"""
                last_state = hidden_state
            if self.padding == 'pre':
                """如果padding值填充在序列尾端,则正向时间步传播应加 mask 操作"""
                hidden_state = [
                    hidden_state[j] * mask[:, i:i + 1] + initial_state[j] * (1 - mask[:, i:i + 1])  # 重新初始化(加数项作用)
                    for j in range(len(hidden_state))
                ]
        sequence = torch.reshape(
            torch.unsqueeze(
                torch.concat(hidden_list, dim=1)
                , dim=1)
            , [batch_size, time_steps, -1]
        )

        # 反向时间步循环
        if self.bidirectional is True:
            hidden_list = []
            hidden_state = initial_state
            for i in range(time_steps, 0, -1):
                hidden_state = self.RNNCell(inputs[:, i - 1], hidden_state)
                hidden_list.insert(0, hidden_state[-1])
                if i == time_steps:
                    """获取最后一时间步的cell_state"""
                    last_state = [
                        torch.concat([last_state[j], hidden_state[j]], dim=1)
                        for j in range(len(hidden_state))
                    ]
                if self.padding == 'post':
                    """如果padding值填充在序列首端,则正反时间步传播应加 mask 操作"""
                    hidden_state = [
                        hidden_state[j] * mask[:, i - 1:i] + initial_state[j] * (1 - mask[:, i - 1:i])  # 重新初始化(加数项作用)
                        for j in range(len(hidden_state))
                    ]
            sequence = torch.concat([
                sequence,
                torch.reshape(
                    torch.unsqueeze(
                        torch.concat(hidden_list, dim=1)
                        , dim=1)
                    , [batch_size, time_steps, -1]
                )
            ], dim=-1)

        return sequence, last_state

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107
  • 108
  • 109
  • 110
  • 111
  • 112
  • 113
  • 114
  • 115
  • 116
  • 117
  • 118
  • 119
  • 120
  • 121
  • 122
  • 123
  • 124
  • 125
  • 126
  • 127
  • 128
  • 129
  • 130
  • 131
  • 132
  • 133
  • 134
  • 135
  • 136
  • 137
  • 138

  1. Hochreiter, S., & Schmidhuber, J. (1997). Long short-term memory. Neural computation, 9(8), 1735–1780. ↩︎

  2. 图片摘自:LSTM Forward and Backward Pass Introduction ↩︎

  3. 图片摘自:LSTM Forward and Backward Pass Introduction ↩︎

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/不正经/article/detail/673976
推荐阅读
相关标签
  

闽ICP备14008679号