当前位置:   article > 正文

【深度学习】BERT是什么?怎么玩的?_bert rnn

bert rnn

RNN

也是一种Seq2Seq网络
RNN
这种RNN就不能并行运算,且对于长句子会造成损失遗忘或者梯度爆炸

Transfomer

Transformer由且仅由self-Attenion和Feed Forward Neural Network组成。一个基于Transformer的可训练的神经网络可以通过堆叠Transformer的形式进行搭建,作者的实验是通过搭建编码器和解码器各6层,总共12层的Encoder-Decoder,并在机器翻译中取得了BLEU值得新高。

解决的问题有2个:

  • 1)并行计算要求
  • 2) 解决RNN中 对远距离的词记忆效果弱

总体结构

结构

Transformer的编码部分是由6个编码器,同样解码有6个解码器组成
Transformer
每个编码器的结构是先做一个self-attention,得到attention z,然后用一个全连接FNN对z的降维;输入的序列,先流入第一层self-attention,计算出当前词与其他单词的

Encoder

Encoder
上述主要会用到两个网络部分:1)self-attention计算序列中的token与其他token的attetion值; 2) FNN 全连接层,两层第一层ReLU,第二层线性激活函数

encoder中的详细过程

Decoder

Decoder
Decoder结构
decoder是一个自回归的网络,根据前面的token,计算后面的token;

  • 1)Self-Attention:当前翻译和已经翻译的前文之间的关系;
  • 2)Encoder-Decnoder Attention:当前翻译和编码的特征向量之间的关系。

计算过程

Self-attention

1)把每一个词编程词向量,文章用的是Xi是512维的,而乘出来后的q,k,v这些新向量是64维的。这样做的目的是可以持续计算多头
计算Q、K、V
当得到Q、K、V三个矩阵时,便可以计算每个词与其他词的得分;这里用一个点积运算,可以求出某两个词的相关性;

2)计算socre
计算score
softmax得分表示出每一个单词在此位置的分量,比如thinking在这个句子中对machine只占0.12

5) 将每个词的value*softmax得分,凭直观可以看出只关注那些我们注意的词,而drop-out那些无关的词(只需要乘以一个足够小的数)

6)把softmax*v加起来便可得到当前词对于整个序列而言的attetion
attention

关于计算attention进一步解释

要把分值转换成一个概率,所以这里用到一个softmax,便可以得到一个词与其他词的分值,然后与其他词value做点积,便可得到self-attention的值
self-attention
这个的最后用softmax输出与z做一个点积运算 求sum,单个词的对其他词的attention,然后在乘以V(实际特征信息)

attention计算公式

多头机制multi-headed

  • 1)让模型可以关注到不同的位置,要知道每一个attention都是体现当前词对整个序列的影响,当我的W不同的时候,得到的Q,K,V都不同,直接计算出来的attention都不一样;
  • 2)可以得到attetion 层多种不同的子空间;
    多头
    多头机制
    多头计算过程详细解释
    首先会有多个W*矩阵去跟词嵌入矩阵做乘积,得到QKV,然后分别计算attention,拿到多个头的注意力,最后拼接成一个大的Z,在经过FF网络输出每个词经过模型后的embed
    多头attention计算过程

位置信息表达

词的位置会产生影响,所以add一个vector到输入的嵌入层,引入的这个vector要能够表达出整个序列词的顺序,还有不同词之间的distance信息;
位置信息

BERT

结构

BERT其实就是transformer的编码器部分,其结构如下所示,首先是embedding层,分成三个部分,词嵌入、位置潜入、token类型嵌入;

embeddings.word_embeddings.weight torch.Size([173347, 768]) 
embeddings.position_embeddings.weight torch.Size([512, 768]) 
embeddings.token_type_embeddings.weight torch.Size([2, 768]) 
embeddings.LayerNorm.weight torch.Size([768]) 
embeddings.LayerNorm.bias torch.Size([768])

encoder.layer.0.attention.self.query.weight torch.Size([768, 768]) 
encoder.layer.0.attention.self.query.bias torch.Size([768]) 
encoder.layer.0.attention.self.key.weight torch.Size([768, 768]) 
encoder.layer.0.attention.self.key.bias torch.Size([768]) 
encoder.layer.0.attention.self.value.weight torch.Size([768, 768])
encoder.layer.0.attention.self.value.bias torch.Size([768]) 
encoder.layer.0.attention.output.dense.weight torch.Size([768, 768]) 
encoder.layer.0.attention.output.dense.bias torch.Size([768]) 
encoder.layer.0.attention.output.LayerNorm.weight torch.Size([768]) 
encoder.layer.0.attention.output.LayerNorm.bias torch.Size([768]) 
encoder.layer.0.intermediate.dense.weight torch.Size([3072, 768]) 
encoder.layer.0.intermediate.dense.bias torch.Size([3072]) 
encoder.layer.0.output.dense.weight torch.Size([768, 3072]) 
encoder.layer.0.output.dense.bias torch.Size([768]) 
encoder.layer.0.output.LayerNorm.weight torch.Size([768]) 
encoder.layer.0.output.LayerNorm.bias torch.Size([768]) 


