当前位置:   article > 正文

​深度学习基础 | Seq2seq+Attention

seq2seq+attention

75becaaaafea080a4541af6c095fa764.png

作者 | Chilia  

整理 | NewBeeNLP

首先,请阅读先修知识:

1. Seq2seq

1.1 seq2seq的训练

82e29550e9929ebcc6f6bf9f06242122.png

可以看出,整个seq2seq模型分为两大部分:Encoder RNN和Decoder RNN。

「Encoder RNN」部分,先将待翻译的原文经过一个RNN (这里可以是vanilla RNN, LSTM,GRU等等),并且使用最后一个隐藏状态作为整句话的encoding表示,作为Decoder RNN的初始隐藏状态输入到Decoder RNN中去。

「Decoder RNN」部分,每个timestep的「输入」就是翻译后的词语embedding,将每一步的隐藏状态经过全连接层,得到整个词汇表每个词的概率分布,然后和实际的词语(one-hot编码)去对比,得到交叉熵损失。将所有的交叉熵损失求平均,即可得到整体的损失。

1.2 Seq2seq的测试
1.2.1 网络结构

d989532791305d24505be9de842d09eb.png

「decoder RNN」 的第一个输入是 「<START>」, 每一步的hidden state再经过一个全连接层得到整个词汇表的概率分布,之后取一个概率最大的词(argmax)作为此时的翻译。每一个 timestep 的翻译词作为下一个timestep的输入,以此继续,直到最后输出**<END>**

1.2.2 解决局部极小 – beam search

贪婪的decoding方法就是每一步都选概率最大的那个词语作为下一步的输入:

b8cdc277c906a2e0c185c335c4f7d52c.png

这种方法的问题就是无法回退->可能陷入局部极小值(local minima).

所以,我们引入了「beam search」的方法。

beam search的核心思想是,每一步都考虑 k 个最可能的翻译词语, k 叫 「beam size」。最后得到若干个翻译结果,这些结果叫hypothesis,然后选择一个概率最大的hypothesis即可。

例如,当选择k = 2时,翻译 il a'm' entarte:

da41275975df712a3040602cf20169b6.png

<START>后面概率最大的两个词是"I"和"he",概率的log值分别为-0.9和-0.7。我们就取这两个词继续往下递归,"<start>he"再往后概率最大的两个词是"hit"和"struck",其概率的对数加上之前的-0.7,分别得到-1.7和-2.9,也就是"he hit"和"he struck"的得分分别为 -1.7 和-2.9 。同理,“I was" 和 "I got"的得分分别为 -1.6 和-1.8. 在这一步,又取得分最高的两句话"he hit"(-1.7)和"I was"(-1.6)往下递归,在此省略若干步骤…

迭代若干步之后得到:

213cbec2e681134a87f68804940afba0.png

取最终得分最高的那句话,“he hit me with a pie.”

42f0eca5dd0ecfaf9c3ce77f3b3317d0.png

「beam search 的终止条件:」

在 beam search中,不同的词语选择方法会导致在不同的时候出现<END>,所以每个hypothesis(翻译句子)的长度都不一样。当一个hypothesis最终预测出了<END>,就说明这句话已经预测完毕了,就可以把这句话先放在一边,然后再用beam search取搜索其他的句子。

所以,beam search的终止条件可以为:

  • 达到 timestep T

  • 获得了 n 个完整的hypothesis

「选择最好的hypothesis:」

对于一个需要我们翻译的外文句子,有若干个可能的hypothesis,对每个hypothesis都计算一个score:

6437abd3fbb4a8c0e4211bc0802bbb8e.png

但是这个有一个非常明显的问题!就是越长的句子score越低,因为连乘导致了概率值越来越小。当然解决这个问题的方法也很简单了,就是对每个句子的长度做一个平均。

但是,回忆之前的这个例子:

bcae3eb304161ec519e8bbadbb3060d9.png

我们好像没有对长度做平均对吧?这是因为每一步的句子长度都是一样的,平均和不平均没有任何区别。

1.3 seq2seq 的其他应用
  • 文本摘要 (长文本->短文本)

  • 对话(previous utterance->next utterance)

  • 代码生成(自然语言->python代码),当然这个太难了…

1.4 机器翻译评估

BLEU(Bilingual Evaluation Understudy) 比较了我们的机器翻译结果和人翻译的结果(ground truth)进行相似度计算。这里所谓的“相似度”就是用两者的1-gram, 2-gram… n-gram重合度来算的。

同时,别忘了对长度较短的翻译做一个惩罚,这是因为如果我们让翻译特别短,只翻译出那些特别确定的词,那么n-gram重合度一定高,但是这并不是一个很好的翻译!

BLEU评测很有效,但有的时候并不完善。这是因为对于一句外文,有很多可能的翻译方式,如果只用n-gram进行精确匹配,可能会导致一个原本很好的翻译评分很低。

