当前位置:   article > 正文

RNN和LSTM详解_lstm 分类 损失函数

lstm 分类 损失函数

1. Recurrent Neural Networks(RNN)

1.1 模型

在这里插入图片描述
h t = t a n h [ W h x X t + W h h h t − 1 + b h ] h_t = tanh[W_{hx}X_t + W_{hh}h_{t-1}+b_h] ht=tanh[WhxXt+Whhht1+bh]
z t = f ( W h y h t + b z ) z_t=f(W_{hy}h_t+b_z) zt=f(Whyht+bz)

  • t a n h ( v ) = e x p ( 2 v ) − 1 e x p ( 2 v ) + 1 tanh(v) = \frac{exp(2v)-1}{exp(2v)+1} tanh(v)=exp(2v)+1exp(2v)1
  • W h h , W x h , W h y W_{hh},W_{xh},W_{hy} Whh,Wxh,Why都是可训练的权重矩阵。
  • b h , b z b_h,b_z bh,bz都是可训练的偏差向量。
  • X t X_t Xt z t z_t zt分别是时间 t t t的输入和输出。

1.2 损失函数

L τ ( θ ) = ∑ t ∈ τ L ( y t , z t ) L_\tau(\theta) = \sum_{t\in\tau}L(y_t,z_t) Lτ(θ)=tτL(yt,zt)
这里的 τ \tau τ是输出序列。

1.3 不同形态的RNN

在这里插入图片描述
应用场景:

  • One-to-many: image captioning;
  • Many-to-one: text sentiment classification;
  • Many-to-many: machine translation.

1.4 多层RNN

回想一下单层RNN:
h t = t a n h [ W h x X t + W h h h t − 1 + b h ] = t a n h [ W ( X t h t − 1 1 ) ] h_t = tanh[W_{hx}X_t + W_{hh}h_{t-1}+b_h]=tanh

[W(Xtht11)]
ht=tanh[WhxXt+Whhht1+bh]=tanhWXtht11

多层RNN是单层RNN堆叠而来的:
在这里插入图片描述

h t l = t a n h [ W ( h t l − 1 h t − 1 1 ) ] h_t^l =tanh

[W(htl1ht11)]
htl=tanhWhtl1ht11

高层的隐含状态 h t l h_t^l htl由老的状态 h t − 1 l h_{t-1}^l ht1l和低层的隐含状态 h t ( l − 1 ) h_t^(l-1) ht(l1)决定。

1.5 RNN存在的问题

普通RNN的一个显著缺点是,当序列长度很大时,RNN难以捕获序列数据中的长依赖项。这有时是梯度消失/爆炸造成的。
在下面的例子中,计算 ∂ L τ ∂ h 1 \frac{\partial L_\tau}{\partial h_1} h1Lτ时,根据链式求导法则,我们需要计算 ∏ t = 1 3 ( ∂ h t + 1 ∂ h t ) \prod_{t=1}^3(\frac{\partial h_{t+1}}{\partial h_t}) t=13(htht+1)
在这里插入图片描述
如果序列很长,这个乘积将是许多雅可比矩阵的乘积,这通常会得到指数大或指数小的奇异值。

2. LSTM/GRU

2.1 概述

先回顾一下单层RNN:
h t = t a n h [ W h x X t + W h h h t − 1 + b h ] = t a n h [ W ( X t h t − 1 1 ) ] h_t = tanh[W_{hx}X_t + W_{hh}h_{t-1}+b_h]=tanh

[W(Xtht11)]
ht=tanh[WhxXt+Whhht1+bh]=tanhWXtht11

对比LSTM:
( i t f t o t c t ) = ( σ σ σ t a n h ) W ( h t − 1 x t 1 )

(itftotct)
=
(σσσtanh)
W
(ht1xt1)
itftotct=σσσtanhWht1xt1

其中, σ \sigma σsigmoid函数

LSTM可以删除或者添加信息到状态,并被叫“门”的结构(包括遗忘门、输入门、输出门)所限制。
在这里插入图片描述

2.2 遗忘门(Forget gate)

在这里插入图片描述

功能:保存旧的信息
f t = σ [ W f ( X t h t − 1 1 ) ] f_t =\sigma

[Wf(Xtht11)]
ft=σWfXtht11

理想情况下,遗忘门的输出具有接近二进制的值,例如,当 f t f_t ft的输出接近1时可能表明输入序列中存在某个特征。

2.3 输入门(Input gate)

在这里插入图片描述
功能:更新记忆

i t = σ [ W i ( X t h t − 1 1 ) ] i_t =\sigma

[Wi(Xtht11)]
it=σWiXtht11
c ˉ t = t a n h [ W c ( X t h t − 1 1 ) ] \bar c_t=tanh
[Wc(Xtht11)]
cˉt=tanhWcXtht11

2.4 输入门和遗忘门的合并

在这里插入图片描述
c t = f t ⊙ c t − 1 + i t ⊙ c ˉ t c_t=f_t\odot c_{t-1}+i_t \odot \bar c_t ct=ftct1+itcˉt

⊙ \odot 表示两个矩阵对应位置元素进行乘积

2.4 输出门(Output gate)

在这里插入图片描述
功能:决定有多少记忆 c t c_t ct影响输出 h t h_t ht

o t = σ [ W o ( X t h t − 1 1 ) ] o_t =\sigma

[Wo(Xtht11)]
ot=σWoXtht11

h t = o t ⊙ t a n h ( c t ) h_t=o_t \odot tanh(c_t) ht=ottanh(ct)

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/Cpp五条/article/detail/488334
推荐阅读
相关标签
  

闽ICP备14008679号