当前位置:   article > 正文

Transformer代码讲解_transformer分类模型代码

transformer分类模型代码

Transformer代码讲解

原论文:https://arxiv.org/pdf/1706.03762v5.pdf
本文将从以下部分进行讲解:

一、Transformer结构展开图

   1.原图

   2.结构展开图

二、Transformer代码

   1.数据预处理

   2. 代码拆分

      2.1 positional encoding

      2.2 pad mask

      2.3 subsequence mask

      2.4 ScaledDotProductAttention(计算 context vector

      2.5 multiheadattention

      2.6 feedforward layer

      2.7 encoder layer

      2.8 encoder

      2.9 decoder layer

      2.10 decoder

   3. transformer

   4. 模型、损失函数、优化器

   5. 训练

   6. 测试

正文:

一、Transformer结构展开图

  1.原图

1.在原论文中N=6,也就是分别有6个Encoder和Decoder。
在这里插入图片描述
2.原论文每一个Decoder的enc_inputs都是最后一个Encoder的输出。如下图:
在这里插入图片描述

  2.结构展开图

将N=6带入到上图中,得到Transformer结构展开图。原论文每一个Decoder的enc_inputs都是最后一个Encoder的输出。如下图:
在这里插入图片描述

二、Transformer代码

  1.数据预处理

  2. 代码拆分

      2.1 positional encoding

      2.2 pad mask(针对句子不够长,加了 pad,因此需要对 pad 进行 mask,保证张量中pad的位置的数字在计算中,对结果不会产生影响)

      2.3 subsequence mask(Decoder input 不能看到未来时刻单词信息,因此需要 mask)

      2.4 ScaledDotProductAttention(计算 context vector)

      2.5 multiheadattention

      2.6 feedforward layer

      2.7 encoder layer

      2.8 encoder

      2.9 decoder layer

      2.10 decoder

  3. transformer

  4. 模型、损失函数、优化器

  5. 训练

  6. 测试

本文内容由网友自发贡献,转载请注明出处:https://www.wpsshop.cn/w/小丑西瓜9/article/detail/92898
推荐阅读
相关标签
  

闽ICP备14008679号