encoder.layer.1.attention.self.query.weight torch.Size([768, 768]) 
encoder.layer.1.attention.self.query.bias torch.Size([768]) 
encoder.layer.1.attention.self.key.weight torch.Size([768, 768])
encoder.layer.1.attention.self.key.bias torch.Size([768])
encoder.layer.1.attention.self.value.weight torch.Size([768, 768]) 
encoder.layer.1.attention.self.value.bias torch.Size([768])
encoder.layer.1.attention.output.dense.weight torch.Size([768, 768]) 
encoder.layer.1.attention.output.dense.bias torch.Size([768]) 
encoder.layer.1.attention.output.LayerNorm.weight torch.Size([768]) 
encoder.layer.1.attention.output.LayerNorm.bias torch.Size([768]) 
encoder.layer.1.intermediate.dense.weight torch.Size([3072, 768])
encoder.layer.1.intermediate.dense.bias torch.Size([3072]) encoder.layer.1.output.dense.weight torch.Size([768, 3072]) encode
r.layer.1.output.dense.bias torch.Size([768]) encoder.layer.1.output.LayerNorm.weight torch.Size([768]) encoder.layer.1.output.LayerNorm.bias torch.Size([768]) encoder.layer.2.attention.self.query.weight torch.Size([768, 768]) encoder.layer.2.attention.self.query.bias torch.Size([768]) 

