当前位置:   article > 正文

从LSTM到GRU基于门控的循环神经网络总结_写出两种基于门控的循环神经网络

写出两种基于门控的循环神经网络
1.概述
  • 为了改善基本RNN的长期依赖问题,一种方法是引入门控机制来控制信息的累积速度,包括有选择性地加入新的信息,并有选择性遗忘之前累积的信息。下面主要介绍两种基于门控的循环神经网络:长短时记忆网络和门控循环单元网络。因为基本的RNN即 h t = f ( U h t − 1 + W x t + b ) \mathbf{h}_{t}=f\left(U \mathbf{h}_{t-1}+W \mathbf{x}_{t}+\mathbf{b}\right) ht=f(Uht1+Wxt+b),每层的隐状态都是由前一层的隐状态经变换和激活函数得到的,反向传播求导时,最终得到的导数会包含每步梯度的连乘,会导致梯度爆炸或消失。所以,基本的RNN很难处理长期依赖问题,即无法学习到序列中蕴含的间隔时间较长的规律。
2.长短时记忆网络LSTM
  • 2.1长短时记忆网络是基本的循环神经网络的一种变体,可以有效的解决简单RNN的梯度爆炸或消失问题。LSTM网络主要改进在下面两个方面

    • 1.新的内部状态 c t \mathbf{c}_{t} ct:LSTM网络引入一个新的内部状态 c t \mathbf{c}_{t} ct,专门进行线性的循环信息传递,同时输出信息给隐藏层的外部状态 h t \mathbf{h}_{t} ht.
      c t = f t ⊙ c t − 1 + i t ⊙ c ~ t h t = o t ⊙ tanh ⁡ ( c t )
      ctamp;=ftct1+itc~thtamp;=ottanh(ct)
      ctht=ftct1+itc~t=ottanh(ct)

      符号说明: f t \mathbf{f}_{t} ft i t \mathbf{i}_{t} it o t \mathbf{o}_{t} ot分别代表遗忘门、输入门、输出门用来控制信息传递的路径;⊙表示向量元素的点乘; c t − 1 \mathbf{c}_{t-1} ct1表示上一时刻的记忆单元; c ~ t \tilde{\mathbf{c}}_{t} c~t表示通过非线性函数得到的候选状态。
      c ~ t = tanh ⁡ ( W c x t + U c h t − 1 + b c ) \tilde{\mathbf{c}}_{t}=\tanh \left(W_{c} \mathbf{x}_{t}+U_{c} \mathbf{h}_{t-1}+\mathbf{b}_{c}\right) c~t=tanh(Wcxt+Ucht1+bc)
      在每个时刻t,LSTM网络的内部状态 c t \mathbf{c}_{t} ct记录了到当前时刻为止的历史信息。
    • 2.门控机制:LSTM网络引入了门控机制,用来控制信息传递的路径, f t \mathbf{f}_{t} ft i t \mathbf{i}_{t} it o t \mathbf{o}_{t} ot分别代表遗忘门、输入门、输出门。这里的门概念类似于电路中的逻辑门概念,1表示开放状态,允许信息通过;0表示关闭状态,阻止信息通过。LSTM网络中的门是一个抽象的概念,借助sigmiod函数,使得输出值在(0,1)之间,表示以一定的比例运行信息通过。三个门的作用如下:
      • 遗忘门 f t \mathbf{f}_{t} ft控制上一时刻的内部状态 c t − 1 \mathbf{c}_{t-1} ct1需要遗忘多少信息
      • 输入门 i t \mathbf{i}_{t} it控制当前时刻的候选状态 c ~ t \tilde{\mathbf{c}}_{t} c~t有多少信息需要保存
      • 输出门 o t \mathbf{o}_{t} ot控制当前时刻的内部状态 c t \mathbf{c}_{t} ct有多少信息需要输出给外部状态 h t \mathbf{h}_{t} ht
        f t = 0 , i t = 1 \mathbf{f}_{t}=0, \mathbf{i}_{t}=1 ft=0,it=1时,记忆单元 c t \mathbf{c}_{t} ct将历史信息清空,并将候选状态向量 c ~ t \tilde{\mathbf{c}}_{t} c~t写入。但此时记忆单元 c t \mathbf{c}_{t} ct依然和上一时刻的历史信息相关。当 f t = 1 , i t = 0 \mathbf{f}_{t}=1, \mathbf{i}_{t}=0 ft=1,it=0时,记忆单元将复制上一时刻的内容,不写入新的信息。三个门的计算公式如下:
        i t = σ ( W i x t + U i h t − 1 + b i ) f t = σ ( W f x t + U f h t − 1 + b f ) o t = σ ( W o x t + U o h t − 1 + b o )
        itamp;=σ(Wixt+Uiht1+bi)ftamp;=σ(Wfxt+Ufht1+bf)otamp;=σ(Woxt+Uoht1+bo)
        itftot=σ(Wixt+Uiht1+bi)=σ(Wfxt+Ufht1+bf)=σ(Woxt+Uoht1+bo)

        其中,激活函数使用sigmoid函数,其输出区间是(0,1), x t \mathbf{x}_{t} xt表示当前时刻的输入, h t − 1 \mathbf{h}_{t-1} ht1表示上一时刻的外部状态。
  • 2.2 LSTM网络的循环单元结构如下图所示,计算过程如下:

    • a.利用上一时刻的外部状态 h t − 1 \mathbf{h}_{t-1} ht1和当前时刻的输入 x t \mathbf{x}_{t} xt,计算出三个门,已经候选状态 c ~ t \tilde{\mathbf{c}}_{t} c~t
      i t = σ ( W i x t + U i h t − 1 + b i ) f t = σ ( W f x t + U f h t − 1 + b f ) o t = σ ( W o x t + U o h t − 1 + b o )
      itamp;=σ(Wixt+Uiht1+bi)ftamp;=σ(Wfxt+Ufht1+bf)otamp;=σ(Woxt+Uoht1+bo)
      itftot=σ(Wixt+Uiht1+bi)=σ(Wfxt+Ufht1+bf)=σ(Woxt+Uoht1+bo)

    c ~ t = tanh ⁡ ( W c x t + U c h t − 1 + b c ) \tilde{\mathbf{c}}_{t}=\tanh \left(W_{c} \mathbf{x}_{t}+U_{c} \mathbf{h}_{t-1}+\mathbf{b}_{c}\right) c~t=tanh(Wcxt+Ucht1+bc)

    • b.结合遗忘门 f t \mathbf{f}_{t} ft和输入门 i t \mathbf{i}_{t} it来更新记忆单元 c t \mathbf{c}_{t} ct
      c t = f t ⊙ c t − 1 + i t ⊙ c ~ t \mathbf{c}_{t}=\mathbf{f}_{t} \odot \mathbf{c}_{t-1}+\mathbf{i}_{t} \odot \tilde{\mathbf{c}}_{t} ct=ftct1+itc~t
    • c.结合输出门 o t \mathbf{o}_{t} ot,将内部状态的信息传递给外部状态 h t \mathbf{h}_{t} ht
      h t = o t ⊙ tanh ⁡ ( c t ) \mathbf{h}_{t}=\mathbf{o}_{t} \odot \tanh \left(\mathbf{c}_{t}\right) ht=ottanh(ct)

