当前位置:   article > 正文

Transformer_transformer解码器的输入

transformer解码器的输入

模型

Transformer是由编码器和解码器组成。Transformer的编码器和解码器是基于自注意力的模块叠加而成的,源(输入)序列和目标(输出)序列的嵌入(embedding)表示将加上位置编码(positional encoding),再分别输入到编码器和解码器中。

在此之前已经描述并实现了基于缩放点积多头注意力和位置编码。接下来将实现Transformer模型的剩余部分。

  1. import math
  2. import pandas as pd
  3. from mxnet import autograd, np, npx
  4. from mxnet.gluon import nn
  5. from d2l import mxnet as d2l
  6. npx.set_np()

基于位置的前馈网络

基于位置的前馈网络对序列中的所有位置的表示进行变换时使用的是同一个多层感知机(MLP),这就是称前馈网络是基于位置的(positionwise)的原因。在下面的实现中,输入X的形状(批量大小,时间步数或序列长度,隐单元数或特征维度)将被一个两层的感知机转换成形状为(批量大小,时间步数,ffn_num_outputs)的输出张量。

  1. #@save
  2. class PositionWiseFFN(nn.Block):
  3. """基于位置的前馈网络"""
  4. def __init__(self, ffn_num_hiddens, ffn_num_outputs, **kwargs):
  5. super(PositionWiseFFN, self).__init__(**kwargs)
  6. self.dense1 = nn.Dense(ffn_num_hiddens, flatten=False,
  7. activation='relu')
  8. self.dense2 = nn.Dense(ffn_num_outputs, flatten=False)
  9. def forward(self, X):
  10. return self.dense2(self.dense1(X))

下面的例子显示,改变张量的最里层维度的尺寸,会改变成基于位置的前馈网络的输出尺寸。因为用同一个多层感知机对所有位置上的输入进行变换,所以当所有这些位置的输入相同时,它们的输出也是相同的。

  1. ffn = PositionWiseFFN(4, 8)
  2. ffn.initialize()
  3. ffn(np.ones((2, 3, 4)))[0]

残差连接和层规范化

层规范化和批量规范化的目标相同,但层规范化是基于特征维度进行规范化。尽管批量规范化在计算机视觉中被广泛应用,但在自然语言处理任务中(输入通常是变长序列)批量规范化通常不如层规范化的效果好。

以下代码对比不同维度的层规范化和批量规范化的效果。

  1. ln = nn.LayerNorm()
  2. ln.initialize()
  3. bn = nn.BatchNorm()
  4. bn.initialize()
  5. X = np.array([[1, 2], [2, 3]])
  6. # 在训练模式下计算X的均值和方差
  7. with autograd.record():
  8. print('层规范化:', ln(X), '\n批量规范化:', bn(X))

现在可以使用残差连接和层规范化来实现AddNorm类。暂退法也被作为正则化方法使用。

  1. #@save
  2. class AddNorm(nn.Block):
  3. """残差连接后进行层规范化"""
  4. def __init__(self, dropout, **kwargs):
  5. super(AddNorm, self).__init__(**kwargs)
  6. self.dropout = nn.Dropout(dropout)
  7. self.ln = nn.LayerNorm()
  8. def forward(self, X, Y):
  9. return self.ln(self.dropout(Y) + X)

残差连接要求两个输入的形状相同,以便加法操作后输出张量的形状相同。

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/羊村懒王/article/detail/389167
推荐阅读
相关标签
  

闽ICP备14008679号