当前位置:   article > 正文

(深度学习)Pytorch自己动手不调库实现LSTM_手动实现lstm

手动实现lstm

深度学习)Pytorch自己动手不调库实现LSTM

此文为Pytorch深度学习的第三篇文章,在上一篇文章(深度学习)Pytorch进阶之实现AlexNet中我们不调库手动实现了AlexNet,今天我们尝试更具挑战性的,手动实现LSTM。

LSTM(Long short-term memory)是一种特殊的RNN。通过精巧的设计解决长序列训练过程中的远距离传递导致的信息丢失问题。

标准RNN由简单的神经网络模块按时序展开成链式。这个重复模块往往结构简单且单一,如一个tanh层。这种记忆叠加方式简单粗暴,容易导致长序列训练过程中的梯度消失和梯度爆炸问题。相比之下,LSTM内部有如下图所示较为复杂的结构。能通过门控状态来选择调整传输的信息,记住需要长时记忆的信息,忘记不重要的信息,进而解决了梯度消失和梯度爆炸问题。

在这里插入图片描述

本文中所展示的LSTM代码,等价于nn.LSTM中batch_first=True的效果。

首先在初始化的过程中,进行了输入门i_t、遗忘门f_t、输出门o_t、候选内部状态g_t的参数初始化:

def __init__(self, input_size, hidden_size):
    super().__init__()
    self.input_size = input_size
    self.hidden_size = hidden_size
    # 输入门i_t
    self.W_i = Parameter(torch.Tensor(input_size, hidden_size))
    self.U_i = Parameter(torch.Tensor(hidden_size, hidden_size))
    self.b_i = Parameter(torch.Tensor(hidden_size))
    # 遗忘门f_t
    self.W_f = Parameter(torch.Tensor(input_size, hidden_size))
    self.U_f = Parameter(torch.Tensor(hidden_size, hidden_size))
    self.b_f = Parameter(torch.Tensor(hidden_size))
    # 候选内部状态g_t
    self.W_g = Parameter(torch.Tensor(input_size, hidden_size))
    self.U_g = Parameter(torch.Tensor(hidden_size, hidden_size))
    self.b_g = Parameter(torch.Tensor(hidden_size))
    # 输出门o_t
    self.W_o = Parameter(torch.Tensor(input_size, hidden_size))
    self.U_o = Parameter(torch.Tensor(hidden_size, hidden_size))
    self.b_o = Parameter(torch.Tensor(hidden_size))

    # 初始化参数
    self._initialize_weights()

def _initialize_weights(self):
    for p in self.parameters():
        if p.data.ndimension() >= 2:
            nn.init.xavier_uniform_(p.data)
        else:
            nn.init.zeros_(p.data)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30


在前向传播forward函数中,先进行内部状态的初始化:

def _init_states(self, x):
    h_t = torch.zeros(1, x.size(0), self.hidden_size, dtype=x.dtype).to(x.device)
    c_t = torch.zeros(1, x.size(0), self.hidden_size, dtype=x.dtype).to(x.device)
    return h_t, c_t

def forward(self, x, init_states=None):
    """
    在这里我定义x的输入格式是(batch, sequence, feature)
    """
    batch_size, seq_size, _ = x.size()
    hidden_seq = []

    # 状态初始化
    if init_states is None:
        h_t, c_t = self._init_states(x)
    else:
        h_t, c_t = init_states
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17

而后,按时间顺序迭代,更新门组件及内部候选状态,进而对记忆单元和隐藏单元进行更新:

# 按时间顺序迭代
for t in range(seq_size):
    x_t = x[:, t, :]
    # 更新门组件及内部候选状态(Tips:Pytorch中@用于矩阵相乘,*用于逐个元素相乘)
    i_t = torch.sigmoid(x_t @ self.W_i + h_t @ self.U_i + self.b_i)
    f_t = torch.sigmoid(x_t @ self.W_f + h_t @ self.U_f + self.b_f)
    g_t = torch.tanh(x_t @ self.W_g + h_t @ self.U_g + self.b_g)
    o_t = torch.sigmoid(x_t @ self.W_o + h_t @ self.U_o + self.b_o)
    # 记忆单元和隐藏单元更新
    c_t = f_t * c_t + i_t * g_t
    h_t = o_t * torch.tanh(c_t)
    hidden_seq.append(h_t)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12

