当前位置:   article > 正文

Attention及其pytorch代码实现_pytorch attention

pytorch attention

基于RNN的Seq2Seq的基本假设:原始序列的最后一个隐含状态(一个向量)包含了该序列的全部信息。(这显然是不合理的)

Seg2Seg问题:记忆长序列能力不足
在这里插入图片描述

解决:当要生成一个目标语言单词的时候,不光考虑前一个时刻的状态和已经生成的单词,还要考虑当前要生成的单词和源句子中的哪些单词更加相关,即更关注源句子中的哪些词,这种做法就叫做注意力机制(Attention)
在这里插入图片描述

Attention

Luong等人在2015年发布的Effective Approaches to Attention-based Neural Machine Translation论文中,提出了attention技术,通过attention技术,seg2seg模型极大地提高了机器翻译的质量。

归其原因是:attention机制使得seg2seg模型可以有区分度、有重点地关注输入序列。

实例:请添加图片描述

  1. 假设模型已经生成单词“我”之后,要生成下一个单词;
  2. 显然和源语言中“love”关系最大,因此将源语言句子中的“love”对应的状态乘以一个比较大的权重,如0.6,而其余词的权重则较小;
  3. 最终将源语言句子中每个单词对应的状态加权求和,并用作新状态更新一个额外输出。

结合Attention机制的Seg2Seg模型

  • 结合attention, seg2seg模型decoder每次更新状态的时候都会再看一遍encoder的所有状态(decoder会知道去关注哪里)

  • 在Encoder结束工作之后,Attention和Decoder同时开始工作

  • attention可以简单的理解为:一种有效的加权求和技术,关键点在于如何获得权重。

  • 计算权重: α i = a l i g h ( h i , s 0 ) \alpha_i=aligh(h_i,s_0) αi=aligh(hi,s0)
    (相当于计算 h i h_i hi s 0 的 相 关 性 s_0的相关性 s0
    h i 为 E n c o d e r 的 隐 藏 层 状 态 , s 0 为 E n c o d e r 的 最 后 一 个 隐 藏 层 状 态 h_i为Encoder的隐藏层状态,s_0为Encoder的最后一个隐藏层状态 hiEncoders0Encoder
    (权重为0-1之间的数,加起来等于1)
  • 计算方法:
  1. Linear maps(线性变换):
    k i = W K ⋅ h i , f o r i = 1 t o m k_i=W_K·h_i,for i = 1 to m ki=WKhifori=1tom q 0 = W Q ⋅ s 0 q_0=W_Q·s_0 q0=WQs0
  2. Inner product(计算内积):
    α i ~ = k i T q 0 \tilde{\alpha_i} = \mathbf{k}^\mathrm{T}_iq_0 αi~=kiTq0
  3. Normalization:
    [ α 1 , ⋅ ⋅ ⋅ , α m ] = s o f t m a x ( [ α 1 ~ , ⋅ ⋅ ⋅ , α m ~ ] ) [\alpha_1,···,\alpha_m] = softmax([\tilde{\alpha_1},···,\tilde{\alpha_m}]) [α1,αm]=softmax([α1~,,αm~])

计算权重还有另一种方法,本文的代码中用到了,但现在更为主流的是第一种方法
在这里插入图片描述

  • 获得权重之后就是求取Context Vector:

c 0 = α 1 h 1 + ⋅ ⋅ ⋅ + α m h m c_0=\alpha_1h_1+···+\alpha_mh_m c0=α1h1++αmhm

  • 更新Decoder状态向量( s 0 = h m s_0=h_m s0=hm)
  1. SimpleRNN
    s 1 = t a n h ( A ′ ⋅ [ X 1 ′ s 0 ] + b ) s_1=tanh(A'·
    [X1s0]
    +b)
    s1=tanh(A[X1
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/从前慢现在也慢/article/detail/348342
推荐阅读
相关标签
  

闽ICP备14008679号