LSTM Cell

3.门控循环单元网络GRU
  • GRU与LSTM的不同之处在于:GRU不引入额外的记忆单元 c t \mathbf{c}_{t} ct,GRU网络引入一个更新门来控制当前状态需要从历史状态中保留多少信息(不经过非线性变换),以及需要从候选状态中接收多少新的信息。
    h t = z t ⊙ h t − 1 + ( 1 − z t ) ⊙ g ( x t , h t − 1 ; θ ) \mathbf{h}_{t}=\mathbf{z}_{t} \odot \mathbf{h}_{t-1}+\left(1-\mathbf{z}_{t}\right) \odot g\left(\mathbf{x}_{t}, \mathbf{h}_{t-1} ; \theta\right) ht=ztht1+(1zt)g(xt,ht1;θ)
    其中, z t ∈ [ 0 , 1 ] \mathbf{z}_{t} \in[0,1] zt[0,1]为更新门
    z t = σ ( W z x t + U z h t − 1 + b z ) \mathbf{z}_{t}=\sigma\left(\mathbf{W}_{z} \mathbf{x}_{t}+\mathbf{U}_{z} \mathbf{h}_{t-1}+\mathbf{b}_{z}\right) zt=σ(Wzxt+Uzht1+bz)
    在GRU网络中,函数 g ( x t , h t − 1 ; θ ) g\left(\mathbf{x}_{t}, \mathbf{h}_{t-1} ; \theta\right) g(xt,ht1;θ)定义为:
    h ~ t = tanh ⁡ ( W h x t + U h ( r t ⊙ h t − 1 ) + b h ) \tilde{\mathbf{h}}_{t}=\tanh \left(W_{h} \mathbf{x}_{t}+U_{h}\left(\mathbf{r}_{t} \odot \mathbf{h}_{t-1}\right)+\mathbf{b}_{h}\right) h~t=tanh(Whxt+Uh(rtht1)+bh)
    上式中的符号说明: h ~ t \tilde{\mathbf{h}}_{t} h~t表示当前时刻的候选状态, r t ∈ [ 0 , 1 ] \mathbf{r}_{t} \in[0,1] rt[0,1]为重置门,用来控制候选状态 h ~ t \tilde{\mathbf{h}}_{t} h~t的计算是否依赖上一时刻的状态 h t − 1 \mathbf{h}_{t-1} ht1
    r t = σ ( W r x t + U r h t − 1 + b r ) \mathbf{r}_{t}=\sigma\left(W_{r} \mathbf{x}_{t}+U_{r} \mathbf{h}_{t-1}+\mathbf{b}_{r}\right) rt=σ(Wrxt+Urht1+br)
    r t \mathbf{r}_{t} rt=0时,候选状态 h ~ t = tanh ⁡ ( W c x t + b ) \tilde{\mathbf{h}}_{t}=\tanh \left(W_{c} \mathbf{x}_{t}+\mathbf{b}\right) h~t=tanh(Wcxt+b)只和当前输入 x t \mathbf{x}_{t} xt相关而与历史状态无关。当 r t \mathbf{r}_{t} rt=1时,候选状态 h ~ t = tanh ⁡ ( W h x t + U h h t − 1 + b h ) \tilde{\mathbf{h}}_{t}=\tanh \left(W_{h} \mathbf{x}_{t}+U_{h} \mathbf{h}_{t-1}+\mathbf{b}_{h}\right) h~t=tanh(Whxt+Uhht1+bh)和当前输入 x t \mathbf{x}_{t} xt相关,也和历史状态 h t − 1 \mathbf{h}_{t-1} ht1相关,此时和简单的RNN是一样的。
    综合上述各式,GRU网络的状态更新方式为:
    h t = z t ⊙ h t − 1 + ( 1 − z t ) ⊙ h ~ t \mathbf{h}_{t}=\mathbf{z}_{t} \odot \mathbf{h}_{t-1}+\left(1-\mathbf{z}_{t}\right) \odot \tilde{\mathbf{h}}_{t} ht=ztht1+(1zt)h~t
  • 总结:当 z t = 0 , r = 1 \mathbf{z}_{t}=0, \mathbf{r}=1 zt=0,r=1时,GRU网络退化为简单的RNN;若 z t = 0 , r = 0 \mathbf{z}_{t}=0, \mathbf{r}=0 zt=0,r=0时,当前状态 h t \mathbf{h}_{t} ht只和当前输入 x t \mathbf{x}_{t} xt相关,和历史状态 h t − 1 \mathbf{h}_{t-1} ht1无关。当 z t = 1 \mathbf{z}_{t}=1 zt=1时,当前状态 h t \mathbf{h}_{t} ht等于上一时刻状态 h t − 1 \mathbf{h}_{t-1} ht1和当前输入 x t \mathbf{x}_{t} xt无关。
    GRU Cell