encoder.layer.2.attention.self.key.weight torch.Size([768, 768]) encoder.layer.2.attention.self.key.bias torch.Size([768]) encoder.layer.2.attention.self.value.weight torch.Size([768, 768]) encoder.layer.2.attention.self.value.bias torch.Size([768]) encoder.layer.2.attention.output.dense.weight torch.Size([768, 768]) encoder.layer.2.attention.output.dense.bias torch.Size([768]) encoder.layer.2.attention.output.LayerNorm.weight torch.Size([768]) encoder.layer.2.attention.output.LayerNorm.bias torch.Size([768]) encoder.layer.2.intermediate.dense.weight torch.Size([3072, 768]) encoder.layer.2.intermediate.dense.bias torch.Size([3072]) encoder.layer.2.output.dense.weight torch.Size([768, 3072]) encoder.layer.2.output.dense.bias torch.Size([768]) encoder.layer.2.output.LayerNorm.weight torch.Size([768]) encoder.layer.2.output.LayerNorm.bias torch.Size([768]) encoder.layer.3.attention.self.query.weight torch.Size([768, 768]) encoder.layer.3.attention.self.query.bias torch.Size([768]) encoder.layer.3.attention.self.key.weight torch.Size([768, 768]) encoder.layer.3.attention.self.key.bias torch.Size([768]) encoder.layer.3.attention.self.value.weight torch.Size([768, 768]) encoder.layer.3.attention.self.value.bias torch.Size([768]) encoder.layer.3.attention.output.dense.weight torch.Size([768, 768]) encoder.layer.3.attention.output.dense.bias torch.Size([768]) encoder.layer.3.attention.output.LayerNorm.weight torch.Size([768]) encoder.layer.3.attention.output.LayerNorm.bias torch.Size([768]) encoder.layer.3.intermediate.dense.weight torch.Size([3072, 768]) encoder.layer.3.intermediate.dense.bias torch.Size([3072]) encoder.layer.3.output.dense.weight torch.Size([768, 3072]) encoder.layer.3.output.dense.bias torch.Size([768]) encoder.layer.3.output.LayerNorm.weight torch.Size([768]) encoder.layer.3.output.LayerNorm.bias torch.Size([768]) encoder.layer.4.attention.self.query.weight torch.Size([768, 768]) encoder.layer.4.attention.self.query.bias torch.Size([768]) encoder.layer.4.attention.self.key.weight torch.Size([768, 768]) encoder.layer.4.attention.self.key.bias torch.Size([768]) encoder.layer.4.attention.self.value.weight torch.Size([768, 768]) encoder.layer.4.attention.self.value.bias torch.Size([768]) encoder.layer.4.attention.output.dense.weight torch.Size([768, 768]) encoder.layer.4.attention.output.dense.bias torch.Size([768]) encoder.layer.4.attention.output.LayerNorm.weight torch.Size([768]) encoder.layer.4.attention.output.LayerNorm.bias torch.Size([768]) encoder.layer.4.intermediate.dense.weight torch.Size([3072, 768]) encoder.layer.4.intermediate.dense.bias torch.Size([3072]) encoder.layer.4.output.dense.weight torch.Size([768, 3072]) encoder.layer.4.output.dense.bias torch.Size([768]) encoder.layer.4.output.LayerNorm.weight torch.Size([768]) encoder.layer.4.output.LayerNorm.bias torch.Size([768]) encoder.layer.5.attention.self.query.weight torch.Size([768, 768]) encoder.layer.5.attention.self.query.bias torch.Size([768]) encoder.layer.5.attention.self.key.weight torch.Size([768, 768]) encoder.layer.5.attention.self.key.bias torch.Size([768]) encoder.layer.5.attention.self.value.weight torch.Size([768, 768]) encoder.layer.5.attention.self.value.bias torch.Size([768]) encoder.layer.5.attention.output.dense.weight torch.Size([768, 768]) encoder.layer.5.attention.output.dense.bias torch.Size([768]) encoder.layer.5.attention.output.LayerNorm.weight torch.Size([768]) encoder.layer.5.attention.output.LayerNorm.bias torch.Size([768]) encoder.layer.5.intermediate.dense.weight torch.Size([3072, 768]) encoder.layer.5.intermediate.dense.bias torch.Size([3072]) encoder.layer.5.output.dense.weight torch.Size([768, 3072]) encoder.layer.5.output.dense.bias torch.Size([768]) encoder.layer.5.output.LayerNorm.weight torch.Size([768]) encoder.layer.5.output.LayerNorm.bias torch.Size([768]) encoder.layer.6.attention.self.query.weight torch.Size([768, 768]) encoder.layer.6.attention.self.query.bias torch.Size([768]) encoder.layer.6.attention.self.key.weight torch.Size([768, 768]) encoder.layer.6.attention.self.key.bias torch.Size([768]) encoder.layer.6.attention.self.value.weight torch.Size([768, 768]) encoder.layer.6.attention.self.value.bias torch.Size([768]) encoder.layer.6.attention.output.dense.weight torch.Size([768, 768]) encoder.layer.6.attention.output.dense.bias torch.Size([768]) encoder.layer.6.attention.output.LayerNorm.weight torch.Size([768]) encoder.layer.6.attention.output.LayerNorm.bias torch.Size([768]) encoder.layer.6.intermediate.dense.weight torch.Size([3072, 768]) encoder.layer.6.intermediate.dense.bias torch.Size([3072]) encoder.layer.6.output.dense.weight torch.Size([768, 3072]) encoder.layer.6.output.dense.bias torch.Size([768]) encoder.layer.6.output.LayerNorm.weight torch.Size([768]) encoder.layer.6.output.LayerNorm.bias torch.Size([768]) encoder.layer.7.attention.self.query.weight torch.Size([768, 768]) encoder.layer.7.attention.self.query.bias torch.Size([768]) encoder.layer.7.attention.self.key.weight torch.Size([768, 768]) encoder.layer.7.attention.self.key.bias torch.Size([768]) encoder.layer.7.attention.self.value.weight torch.Size([768, 768]) encoder.layer.7.attention.self.value.bias torch.Size([768]) encoder.layer.7.attention.output.dense.weight torch.Size([768, 768]) encoder.layer.7.attention.output.dense.bias torch.Size([768]) encoder.layer.7.attention.output.LayerNorm.weight torch.Size([768]) encoder.layer.7.attention.output.LayerNorm.bias torch.Size([768]) encoder.layer.7.intermediate.dense.weight torch.Size([3072, 768]) encoder.layer.7.intermediate.dense.bias torch.Size([3072]) encoder.layer.7.output.dense.weight torch.Size([768, 3072]) encoder.layer.7.output.dense.bias torch.Size([768]) encoder.layer.7.output.LayerNorm.weight torch.Size([768]) encoder.layer.7.output.LayerNorm.bias torch.Size([768]) encoder.layer.8.attention.self.query.weight torch.Size([768, 768]) encoder.layer.8.attention.self.query.bias torch.Size([768]) encoder.layer.8.attention.self.key.weight torch.Size([768, 768]) encoder.layer.8.attention.self.key.bias torch.Size([768]) encoder.layer.8.attention.self.value.weight torch.Size([768, 768]) encoder.layer.8.attention.self.value.bias torch.Size([768]) encoder.layer.8.attention.output.dense.weight torch.Size([768, 768]) encoder.layer.8.attention.output.dense.bias torch.Size([768]) encoder.layer.8.attention.output.LayerNorm.weight torch.Size([768]) encoder.layer.8.attention.output.LayerNorm.bias torch.Size([768]) encoder.layer.8.intermediate.dense.weight torch.Size([3072, 768]) encoder.layer.8.intermediate.dense.bias torch.Size([3072]) encoder.layer.8.output.dense.weight torch.Size([768, 3072]) encoder.layer.8.output.dense.bias torch.Size([768]) encoder.layer.8.output.LayerNorm.weight torch.Size([768]) encoder.layer.8.output.LayerNorm.bias torch.Size([768]) encoder.layer.9.attention.self.query.weight torch.Size([768, 768]) encoder.layer.9.attention.self.query.bias torch.Size([768]) encoder.layer.9.attention.self.key.weight torch.Size([768, 768]) encoder.layer.9.attention.self.key.bias torch.Size([768]) encoder.layer.9.attention.self.value.weight torch.Size([768, 768]) encoder.layer.9.attention.self.value.bias torch.Size([768]) encoder.layer.9.attention.output.dense.weight torch.Size([768, 768]) encoder.layer.9.attention.output.dense.bias torch.Size([768]) encoder.layer.9.attention.output.LayerNorm.weight torch.Size([768]) encoder.layer.9.attention.output.LayerNorm.bias torch.Size([768]) encoder.layer.9.intermediate.dense.weight torch.Size([3072, 768]) encoder.layer.9.intermediate.dense.bias torch.Size([3072]) encoder.layer.9.output.dense.weight torch.Size([768, 3072]) encoder.layer.9.output.dense.bias torch.Size([768]) encoder.layer.9.output.LayerNorm.weight torch.Size([768]) encoder.layer.9.output.LayerNorm.bias torch.Size([768]) encoder.layer.10.attention.self.query.weight torch.Size([768, 768]) encoder.layer.10.attention.self.query.bias torch.Size([768]) encoder.layer.10.attention.self.key.weight torch.Size([768, 768]) encoder.layer.10.attention.self.key.bias torch.Size([768]) encoder.layer.10.attention.self.value.weight torch.Size([768, 768]) encoder.layer.10.attention.self.value.bias torch.Size([768]) encoder.layer.10.attention.output.dense.weight torch.Size([768, 768]) encoder.layer.10.attention.output.dense.bias torch.Size([768]) encoder.layer.10.attention.output.LayerNorm.weight torch.Size([768]) encoder.layer.10.attention.output.LayerNorm.bias torch.Size([768]) encoder.layer.10.intermediate.dense.weight torch.Size([3072, 768]) encoder.layer.10.intermediate.dense.bias torch.Size([3072]) encoder.layer.10.output.dense.weight torch.Size([768, 3072]) encoder.layer.10.output.dense.bias torch.Size([768]) encoder.layer.10.output.LayerNorm.weight torch.Size([768]) encoder.layer.10.output.LayerNorm.bias torch.Size([768]) encoder.layer.11.attention.self.query.weight torch.Size([768, 768]) encoder.layer.11.attention.self.query.bias torch.Size([768]) encoder.layer.11.attention.self.key.weight torch.Size([768, 768]) encoder.layer.11.attention.self.key.bias torch.Size([768]) encoder.layer.11.attention.self.value.weight torch.Size([768, 768]) encoder.layer.11.attention.self.value.bias torch.Size([768]) encoder.layer.11.attention.output.dense.weight torch.Size([768, 768]) encoder.layer.11.attention.output.dense.bias torch.Size([768]) encoder.layer.11.attention.output.LayerNorm.weight torch.Size([768]) encoder.layer.11.attention.output.LayerNorm.bias torch.Size([768]) encoder.layer.11.intermediate.dense.weight torch.Size([3072, 768]) encoder.layer.11.intermediate.dense.bias torch.Size([3072]) encoder.layer.11.output.dense.weight torch.Size([768, 3072]) encoder.layer.11.output.dense.bias torch.Size([768]) encoder.layer.11.output.LayerNorm.weight torch.Size([768]) encoder.layer.11.output.LayerNorm.bias torch.Size([768]) pooler.dense.weight torch.Size([768, 768]) pooler.dense.bias torch.Size([768])
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39

说明

按照官网给出的 BERT 技术,一方面因为是 unsupervised ,另一方面是因为用于预训练 NLP 的深度双向系统。预训练表示既可以是 context-free 也可是 contexual

  • context-free: word2vec 和 glove
  • 上下文相关的模型:BERT、ELMO;例如 I made a bank deposit中 对bank的理解 bert会看bank的前后。
    而其他模型要么只看左边,要么只看右边 pre-training & fine-tunings 难的是对词的编码,就是transfomer的encoder部分;

如何训练bert

MLM
NSP
BERT 是一种端到端的模型
BERT
可以讲 BERT 模型已经训练好的参数加载进来,然后直接在下游任务进行微调是目前主流做法。

参考资料

  • https://zhuanlan.zhihu.com/p/48508221
  • http://jalammar.github.io/illustrated-transformer/
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/盐析白兔/article/detail/823125
推荐阅读
相关标签
  

闽ICP备14008679号