最后,返回隐藏单元输出的序列以及最新的记忆单元和隐藏单元:

hidden_seq = torch.cat(hidden_seq, dim=Dim.batch)
hidden_seq = hidden_seq.transpose(Dim.batch, Dim.seq).contiguous()
return hidden_seq, (h_t, c_t)
  • 1
  • 2
  • 3




完整代码如下:

import torch
import torch.nn as nn
from torch.nn import Parameter
from enum import IntEnum


class Dim(IntEnum):
    batch = 0
    seq = 1
    feature = 2


class LSTM_batchfirst(nn.Module):
    """
    自己构造的LSTM
    等价于nn.LSTM中batch_first=True的效果
    """

    def __init__(self, input_size, hidden_size):
        super().__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        # 输入门i_t
        self.W_i = Parameter(torch.Tensor(input_size, hidden_size))
        self.U_i = Parameter(torch.Tensor(hidden_size, hidden_size))
        self.b_i = Parameter(torch.Tensor(hidden_size))
        # 遗忘门f_t
        self.W_f = Parameter(torch.Tensor(input_size, hidden_size))
        self.U_f = Parameter(torch.Tensor(hidden_size, hidden_size))
        self.b_f = Parameter(torch.Tensor(hidden_size))
        # 候选内部状态g_t
        self.W_g = Parameter(torch.Tensor(input_size, hidden_size))
        self.U_g = Parameter(torch.Tensor(hidden_size, hidden_size))
        self.b_g = Parameter(torch.Tensor(hidden_size))
        # 输出门o_t
        self.W_o = Parameter(torch.Tensor(input_size, hidden_size))
        self.U_o = Parameter(torch.Tensor(hidden_size, hidden_size))
        self.b_o = Parameter(torch.Tensor(hidden_size))

        # 初始化参数
        self._initialize_weights()

    def _initialize_weights(self):
        for p in self.parameters():
            if p.data.ndimension() >= 2:
                nn.init.xavier_uniform_(p.data)
            else:
                nn.init.zeros_(p.data)

    def _init_states(self, x):
        h_t = torch.zeros(1, x.size(0), self.hidden_size, dtype=x.dtype).to(x.device)
        c_t = torch.zeros(1, x.size(0), self.hidden_size, dtype=x.dtype).to(x.device)
        return h_t, c_t

    def forward(self, x, init_states=None):
        """
        在这里我定义x的输入格式是(batch, sequence, feature)
        """
        batch_size, seq_size, _ = x.size()
        hidden_seq = []

        # 状态初始化
        if init_states is None:
            h_t, c_t = self._init_states(x)
        else:
            h_t, c_t = init_states

        # 按时间顺序迭代
        for t in range(seq_size):
            x_t = x[:, t, :]
            # 更新门组件及内部候选状态(Tips:Pytorch中@用于矩阵相乘,*用于逐个元素相乘)
            i_t = torch.sigmoid(x_t @ self.W_i + h_t @ self.U_i + self.b_i)
            f_t = torch.sigmoid(x_t @ self.W_f + h_t @ self.U_f + self.b_f)
            g_t = torch.tanh(x_t @ self.W_g + h_t @ self.U_g + self.b_g)
            o_t = torch.sigmoid(x_t @ self.W_o + h_t @ self.U_o + self.b_o)
            # 记忆单元和隐藏单元更新
            c_t = f_t * c_t + i_t * g_t
            h_t = o_t * torch.tanh(c_t)
            hidden_seq.append(h_t)
        hidden_seq = torch.cat(hidden_seq, dim=Dim.batch)
        hidden_seq = hidden_seq.transpose(Dim.batch, Dim.seq).contiguous()
        return hidden_seq, (h_t, c_t)

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/你好赵伟/article/detail/576631
推荐阅读
相关标签
  

闽ICP备14008679号