3.实战:基于Keras的LSTM和GRU的文本分类
	import random
	import jieba
	import pandas as pd
	import numpy as np
	
	stopwords = pd.read_csv(r"E:\DeepLearning\jupyter_code\dataset\corpus\03_project\stopwords.txt", index_col=False, quoting=3, sep="\t", names=["stopword"], encoding="utf-8")
	stopwords = stopwords["stopword"].values
	
	# 加载语料
	laogong_df = pd.read_csv(r"E:\DeepLearning\jupyter_code\dataset\corpus\03_project\beilaogongda.csv", encoding="utf-8", sep=",")
	laopo_df = pd.read_csv(r"E:\DeepLearning\jupyter_code\dataset\corpus\03_project\beilaopoda.csv", encoding="utf-8", sep=",")
	erzi_df = pd.read_csv(r"E:\DeepLearning\jupyter_code\dataset\corpus\03_project\beierzida.csv", encoding="utf-8", sep=",")
	nver_df = pd.read_csv(r"E:\DeepLearning\jupyter_code\dataset\corpus\03_project\beinverda.csv", encoding="utf-8", sep=",")
	
	# 删除语料的nan行
	laogong_df.dropna(inplace=True)
	laopo_df.dropna(inplace=True)
	erzi_df.dropna(inplace=True)
	nver_df.dropna(inplace=True)
	
	# 转换
	laogong = laogong_df.segment.values.tolist()
	laopo = laopo_df.segment.values.tolist()
	erzi = erzi_df.segment.values.tolist()
	nver = nver_df.segment.values.tolist()
	
	# 分词和去掉停用词
	
	## 定义分词和打标签函数preprocess_text
	def preprocess_text(content_lines, sentences, category):
	    # content_lines是上面转换得到的list
	    # sentences是空的list,用来存储打上标签后的数据
	    # category是类型标签
	    for line in content_lines:
	        try:
	            segs = jieba.lcut(line)
	            segs = [v for v in segs if not str(v).isdigit()]  # 除去数字
	            segs = list(filter(lambda x: x.strip(), segs))  # 除去左右空格
	            segs = list(filter(lambda x: len(x) > 1, segs))  # 除去长度为1的字符
	            segs = list(filter(lambda x: x not in stopwords, segs))  # 除去停用词
	            sentences.append((" ".join(segs), category))  # 打标签
	        except Exception:
	            print(line)
	            continue
	
	# 调用上面函数,生成训练数据
	sentences = []
	preprocess_text(laogong, sentences, 0)
	preprocess_text(laopo, sentences, 1)
	preprocess_text(erzi, sentences, 2)
	preprocess_text(nver, sentences, 3)
	
	# 先打乱数据,使得数据分布均匀,然后获取特征和标签列表
	random.shuffle(sentences)  # 打乱数据,生成更可靠的训练集
	for sentence in sentences[:10]:    # 输出前10条数据,观察一下
	    print(sentence[0], sentence[1])
	
	# 所有特征和对应标签
	all_texts = [sentence[0] for sentence in sentences]
	all_labels = [sentence[1] for sentence in sentences]
	
	
	# 使用LSTM对数据进行分类
	from keras.preprocessing.text import Tokenizer
	from keras.preprocessing.sequence import pad_sequences
	from keras.utils import to_categorical
	from keras.layers import Dense, Input, Flatten, Dropout
	from keras.layers import LSTM, Embedding, GRU
	from keras.models import Sequential
	
	
	# 预定义变量
	MAX_SEQENCE_LENGTH = 100   # 最大序列长度
	EMBEDDING_DIM = 200   # 词嵌入维度
	VALIDATION_SPLIT = 0.16   # 验证集比例
	TEST_SPLIT = 0.2  # 测试集比例
	
	# 使用keras的sequence模块文本序列填充
	tokenizer = Tokenizer()
	tokenizer.fit_on_texts(all_texts)
	sequences = tokenizer.texts_to_sequences(all_texts)
	word_index = tokenizer.word_index
	print("Found %s unique tokens." % len(word_index))
	
	
	data = pad_sequences(sequences, maxlen=MAX_SEQENCE_LENGTH)
	labels = to_categorical(np.asarray(all_labels))
	print("data shape:", data.shape)
	print("labels shape:", labels.shape)
	
	# 数据切分
	p1 = int(len(data) * (1 - VALIDATION_SPLIT - TEST_SPLIT))
	p2 = int(len(data) * (1 - TEST_SPLIT))
	
	# 训练集
	x_train = data[:p1]
	y_train = labels[:p1]
	
	# 验证集
	x_val = data[p1:p2]
	y_val = labels[p1:p2]
	
	# 测试集
	x_test = data[p2:]
	y_test = labels[p2:]
	
	# LSTM训练模型
	model = Sequential()
	model.add(Embedding(len(word_index) + 1, EMBEDDING_DIM, input_length=MAX_SEQENCE_LENGTH))
	model.add(LSTM(200, dropout=0.2, recurrent_dropout=0.2))
	model.add(Dropout(0.2))
	model.add(Dense(64, activation="relu"))
	model.add(Dense(labels.shape[1], activation="softmax"))
	model.summary()
	
	# 模型编译
	model.compile(loss="categorical_crossentropy", optimizer="rmsprop", metrics=["acc"])
	print(model.metrics_names)
	
	model.fit(x_train, y_train, validation_data=(x_val, y_val), epochs=10, batch_size=128)
	model.save("lstm.h5")
	# 模型评估
	print(model.evaluate(x_test, y_test))
	
	
	
	# 使用GRU模型
	model = Sequential()
	model.add(Embedding(len(word_index) + 1, EMBEDDING_DIM, input_length=MAX_SEQENCE_LENGTH))
	model.add(GRU(200, dropout=0.2, recurrent_dropout=0.2))
	model.add(Dropout(0.2))
	model.add(Dense(64, activation="relu"))
	model.add(Dense(labels.shape[1], activation="softmax"))
	model.summary()
	
	model.compile(loss="categorical_crossentropy", optimizer="rmsprop", metrics=["acc"])
	print(model.metrics_names)
	
	model.fit(x_train, y_train, validation_data=(x_val, y_val), epochs=10, batch_size=128)
	model.save("gru.h5")
	
	print(model.evaluate(x_test, y_test))
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107
  • 108
  • 109
  • 110
  • 111
  • 112
  • 113
  • 114
  • 115
  • 116
  • 117
  • 118
  • 119
  • 120
  • 121
  • 122
  • 123
  • 124
  • 125
  • 126
  • 127
  • 128
  • 129
  • 130
  • 131
  • 132
  • 133
  • 134
  • 135
  • 136
  • 137
  • 138
  • 139
  • 140
  • 141
  • 142
4.本文代码及数据集下载
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/我家小花儿/article/detail/663212
推荐阅读
相关标签
  

闽ICP备14008679号