1.5 机器翻译的困难
  • 对于一个我们根本没见过的词 (Out-of-vocabulary words, 「OOV」) 该如何翻译呢?(可以随机初始化input embedding,这样让decoder根据language model随便翻译一个;或者直接写下来; 或者使用sub-word model)

  • 训练集和测试集必须非常相似。(如果你用Wikipedia这种非常正式的语料库训练,再用人们在twitter上聊天做测试,效果一定不好)

  • 对于长文本的翻译比较困难

  • 某些语言平行语料非常少(如Thai-Eng)

  • 训练语料的一些bias也会被机器翻译算法学去,导致在翻译测试的时候会体现出这种bias。

例如:

49b05936ebc2e23404c95a240d28a667.png

在训练语料里面she比较有可能是nurse;he比较有可能是programmer,所以出现了&amp;amp;quot;性别歧视&amp;amp;quot;

再看一个更可怕的例子:

2220b88c3d8fe24efcda608b54ba6f2f.png

这是因为Somali-Eng平行语料库主要是基于《圣经》,所以在这里机器翻译算法只是在用language model生成一些随机的词语而已…

2. Attention机制

2.1 没有Attention会怎样?

e7388f798dbdebcdedd0dc030f31eb62.png

  • the orange box is the encoding of the 「whole」 source sentence, it need to capture 「all」 the information of the source sentence —> 「information bottleneck」! (too much pressure on this single vector to be a good representation of the source sentence.)

  • the target sentence just have one input, i.e. the orange box, so it does not provide 「location」 information.

所以,attention机制就是为了解决这种「information bottleneck」的问题才引入的。

2.2 Attention的核心思想

on each step of the decoder, use 「direct connection」 to the encoder to focus on a particular part of the source sentence.

「step1.」 对于decoder的每一个timestep t,都计算它和encoder的每一步的点乘,作为score。之后再把这些score做softmax,变成概率分布。可以看到,第一个柱子最高,说明我们在翻译<start>的时候,需要格外注意source sentence的第一个位置。

f4c19313a3a399a1b5beeca33d722cda.png

用概率分布去乘以encoder每一步的hidden state,得到一个加权的source sentence表示:

**step2:**用概率分布去乘以encoder每一步的hidden state,得到一个加权的source sentence表示。

10bf3f44456f6227d11f44693929713a.png

「step3:「之后,把source sentence的表示和decoderRNN每一步的hidden state」拼接」在一起,得到一个长向量。然后再经过一层全连接网络,得到整个词汇表的概率分布。取一个argmax即得到这一步的预测值。

ee244a3d779ef7c3cff1e6015faa44b6.png

52cefdce11b6c06f3ae664548af98c5e.png

.... ....

59b961b6b8e0c5509a57330dda4b03c2.png

2.3 Attention的好处
  • Attention大大提高了机器翻译的表现

    • Decoder在每一步都更加关注源语言的不同部分

  • Attention 解决了bottleneck problem

    • Decoder 可以直接关注到源语言的每个词,而不需要用一个向量来表示源句子

  • Attention 可以缓解梯度消失问题

    • 类似skip-connection

  • Attention 增加了可解释性

    • 可以直观的看出来decoder在每一步更关注源语言的哪些部分

d0d7f27051e24288ba1d9fb3c3c4959c.png

Attention helps us get alignment for FREE!

一起交流

想和你一起学习进步!『NewBeeNLP』目前已经建立了多个不同方向交流群(机器学习 / 深度学习 / 自然语言处理 / 搜索推荐 / 图网络 / 面试交流 / 等),名额有限,赶紧添加下方微信加入一起讨论交流吧!(注意一定要备注信息才能通过)

9d5828a047d05aec5118a7fa5540f362.png

本文参考资料

[1]

cs224n-2019-lecture08-nmt: https://web.stanford.edu/class/cs224n/slides/cs224n-2019-lecture08-nmt.pdf

END -

c6c3ddfc4f47f6e7d6bbaeda98c2a216.png

4c9941533a5c54a0ec0a4dbcf4705dcb.png

2021最新 北京互联网公司

2021-10-24

f1d107da67012fc21c01bdf3f7efe45b.png

『优势特征知识蒸馏』在淘宝推荐中的应用

2021-10-22

9aece9dbe53112edb38c2d08ffed5935.png

【NLP保姆级教程】手把手带你RNN文本分类(附代码)

2021-10-22

f49047540b73077c8f14ab5365db270a.png

预训练模型,NLP的版本答案!

2021-10-20

5dfa636d8ac60a11eb7aaac4bb4d97e4.png

051c29222fd9a90faaecd828f2c99524.gif

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/盐析白兔/article/detail/869994?site
推荐阅读
相关标签
  

闽ICP备14008679号