当前位置:   article > 正文

攻克 Transformer 之代码精讲+实战,以及《变形金刚》结构_变形金刚transformer 编码

变形金刚transformer 编码

Transformer模型完全基于注意力机制,没有任何卷积层或循环神经网络层。尽管transformer最初是应⽤于在⽂本数据上的序列到序列学习,但现在已经推广到各种现代的深度学习中,例如语⾔、视觉、语音和强化学习领域。

本文章进行实战:利用Transformer把英语翻译成法语(Pytorch框架)

Transformer系列往期博客链接直达:攻克 Transformer & 注意力机制的查询、键和值 & 有无参数的Nadaraya-Watson核回归
攻克 Transformer && 评分函数(加性注意力、缩放点积注意力)

目录

(1)多头注意力

(2)基于位置的前馈网络

(3)残差网络后进行层规范化

(4)transformer编码器块

(5)transformer编码器

(6)transformer解码器块

(7)transformer解码器

(8)训练

(9)测试

(10)可视化权重

(11)总结

(12)完整代码


Transformer作为编码器-解码器架构的⼀个实例,其整体架构图在图1 中展示。

正如所见到的,transformer是由编码器和解码器组成的;transformer的编码器和解码器是基于⾃注意力的模块叠加而成的;源序列和目标序列的嵌⼊层先加上位置编码(positional encoding),再分别输⼊到编码器和解码器中。

图1 Transformer架构

=============Transformer 编码器=============

从宏观角度来看,transformer的编码器是由多个相同的层叠加而成的,每个层都有两个子层。第⼀个⼦层是多头⾃注意力汇聚;第⼆个子层是基于位置的前馈网络。编码器由n个编码块组成,前n-1个编码块的输出作为编码块的输入,具体来说,在计算编码器的⾃注意力时,查询、键和值都来⾃前⼀个编码器层的输出。每个子层都采用了残差连接。在残差连接的加法计算之后,紧接着应用层规范化。因此,输入序列对应的每个位置,transformer 编码器都将输出⼀个 d 维表示向量。

=============Transformer 解码器=============

Transformer解码器也是由多个相同的层叠加而成的,并且层中使用了残差连接和层规范化。除了编码器中描述的两个子层之外,解码器还在这两个子层之间插⼊了第三个子层,称为编码器—解码器注意力层在编码器-解码器注意力中,查询来自前⼀个解码器层的输出,而键和值来自整个编码器的输出。在解码器自注意力中,查询、键和值都来自上⼀个解码器层的输出。但是,解码器中的每个位置只能考虑该位置之前的所有位置。这种掩蔽注意力保留了自回归(auto-regressive)属性,确保预测仅依赖于已生成的输出词元。

补充:自回归模型_百度百科

自回归模型(英语:Autoregressive model,简称AR模型),是统计上一种处理时间序列的方法,用同一变数例如x的之前各期,亦即x1至xt-1来预测本期xt的表现,并假设它们为一线性关系。因为这是从回归分析中的线性回归发展而来,只是不用x预测y,而是用x预测 x(自己);所以叫做自回归

Transformer“编码器-解码器”网络模型结构:
EncoderDecoder(

  编码器
  (encoder): TransformerEncoder(

    嵌入层
    (embedding): Embedding(184, 32)

    位置编码
    (pos_encoding): PositionalEncoding(

    舍弃
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (blks): Sequential(

     第一个编码器块
      (block0): EncoderBlock(

          多头注意力
        (attention): MultiHeadAttention(

           缩放点击注意力
          (attention): DotProductAttention(

            舍弃
            (dropout): Dropout(p=0.1, inplace=False)
          )

          全连接层(线性变换)
          (W_q): Linear(in_features=32, out_features=32, bias=False)
          (W_k): Linear(in_features=32, out_features=32, bias=False)
          (W_v): Linear(in_features=32, out_features=32, bias=False)
          (W_o): Linear(in_features=32, out_features=32, bias=False)
        )

        第一个 加&规范化
        (addnorm1): AddNorm(
          (dropout): Dropout(p=0.1, inplace=False)
          (ln): LayerNorm((32,), eps=1e-05, elementwise_affine=True)
        )

        基于位置的前馈神经网络
        (ffn): PositionWiseFFN(
          (dense1): Linear(in_features=32, out_features=64, bias=True)
          (relu): ReLU()  激活函数
          (dense2): Linear(in_features=64, out_features=32, bias=True)
        )

        第二个 加&规范化
        (addnorm2): AddNorm(
          (dropout): Dropout(p=0.1, inplace=False)
          (ln): LayerNorm((32,), eps=1e-05, elementwise_affine=True)
        )
      )

     第二个编码器块
      (block1): EncoderBlock(
        (attention): MultiHeadAttention(
          (attention): DotProductAttention(
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (W_q): Linear(in_features=32, out_features=32, bias=False)
          (W_k): Linear(in_features=32, out_features=32, bias=False)
          (W_v): Linear(in_features=32, out_features=32, bias=False)
          (W_o): Linear(in_features=32, out_features=32, bias=False)
        )
        (addnorm1): AddNorm(
          (dropout): Dropout(p=0.1, inplace=False)
          (ln): LayerNorm((32,), eps=1e-05, elementwise_affine=True)
        )
        (ffn): PositionWiseFFN(
          (dense1): Linear(in_features=32, out_features=64, bias=True)
          (relu): ReLU()
          (dense2): Linear(in_features=64, out_features=32, bias=True)
        )
        (addnorm2): AddNorm(
          (dropout): Dropout(p=0.1, inplace=False)
          (ln): LayerNorm((32,), eps=1e-05, elementwise_affine=True)
        ))))

解码器同编码器类似

(1)多头注意力

在实践中,当给定相同的查询、键和值的集合时,我们希望模型可以基于相同的注意⼒机制学习到不同的⾏为,然后将不同的⾏为作为知识组合起来,捕获序列内各种范围的依赖关系(例如,短距离依赖和长距离依赖关系)。因此,允许注意力机制组合使用查询、键和值的不同子空间表示可能是有益的。

为此,与其只使⽤单独⼀个注意力汇聚,我们可以⽤独⽴学习得到的h组不同的线性投影(linear projections) 来变换查询、键和值。然后,这h组变换后的查询、键和值将并⾏地送到注意⼒汇聚中。最后,将这h个注意力汇聚的输出拼接在⼀起,并且通过另⼀个可以学习的线性投影进⾏变换,以产⽣最终输出,这种设计被称为多头注意力(multihead attention)。对于h个注意力汇聚输出,每⼀个注意⼒汇聚都被称作⼀个头。

图2 展示了使用全连接层来实现可学习的线性变换的多头注意力。

 图2 多头注意力:多个头连结然后线性变换 

多头注意力代码:(此处选择缩放点积注意力作为注意力汇聚)

  1. class MultiHeadAttention(nn.Module):
  2. """多头注意力"""
  3. # 100, 100, 100, 100, 5, 0.5
  4. def __init__(self, key_size, query_size, value_size, num_hiddens,
  5. num_heads, dropout, bias=False, **kwargs):
  6. super(MultiHeadAttention, self).__init__(**kwargs)
  7. self.num_heads = num_heads
  8. self.DP_Attention = d2l.DotProductAttention(dropout) # 缩放点积注意力,舍弃50%的神经元参数
  9. # 输入样本的大小、输出样本的大小、偏置设置为False
  10. self.W_q = nn.Linear(query_size, num_hiddens, bias=bias) # [100, 100]
  11. self.W_k = nn.Linear(key_size, num_hiddens, bias=bias) # [100, 100]
  12. self.W_v = nn.Linear(value_size, num_hiddens, bias=bias) # [100, 100]
  13. self.W_o = nn.Linear(num_hiddens, num_hiddens, bias=bias) # [100, 100]
  14. # [2, 4, 100], [2, 6, 100], [2, 6, 100], torch.tensor([3, 2])
  15. def forward(self, queries, keys, values, valid_lens):
  16. # queries,keys,values的形状:(batch_size,查询或者“键-值”对的个数,num_hiddens)
  17. # valid_lens 的形状:(batch_size,)或(batch_size,查询的个数)
  18. # 经过变换后,输出的 queries,keys,values 的形状:
  19. # (batch_size*num_heads,查询或者“键-值”对的个数,num_hiddens/num_heads)
  20. queries = transpose_qkv(self.W_q(queries), self.num_heads)
  21. keys = transpose_qkv(self.W_k(keys), self.num_heads)
  22. values = transpose_qkv(self.W_v(values), self.num_heads)
  23. # print(queries)
  24. if valid_lens is not None:
  25. # 在轴0,将第一项(标量或者矢量)复制num_heads次,
  26. # 然后如此复制第二项,然后诸如此类。
  27. valid_lens = torch.repeat_interleave(valid_lens, repeats=self.num_heads, dim=0)
  28. # print(valid_lens) # tensor([3, 3, 3, 3, 3, 2, 2, 2, 2, 2])
  29. print("valid_lens:", valid_lens.size())
  30. # output的形状:(batch_size*num_heads,查询的个数,num_hiddens/num_heads)
  31. # torch.Size([10, 4, 20])、torch.Size([10, 6, 20])、torch.Size([10, 6, 20])、torch.Size([10])
  32. output = self.DP_Attention(queries, keys, values, valid_lens)
  33. # output_concat的形状:(batch_size,查询的个数,num_hiddens)
  34. output_concat = transpose_output(output, self.num_heads)
  35. return self.W_o(output_concat)

为了能够使多个头并行计算,上面的MultiHeadAttention类将使用下面定义的两个转置函数。具体来说, transpose_output函数反转了transpose_qkv函数的操作。

  1. # [2, 4, 100]/[2, 6, 100], 5
  2. def transpose_qkv(X, num_heads):
  3. """为了多注意力头的并行计算而变换形状"""
  4. # 输入X的形状:(batch_size,查询或者“键-值”对的个数,num_hiddens)
  5. # 输出X的形状:(batch_size,查询或者“键-值”对的个数,num_heads,num_hiddens/num_heads)
  6. # reshape(-1):首先把张量中的所有元素平铺,然后在变形成指定的形状
  7. X = X.reshape(X.shape[0], X.shape[1], num_heads, -1) # 2*4*100/2*4*5 = 20, -1就代表20
  8. # print(X.size())
  9. # 输出X的形状:(batch_size,num_heads,查询或者“键-值”对的个数, num_hiddens/num_heads)
  10. X = X.permute(0, 2, 1, 3) # 更改矩阵形状 torch.Size([2, 5, 4, 20])
  11. # print(X.size())
  12. # 最终输出的形状:(batch_size*num_heads,查询或者“键-值”对的个数, num_hiddens/num_heads)
  13. output = X.reshape(-1, X.shape[2], X.shape[3]) # 2*5*4*20/4*20 = 10, -1就代表10
  14. print("transpose_qkv:", output.size())
  15. return output
  16. # [10, 4, 20], 5
  17. def transpose_output(X, num_heads):
  18. """逆转transpose_qkv函数的操作"""
  19. # print(X.size())
  20. X = X.reshape(-1, num_heads, X.shape[1], X.shape[2]) # [2, 5, 4, 20]
  21. X = X.permute(0, 2, 1, 3) # [2, 4, 5, 20]
  22. output = X.reshape(X.shape[0], X.shape[1], -1) # [2, 4, 100]
  23. print("transpose_output:", output.size())
  24. return output

 最后,使用键和值相同的例子,测试我们编写的MultiHeadAttention类。

  1. num_hiddens, num_heads = 100, 5
  2. attention = MultiHeadAttention(num_hiddens, num_hiddens, num_hiddens,
  3. num_hiddens, num_heads, 0.5)
  4. attention.eval()
  5. batch_size, num_queries = 2, 4
  6. num_kvpairs, valid_lens = 6, torch.tensor([3, 2])
  7. X = torch.ones((batch_size, num_queries, num_hiddens)) # [2, 4, 100]
  8. Y = torch.ones((batch_size, num_kvpairs, num_hiddens)) # [2, 6, 100]
  9. print("result:", attention(X, Y, Y, valid_lens).shape) # torch.Size([2, 4, 100])

 输出:

transpose_qkv: torch.Size([10, 4, 20])
transpose_qkv: torch.Size([10, 6, 20])
transpose_qkv: torch.Size([10, 6, 20])
valid_lens: torch.Size([10])
transpose_output: torch.Size([2, 4, 100])
result: torch.Size([2, 4, 100])

(2)基于位置的前馈网络

基于位置的前馈⽹络对序列中的所有位置的表⽰进⾏变换时使用的是同⼀个多层感知机(MLP),这就是称前馈⽹络是基于位置的原因。

下⾯的例⼦显示,改变张量的最⾥层维度的尺⼨,会改变成基于位置的前馈⽹络的输出尺⼨。因为⽤同⼀个多层感知机对所有位置上的输⼊进⾏变换,所以当所有这些位置的输⼊相同时,它们的输出也是相同的。

  1. class PositionWiseFFN(nn.Module):
  2. """基于位置的前馈网络"""
  3. def __init__(self, ffn_num_input, ffn_num_hiddens, ffn_num_outputs,
  4. **kwargs):
  5. super(PositionWiseFFN, self).__init__(**kwargs)
  6. self.dense1 = nn.Linear(ffn_num_input, ffn_num_hiddens)
  7. self.relu = nn.ReLU()
  8. self.dense2 = nn.Linear(ffn_num_hiddens, ffn_num_outputs)
  9. def forward(self, X):
  10. return self.dense2(self.relu(self.dense1(X)))
  11. # demo
  12. # [2, 3, 4] * [4, 4] * [4, 8] = [2, 3, 8]
  13. ffn = PositionWiseFFN(4, 4, 8)
  14. ffn.eval()
  15. print("基于位置的前馈网络:")
  16. print(ffn(torch.ones((2, 3, 4)))[0])

输出:

基于位置的前馈网络:
tensor([[-0.1039,  0.6010, -0.7257,  0.0406, -0.2380, -0.5354, -1.0672,  0.3957],
            [-0.1039,  0.6010, -0.7257,  0.0406, -0.2380, -0.5354, -1.0672,  0.3957],
            [-0.1039,  0.6010, -0.7257,  0.0406, -0.2380, -0.5354, -1.0672,  0.3957]],
            grad_fn=<SelectBackward0>) 

(3)残差网络后进行层规范化

下面代码由残差连接和紧随其后的层规范化组成,两者都是构建有效的深度架构的关键。

  1. class AddNorm(nn.Module):
  2. """残差连接后进行层规范化"""
  3. def __init__(self, normalized_shape, dropout, **kwargs):
  4. super(AddNorm, self).__init__(**kwargs)
  5. self.dropout = nn.Dropout(dropout)
  6. self.ln = nn.LayerNorm(normalized_shape)
  7. def forward(self, X, Y):
  8. return self.ln(self.dropout(Y) + X)
  9. # demo
  10. add_norm = AddNorm([3, 4], 0.5)
  11. add_norm.eval()
  12. print("残差连接后进行层规范化:")
  13. print(add_norm(torch.ones((2, 3, 4)), torch.ones((2, 3, 4))).shape)

输出:

残差连接后进行层规范化:
torch.Size([2, 3, 4]) 

补充:

层规范化和批量规范化的⽬标相同,但层规范化是基于特征维度进⾏规范化。尽管批量规范化在计算机视觉中 被广泛应用,但在自然语言处理任务中(输入通常是变长序列)批量规范化通常不如层规范化的效果好。

以下代码对⽐不同维度的层规范化和批量规范化的效果。

  1. ln = nn.LayerNorm(2) # 层规范化
  2. bn = nn.BatchNorm1d(2) # 批标准化
  3. X = torch.tensor([[1, 2], [2, 3]], dtype=torch.float32)
  4. # 在训练模式下计算 X 的均值和方差
  5. print('layer norm:', ln(X), '\nbatch norm:', bn(X))

 (4)transformer编码器块

有了组成transformer编码器的基础组件,现在可以先实现编码器中的一个层。下⾯的EncoderBlock类包含两个⼦层:多头⾃注意力和基于位置的前馈⽹络,这两个子层都使用了残差连接和紧随的层规范化。

  1. class EncoderBlock(nn.Module):
  2. """transformer编码器块"""
  3. def __init__(self, key_size, query_size, value_size, num_hiddens,
  4. norm_shape, ffn_num_input, ffn_num_hiddens, num_heads,
  5. dropout, use_bias=False, **kwargs):
  6. super(EncoderBlock, self).__init__(**kwargs)
  7. # 多头注意力
  8. self.attention = d2l.MultiHeadAttention(key_size, query_size, value_size,
  9. num_hiddens, num_heads, dropout, use_bias)
  10. # 加 & 规范化(包含残差连接)
  11. self.addnorm1 = AddNorm(norm_shape, dropout)
  12. # 逐位前馈网络
  13. self.ffn = PositionWiseFFN(ffn_num_input, ffn_num_hiddens, num_hiddens)
  14. # 加 & 规范化(包含残差连接)
  15. self.addnorm2 = AddNorm(norm_shape, dropout)
  16. # 多头注意力>>>加&规范化>>>逐位前馈网络>>>加&规范化
  17. def forward(self, X, valid_lens):
  18. Y = self.addnorm1(X, self.attention(X, X, X, valid_lens))
  19. return self.addnorm2(Y, self.ffn(Y))
  20. # demo
  21. X = torch.ones((2, 100, 24))
  22. valid_lens = torch.tensor([3, 2])
  23. encoder_blk = EncoderBlock(24, 24, 24, 24, [100, 24], 24, 48, 8, 0.5)
  24. encoder_blk.eval()
  25. print("transformer编码器块:")
  26. print(encoder_blk(X, valid_lens).shape)

输出:(输入的张量尺寸和输出的张量尺寸大小相同)

transformer编码器块:
torch.Size([2, 100, 24]) 

(5)transformer编码器

在实现下⾯的transformer编码器的代码中,我们堆叠了num_layers个EncoderBlock类的实例。由于我们使⽤的是值范围在−1和1之间的固定位置编码,因此通过学习得到的输⼊的嵌⼊表⽰的值需要先乘以嵌⼊维度的平⽅根进⾏重新缩放,然后再与位置编码相加。

下面的代码指定了超参数来创建⼀个两层的transformer编码器。

  1. class TransformerEncoder(d2l.Encoder):
  2. """transformer编码器"""
  3. def __init__(self, vocab_size, key_size, query_size, value_size,
  4. num_hiddens, norm_shape, ffn_num_input, ffn_num_hiddens,
  5. num_heads, num_layers, dropout, use_bias=False, **kwargs):
  6. super(TransformerEncoder, self).__init__(**kwargs)
  7. self.num_hiddens = num_hiddens
  8. # 嵌入层
  9. self.embedding = nn.Embedding(vocab_size, num_hiddens)
  10. # 位置编码
  11. self.pos_encoding = d2l.PositionalEncoding(num_hiddens, dropout)
  12. # 添加模块(num_layers个transformer编码器块)
  13. self.blks = nn.Sequential()
  14. for i in range(num_layers):
  15. self.blks.add_module("block"+str(i),
  16. EncoderBlock(key_size, query_size, value_size, num_hiddens,
  17. norm_shape, ffn_num_input, ffn_num_hiddens,
  18. num_heads, dropout, use_bias))
  19. def forward(self, X, valid_lens, *args):
  20. # 因为位置编码值在-1和1之间,因此嵌入值乘以嵌入维度的平方根进行缩放,然后再与位置编码相加。
  21. X = self.pos_encoding(self.embedding(X) * math.sqrt(self.num_hiddens))
  22. self.attention_weights = [None] * len(self.blks)
  23. for i, blk in enumerate(self.blks):
  24. X = blk(X, valid_lens)
  25. self.attention_weights[i] = blk.attention.attention.attention_weights
  26. return X
  27. # demo
  28. encoder = TransformerEncoder(200, 24, 24, 24, 24, [100, 24], 24, 48, 8, 2, 0.5)
  29. encoder.eval()
  30. print("transformer编码器:")
  31. print(encoder(torch.ones((2, 100), dtype=torch.long), valid_lens).shape)

输出:

transformer编码器:
torch.Size([2, 100, 24]) 

(6)transformer解码器块

transformer解码器也是由多个相同的层组成。在DecoderBlock类中实现的每个层包含了三个⼦层:解码器自注意力、“编码器-解码器”注意力和基于位置的前馈网络。这些⼦层也都被残差连接和紧随的层规范化围绕。

关于序列到序列模型(sequence-to-sequence model),训练阶段,其输出序列的所有位置(时间步)的词元都是已知的;然而,在预测阶段,其输出序列的词元是逐个生成的。因此,在任何解码器时间步中,只有⽣成的词元才能⽤于解码器的自注意力计算中。

为了在解码器中保留自回归的属性,其掩蔽自注意力设定了参数dec_valid_lens,以便任何查询都只会与解码器中所有已经生成词元的位置(即直到该查询位置为止)进⾏注意力计算。

  1. class DecoderBlock(nn.Module):
  2. """解码器中第i个块"""
  3. def __init__(self, key_size, query_size, value_size, num_hiddens,
  4. norm_shape, ffn_num_input, ffn_num_hiddens, num_heads,
  5. dropout, i, **kwargs):
  6. super(DecoderBlock, self).__init__(**kwargs)
  7. self.i = i
  8. # 多头注意力
  9. self.attention1 = d2l.MultiHeadAttention(
  10. key_size, query_size, value_size, num_hiddens, num_heads, dropout)
  11. # 加 & 规范化
  12. self.addnorm1 = AddNorm(norm_shape, dropout)
  13. # 多头注意力
  14. self.attention2 = d2l.MultiHeadAttention(
  15. key_size, query_size, value_size, num_hiddens, num_heads, dropout)
  16. self.addnorm2 = AddNorm(norm_shape, dropout)
  17. # 加 & 规范化
  18. self.ffn = PositionWiseFFN(ffn_num_input, ffn_num_hiddens,
  19. num_hiddens)
  20. # 加 & 规范化
  21. self.addnorm3 = AddNorm(norm_shape, dropout)
  22. def forward(self, X, state):
  23. enc_outputs, enc_valid_lens = state[0], state[1]
  24. # 训练阶段,输出序列的所有词元都在同一时间处理,
  25. # 因此state[2][self.i]初始化为None。
  26. # 预测阶段,输出序列是通过词元一个接着一个解码的,
  27. # 因此state[2][self.i]包含着直到当前时间步第i个块解码的输出表示
  28. if state[2][self.i] is None:
  29. key_values = X
  30. else:
  31. key_values = torch.cat((state[2][self.i], X), axis=1)
  32. state[2][self.i] = key_values
  33. if self.training:
  34. batch_size, num_steps, _ = X.shape
  35. # dec_valid_lens的开头:(batch_size,num_steps),
  36. # 其中每一行是[1,2,...,num_steps]
  37. dec_valid_lens = torch.arange(1, num_steps + 1, device=X.device).repeat(batch_size, 1)
  38. else:
  39. dec_valid_lens = None
  40. # 多头注意力>>>加&规范化>>>多头注意力>>>加&规范化>>>逐位前馈网络>>>加&规范化
  41. X2 = self.attention1(X, key_values, key_values, dec_valid_lens)
  42. Y = self.addnorm1(X, X2)
  43. # 编码器-解码器注意力。
  44. # enc_outputs的开头:(batch_size,num_steps,num_hiddens)
  45. Y2 = self.attention2(Y, enc_outputs, enc_outputs, enc_valid_lens)
  46. Z = self.addnorm2(Y, Y2)
  47. return self.addnorm3(Z, self.ffn(Z)), state
  48. # demo
  49. decoder_blk = DecoderBlock(24, 24, 24, 24, [100, 24], 24, 48, 8, 0.5, 0)
  50. decoder_blk.eval()
  51. X = torch.ones((2, 100, 24))
  52. state = [encoder_blk(X, valid_lens), valid_lens, [None]]
  53. print("解码器中第i个块:")
  54. print(decoder_blk(X, state)[0].shape)

输出:

解码器中第i个块:
torch.Size([2, 100, 24]) 

(7)transformer解码器

 现在我们构建了由num_layers个DecoderBlock实例组成的完整的transformer解码器。最后,通过⼀个全连接层计算所有vocab_size个可能的输出词元的预测值。解码器的⾃注意力权重和编码器解码器注意力权重都被存储下来,⽅便日后可视化的需要。

  1. class TransformerDecoder(d2l.AttentionDecoder):
  2. """Transformer解码器"""
  3. def __init__(self, vocab_size, key_size, query_size, value_size,
  4. num_hiddens, norm_shape, ffn_num_input, ffn_num_hiddens,
  5. num_heads, num_layers, dropout, **kwargs):
  6. super(TransformerDecoder, self).__init__(**kwargs)
  7. self.num_hiddens = num_hiddens
  8. self.num_layers = num_layers
  9. # 嵌入层
  10. self.embedding = nn.Embedding(vocab_size, num_hiddens)
  11. # 位置编码
  12. self.pos_encoding = d2l.PositionalEncoding(num_hiddens, dropout)
  13. # 添加模块(num_layers个解码器块)
  14. self.blks = nn.Sequential()
  15. for i in range(num_layers):
  16. self.blks.add_module("block"+str(i),
  17. DecoderBlock(key_size, query_size, value_size, num_hiddens,
  18. norm_shape, ffn_num_input, ffn_num_hiddens,
  19. num_heads, dropout, i))
  20. # 全连接层
  21. self.dense = nn.Linear(num_hiddens, vocab_size)
  22. def init_state(self, enc_outputs, enc_valid_lens, *args):
  23. return [enc_outputs, enc_valid_lens, [None] * self.num_layers]
  24. def forward(self, X, state):
  25. # 位置编码+嵌入层
  26. X = self.pos_encoding(self.embedding(X) * math.sqrt(self.num_hiddens))
  27. # 定义二维列表,存放 解码器自注意力权重 和 “编码器=解码器”自注意力权重
  28. self._attention_weights = [[None] * len(self.blks) for _ in range(2)]
  29. for i, blk in enumerate(self.blks):
  30. X, state = blk(X, state)
  31. # 解码器自注意力权重
  32. self._attention_weights[0][i] = blk.attention1.attention.attention_weights
  33. # “编码器-解码器”自注意力权重
  34. self._attention_weights[1][i] = blk.attention2.attention.attention_weights
  35. return self.dense(X), state # 最后把参数传入到全连接层之中
  36. def attention_weights(self):
  37. return self._attention_weights

(8)训练

依照transformer架构来实例化编码器-解码器模型。在这⾥,指定transformer的编码器和解码器都是2层, 都使⽤4头注意⼒。

为了进⾏序列到序列的学习,我们在“英语-法语”机器翻译数据集上训练transformer模型。

  1. # # 训练
  2. # 隐藏层层数,编码器/解码器块数量,舍弃比率,批量大小,num_steps
  3. num_hiddens, num_layers, dropout, batch_size, num_steps = 32, 2, 0.1, 64, 10
  4. # 学习率,迭代次数,设备选取
  5. lr, num_epochs, device = 0.005, 200, d2l.try_gpu()
  6. # 前馈网络输入层,前馈网络隐藏层,注意力头的数量
  7. ffn_num_input, ffn_num_hiddens, num_heads = 32, 64, 4
  8. # 查询,键,值
  9. query_size, key_size, value_size = 32, 32, 32
  10. # AddNorm(规范化的维度)
  11. norm_shape = [32]
  12. # 加载文字数据集(英语翻译成法语)
  13. # train_iter:数据集信息
  14. # src_vocab:源(英语)词汇表
  15. # tgt_vocab:法语词汇表
  16. train_iter, src_vocab, tgt_vocab = d2l.load_data_nmt(batch_size, num_steps)
  17. encoder = TransformerEncoder(
  18. len(src_vocab), key_size, query_size, value_size, num_hiddens,
  19. norm_shape, ffn_num_input, ffn_num_hiddens, num_heads,
  20. num_layers, dropout)
  21. # print("编码器返回的结果(Transformer编码器的结构):")
  22. # print(encoder)
  23. decoder = TransformerDecoder(
  24. len(tgt_vocab), key_size, query_size, value_size, num_hiddens, norm_shape,
  25. ffn_num_input, ffn_num_hiddens, num_heads, num_layers, dropout)
  26. # print("解码器返回的结果(Transformer解码器的结构):")
  27. # print(decoder)
  28. # Transformer“编码器-解码器”网络模型
  29. net = d2l.EncoderDecoder(encoder, decoder)
  30. print("Transformer“编码器-解码器”网络模型结构:")
  31. print(net)
  32. # 训练一个序列到序列的模型
  33. # 网络模型,数据集信息,学习率,迭代次数,法语词汇表(标志输出),设备选取
  34. d2l.train_seq2seq(net, train_iter, lr, num_epochs, tgt_vocab, device)

图3 Epoch-Loss曲线图 

数据集截图:

(9)测试

训练结束后,使⽤transformer模型将⼀些英语句⼦翻译成法语,并且计算它们的BLEU分数。

  • BLEU的全名为:bilingual evaluation understudy,即:双语互译质量评估辅助工具。它是用来评估机器翻译质量的工具。
  • BLEU的设计思想:机器翻译结果越接近专业人工翻译的结果,则越好。BLEU算法实际上就是在判断两个句子的相似程度。
  • 想知道一个句子翻译前后的表示是否意思一致,直接的办法是拿这个句子的标准人工翻译与机器翻译的结果作比较,如果它们是很相似的,说明我的翻译很成功。
  • 因此,BLUE将机器翻译的结果与其相对应的几个参考翻译作比较,算出一个综合分数。这个分数越高说明机器翻译得越好。注意BLEU算法是句子之间的比较,不是词组,也不是段落。
  1. # # 测试
  2. engs = ['go .', "i lost .", 'he\'s calm .', 'i\'m home .']
  3. fras = ['va !', 'j\'ai perdu .', 'il est calme .', 'je suis chez moi .']
  4. # 多维列表
  5. dec_attention_weight_seq = []
  6. for eng, fra in zip(engs, fras):
  7. # 网络模型、单个英语词汇、英语词汇集、法语词汇集、num_steps,设备选取
  8. translation, dec_attention_weight_seq = d2l.predict_seq2seq(
  9. net, eng, src_vocab, tgt_vocab, num_steps, device, save_attention_weights=True)
  10. print(f'{eng} => {translation}, ',
  11. f'bleu {d2l.bleu(translation, fra, k=2):.3f}') # 评分函数(翻译结果同数据集相比较)
  12. enc_attention_weights = torch.cat(net.encoder.attention_weights, 0).reshape((num_layers, num_heads, -1, num_steps))
  13. print(enc_attention_weights.shape)

输出:

go .                => va !,                         bleu 1.000
i lost .            => je vous en prie .,      bleu 1.000
he's calm .    => il court .,                   bleu 0.000
i'm home .     => je suis chez moi .,    bleu 1.000
torch.Size([2, 4, 10, 10]) 

在编码器的⾃注意⼒中,查询和键都来⾃相同的输⼊序列。因为填充词元是不携带信息的,因此通过指定输⼊序列的有效⻓度可以避免查询与使⽤填充词元的位置计算注意力。接下来,将逐⾏呈现两层多头注意力的权重。每个注意⼒头都根据查询、键和值的不同的表⽰子空间来表⽰不同的注意⼒。 

  1. d2l.show_heatmaps(
  2. enc_attention_weights.cpu(), xlabel='Key positions',
  3. ylabel='Query positions', titles=['Head %d' % i for i in range(1, 5)],
  4. figsize=(7, 3.5))
  5. plt.show()

 图4 最后⼀个英语到法语的句子翻译的可视化transformer的注意力权重

(10)可视化权重

为了可视化解码器的⾃注意⼒权重和“编码器-解码器”的注意⼒权重,我们需要完成更多的数据操作⼯作。 例如,我们⽤零填充被掩蔽住的注意⼒权重。值得注意的是,解码器的⾃注意⼒权重和“编码器-解码器” 的注意⼒权重都有相同的查询:即以序列开始词元(beginning-of-sequence,BOS)打头,再与后续输出的词元共同组成序列。

  1. dec_attention_weights_2d = [head[0].tolist()
  2. for step in dec_attention_weight_seq
  3. for attn in step for blk in attn for head in blk]
  4. dec_attention_weights_filled = torch.tensor(
  5. pd.DataFrame(dec_attention_weights_2d).fillna(0.0).values)
  6. dec_attention_weights = dec_attention_weights_filled.reshape((-1, 2, num_layers, num_heads, num_steps))
  7. dec_self_attention_weights, dec_inter_attention_weights = \
  8. dec_attention_weights.permute(1, 2, 3, 0, 4)
  9. print(dec_self_attention_weights.shape, dec_inter_attention_weights.shape)
  10. # 与编码器的自注意力的情况类似,通过指定输入序列的有效长度,
  11. # [输出序列的查询不会与输入序列中填充位置的词元进行注意力计算]。
  12. # Plusonetoincludethebeginning-of-sequencetoken
  13. d2l.show_heatmaps(
  14. dec_self_attention_weights[:, :, :, :len(translation.split()) + 1],
  15. xlabel='Key positions', ylabel='Query positions',
  16. titles=['Head %d' % i for i in range(1, 5)], figsize=(7, 3.5))
  17. d2l.show_heatmaps(
  18. dec_inter_attention_weights, xlabel='Key positions',
  19. ylabel='Query positions', titles=['Head %d' % i for i in range(1, 5)],
  20. figsize=(7, 3.5))

 1. 由于解码器⾃注意⼒的⾃回归属性,查询不会对当前位置之后的“键-值”对进⾏注意⼒计算。

 由上图可以观察到:查询位置之后(主对角线上方)的注意力权重为零,即没有进行注意力计算!

2. 与编码器的⾃注意⼒的情况类似,通过指定输⼊序列的有效⻓度,输出序列的查询不会与输⼊序列中填充位置的词元进⾏注意⼒计算。

由上图可以观察到:指定序列长度后,大于序列长度的位置的注意力权重为零,没有进行注意力计算!

 (11)总结

  • transformer是编码器-解码器架构的⼀个实践,尽管在实际情况中编码器或解码器可以单独使⽤。
  • 在transformer中,多头⾃注意力⽤于表⽰输⼊序列和输出序列,不过解码器必须通过掩蔽机制来保留自回归属性。
  • transformer中的残差连接和层规范化是训练⾮常深度模型的重要⼯具。
  • transformer模型中基于位置的前馈⽹络使⽤同⼀个多层感知机,作用是对所有序列位置的表示进行转换。

 (12)完整代码

  1. import math
  2. import pandas as pd
  3. import torch
  4. from matplotlib import pyplot as plt
  5. from torch import nn
  6. from d2l import torch as d2l
  7. class PositionWiseFFN(nn.Module):
  8. """基于位置的前馈网络"""
  9. def __init__(self, ffn_num_input, ffn_num_hiddens, ffn_num_outputs,
  10. **kwargs):
  11. super(PositionWiseFFN, self).__init__(**kwargs)
  12. self.dense1 = nn.Linear(ffn_num_input, ffn_num_hiddens)
  13. self.relu = nn.ReLU()
  14. self.dense2 = nn.Linear(ffn_num_hiddens, ffn_num_outputs)
  15. def forward(self, X):
  16. return self.dense2(self.relu(self.dense1(X)))
  17. # demo
  18. # [2, 3, 4] * [4, 4] * [4, 8] = [2, 3, 8]
  19. ffn = PositionWiseFFN(4, 4, 8)
  20. ffn.eval()
  21. print("基于位置的前馈网络:")
  22. print(ffn(torch.ones((2, 3, 4)))[0])
  23. ln = nn.LayerNorm(2) # 层规范化
  24. bn = nn.BatchNorm1d(2) # 批标准化
  25. X = torch.tensor([[1, 2], [2, 3]], dtype=torch.float32)
  26. # 在训练模式下计算 X 的均值和方差
  27. print('layer norm:', ln(X), '\nbatch norm:', bn(X))
  28. class AddNorm(nn.Module):
  29. """残差连接后进行层规范化"""
  30. def __init__(self, normalized_shape, dropout, **kwargs):
  31. super(AddNorm, self).__init__(**kwargs)
  32. self.dropout = nn.Dropout(dropout)
  33. self.ln = nn.LayerNorm(normalized_shape)
  34. def forward(self, X, Y):
  35. return self.ln(self.dropout(Y) + X)
  36. # demo
  37. add_norm = AddNorm([3, 4], 0.5)
  38. add_norm.eval()
  39. print("残差连接后进行层规范化:")
  40. print(add_norm(torch.ones((2, 3, 4)), torch.ones((2, 3, 4))).shape)
  41. class EncoderBlock(nn.Module):
  42. """transformer编码器块"""
  43. def __init__(self, key_size, query_size, value_size, num_hiddens,
  44. norm_shape, ffn_num_input, ffn_num_hiddens, num_heads,
  45. dropout, use_bias=False, **kwargs):
  46. super(EncoderBlock, self).__init__(**kwargs)
  47. # 多头注意力
  48. self.attention = d2l.MultiHeadAttention(key_size, query_size, value_size,
  49. num_hiddens, num_heads, dropout, use_bias)
  50. # 加 & 规范化(包含残差连接)
  51. self.addnorm1 = AddNorm(norm_shape, dropout)
  52. # 逐位前馈网络
  53. self.ffn = PositionWiseFFN(ffn_num_input, ffn_num_hiddens, num_hiddens)
  54. # 加 & 规范化(包含残差连接)
  55. self.addnorm2 = AddNorm(norm_shape, dropout)
  56. # 多头注意力>>>加&规范化>>>逐位前馈网络>>>加&规范化
  57. def forward(self, X, valid_lens):
  58. Y = self.addnorm1(X, self.attention(X, X, X, valid_lens))
  59. return self.addnorm2(Y, self.ffn(Y))
  60. # demo
  61. X = torch.ones((2, 100, 24))
  62. valid_lens = torch.tensor([3, 2])
  63. encoder_blk = EncoderBlock(24, 24, 24, 24, [100, 24], 24, 48, 8, 0.5)
  64. encoder_blk.eval()
  65. print("transformer编码器块:")
  66. print(encoder_blk(X, valid_lens).shape)
  67. class TransformerEncoder(d2l.Encoder):
  68. """transformer编码器"""
  69. def __init__(self, vocab_size, key_size, query_size, value_size,
  70. num_hiddens, norm_shape, ffn_num_input, ffn_num_hiddens,
  71. num_heads, num_layers, dropout, use_bias=False, **kwargs):
  72. super(TransformerEncoder, self).__init__(**kwargs)
  73. self.num_hiddens = num_hiddens
  74. # 嵌入层
  75. self.embedding = nn.Embedding(vocab_size, num_hiddens)
  76. # 位置编码
  77. self.pos_encoding = d2l.PositionalEncoding(num_hiddens, dropout)
  78. # 添加模块(num_layers个transformer编码器块)
  79. self.blks = nn.Sequential()
  80. for i in range(num_layers):
  81. self.blks.add_module("block"+str(i),
  82. EncoderBlock(key_size, query_size, value_size, num_hiddens,
  83. norm_shape, ffn_num_input, ffn_num_hiddens,
  84. num_heads, dropout, use_bias))
  85. def forward(self, X, valid_lens, *args):
  86. # 因为位置编码值在-1和1之间,因此嵌入值乘以嵌入维度的平方根进行缩放,然后再与位置编码相加。
  87. X = self.pos_encoding(self.embedding(X) * math.sqrt(self.num_hiddens))
  88. self.attention_weights = [None] * len(self.blks)
  89. for i, blk in enumerate(self.blks):
  90. X = blk(X, valid_lens)
  91. self.attention_weights[i] = blk.attention.attention.attention_weights
  92. return X
  93. # demo
  94. encoder = TransformerEncoder(200, 24, 24, 24, 24, [100, 24], 24, 48, 8, 2, 0.5)
  95. encoder.eval()
  96. print("transformer编码器:")
  97. print(encoder(torch.ones((2, 100), dtype=torch.long), valid_lens).shape)
  98. # =================================================================== #
  99. class DecoderBlock(nn.Module):
  100. """解码器中第i个块"""
  101. def __init__(self, key_size, query_size, value_size, num_hiddens,
  102. norm_shape, ffn_num_input, ffn_num_hiddens, num_heads,
  103. dropout, i, **kwargs):
  104. super(DecoderBlock, self).__init__(**kwargs)
  105. self.i = i
  106. # 多头注意力
  107. self.attention1 = d2l.MultiHeadAttention(
  108. key_size, query_size, value_size, num_hiddens, num_heads, dropout)
  109. # 加 & 规范化
  110. self.addnorm1 = AddNorm(norm_shape, dropout)
  111. # 多头注意力
  112. self.attention2 = d2l.MultiHeadAttention(
  113. key_size, query_size, value_size, num_hiddens, num_heads, dropout)
  114. self.addnorm2 = AddNorm(norm_shape, dropout)
  115. # 加 & 规范化
  116. self.ffn = PositionWiseFFN(ffn_num_input, ffn_num_hiddens,
  117. num_hiddens)
  118. # 加 & 规范化
  119. self.addnorm3 = AddNorm(norm_shape, dropout)
  120. def forward(self, X, state):
  121. enc_outputs, enc_valid_lens = state[0], state[1]
  122. # 训练阶段,输出序列的所有词元都在同一时间处理,
  123. # 因此state[2][self.i]初始化为None。
  124. # 预测阶段,输出序列是通过词元一个接着一个解码的,
  125. # 因此state[2][self.i]包含着直到当前时间步第i个块解码的输出表示
  126. if state[2][self.i] is None:
  127. key_values = X
  128. else:
  129. key_values = torch.cat((state[2][self.i], X), axis=1)
  130. state[2][self.i] = key_values
  131. if self.training:
  132. batch_size, num_steps, _ = X.shape
  133. # dec_valid_lens的开头:(batch_size,num_steps),
  134. # 其中每一行是[1,2,...,num_steps]
  135. dec_valid_lens = torch.arange(1, num_steps + 1, device=X.device).repeat(batch_size, 1)
  136. else:
  137. dec_valid_lens = None
  138. # 多头注意力>>>加&规范化>>>多头注意力>>>加&规范化>>>逐位前馈网络>>>加&规范化
  139. X2 = self.attention1(X, key_values, key_values, dec_valid_lens)
  140. Y = self.addnorm1(X, X2)
  141. # 编码器-解码器注意力。
  142. # enc_outputs的开头:(batch_size,num_steps,num_hiddens)
  143. Y2 = self.attention2(Y, enc_outputs, enc_outputs, enc_valid_lens)
  144. Z = self.addnorm2(Y, Y2)
  145. return self.addnorm3(Z, self.ffn(Z)), state
  146. # demo
  147. decoder_blk = DecoderBlock(24, 24, 24, 24, [100, 24], 24, 48, 8, 0.5, 0)
  148. decoder_blk.eval()
  149. X = torch.ones((2, 100, 24))
  150. state = [encoder_blk(X, valid_lens), valid_lens, [None]]
  151. print("解码器中第i个块:")
  152. print(decoder_blk(X, state)[0].shape)
  153. class TransformerDecoder(d2l.AttentionDecoder):
  154. """Transformer解码器"""
  155. def __init__(self, vocab_size, key_size, query_size, value_size,
  156. num_hiddens, norm_shape, ffn_num_input, ffn_num_hiddens,
  157. num_heads, num_layers, dropout, **kwargs):
  158. super(TransformerDecoder, self).__init__(**kwargs)
  159. self.num_hiddens = num_hiddens
  160. self.num_layers = num_layers
  161. # 嵌入层
  162. self.embedding = nn.Embedding(vocab_size, num_hiddens)
  163. # 位置编码
  164. self.pos_encoding = d2l.PositionalEncoding(num_hiddens, dropout)
  165. # 添加模块(num_layers个解码器块)
  166. self.blks = nn.Sequential()
  167. for i in range(num_layers):
  168. self.blks.add_module("block"+str(i),
  169. DecoderBlock(key_size, query_size, value_size, num_hiddens,
  170. norm_shape, ffn_num_input, ffn_num_hiddens,
  171. num_heads, dropout, i))
  172. # 全连接层
  173. self.dense = nn.Linear(num_hiddens, vocab_size)
  174. def init_state(self, enc_outputs, enc_valid_lens, *args):
  175. return [enc_outputs, enc_valid_lens, [None] * self.num_layers]
  176. def forward(self, X, state):
  177. # 位置编码+嵌入层
  178. X = self.pos_encoding(self.embedding(X) * math.sqrt(self.num_hiddens))
  179. # 定义二维列表,存放 解码器自注意力权重 和 “编码器=解码器”自注意力权重
  180. self._attention_weights = [[None] * len(self.blks) for _ in range(2)]
  181. for i, blk in enumerate(self.blks):
  182. X, state = blk(X, state)
  183. # 解码器自注意力权重
  184. self._attention_weights[0][i] = blk.attention1.attention.attention_weights
  185. # “编码器-解码器”自注意力权重
  186. self._attention_weights[1][i] = blk.attention2.attention.attention_weights
  187. return self.dense(X), state # 最后把参数传入到全连接层之中
  188. def attention_weights(self):
  189. return self._attention_weights
  190. # =================================================================== #
  191. # # 训练
  192. # 隐藏层层数,编码器/解码器块数量,舍弃比率,批量大小,num_steps
  193. num_hiddens, num_layers, dropout, batch_size, num_steps = 32, 2, 0.1, 64, 10
  194. # 学习率,迭代次数,设备选取
  195. lr, num_epochs, device = 0.005, 200, d2l.try_gpu()
  196. # 前馈网络输入层,前馈网络隐藏层,注意力头的数量
  197. ffn_num_input, ffn_num_hiddens, num_heads = 32, 64, 4
  198. # 查询,键,值
  199. query_size, key_size, value_size = 32, 32, 32
  200. # AddNorm(规范化的维度)
  201. norm_shape = [32]
  202. # 加载文字数据集(英语翻译成法语)
  203. # train_iter:数据集信息
  204. # src_vocab:源(英语)词汇表
  205. # tgt_vocab:法语词汇表
  206. train_iter, src_vocab, tgt_vocab = d2l.load_data_nmt(batch_size, num_steps)
  207. encoder = TransformerEncoder(
  208. len(src_vocab), key_size, query_size, value_size, num_hiddens,
  209. norm_shape, ffn_num_input, ffn_num_hiddens, num_heads,
  210. num_layers, dropout)
  211. # print("编码器返回的结果(Transformer编码器的结构):")
  212. # print(encoder)
  213. decoder = TransformerDecoder(
  214. len(tgt_vocab), key_size, query_size, value_size, num_hiddens, norm_shape,
  215. ffn_num_input, ffn_num_hiddens, num_heads, num_layers, dropout)
  216. # print("解码器返回的结果(Transformer解码器的结构):")
  217. # print(decoder)
  218. # Transformer“编码器-解码器”网络模型
  219. net = d2l.EncoderDecoder(encoder, decoder)
  220. print("Transformer“编码器-解码器”网络模型结构:")
  221. print(net)
  222. # 训练一个序列到序列的模型
  223. # 网络模型,数据集信息,学习率,迭代次数,法语词汇表(标志输出),设备选取
  224. d2l.train_seq2seq(net, train_iter, lr, num_epochs, tgt_vocab, device)
  225. # =================================================================== #
  226. # # 测试
  227. engs = ['go .', "i lost .", 'he\'s calm .', 'i\'m home .']
  228. fras = ['va !', 'j\'ai perdu .', 'il est calme .', 'je suis chez moi .']
  229. # 多维列表
  230. dec_attention_weight_seq = []
  231. for eng, fra in zip(engs, fras):
  232. # 网络模型、单个英语词汇、英语词汇集、法语词汇集、num_steps,设备选取
  233. translation, dec_attention_weight_seq = d2l.predict_seq2seq(
  234. net, eng, src_vocab, tgt_vocab, num_steps, device, save_attention_weights=True)
  235. print(f'{eng} => {translation}, ',
  236. f'bleu {d2l.bleu(translation, fra, k=2):.3f}') # 评分函数(翻译结果同数据集相比较)
  237. enc_attention_weights = torch.cat(net.encoder.attention_weights, 0).reshape((num_layers, num_heads, -1, num_steps))
  238. print(enc_attention_weights.shape)
  239. d2l.show_heatmaps(
  240. enc_attention_weights.cpu(), xlabel='Key positions',
  241. ylabel='Query positions', titles=['Head %d' % i for i in range(1, 5)],
  242. figsize=(7, 3.5))
  243. plt.show()
  244. # =================================================================== #
  245. # # 可视化权重
  246. print(len(dec_attention_weight_seq))
  247. dec_attention_weights_2d = [head[0].tolist()
  248. for step in dec_attention_weight_seq
  249. for attn in step for blk in attn for head in blk]
  250. dec_attention_weights_filled = torch.tensor(
  251. pd.DataFrame(dec_attention_weights_2d).fillna(0.0).values)
  252. dec_attention_weights = dec_attention_weights_filled.reshape((-1, 2, num_layers, num_heads, num_steps))
  253. dec_self_attention_weights, dec_inter_attention_weights = \
  254. dec_attention_weights.permute(1, 2, 3, 0, 4)
  255. print(dec_self_attention_weights.shape, dec_inter_attention_weights.shape)
  256. # 与编码器的自注意力的情况类似,通过指定输入序列的有效长度,
  257. # [输出序列的查询不会与输入序列中填充位置的词元进行注意力计算]。
  258. # Plusonetoincludethebeginning-of-sequencetoken
  259. d2l.show_heatmaps(
  260. dec_self_attention_weights[:, :, :, :len(translation.split()) + 1],
  261. xlabel='Key positions', ylabel='Query positions',
  262. titles=['Head %d' % i for i in range(1, 5)], figsize=(7, 3.5))
  263. d2l.show_heatmaps(
  264. dec_inter_attention_weights, xlabel='Key positions',
  265. ylabel='Query positions', titles=['Head %d' % i for i in range(1, 5)],
  266. figsize=(7, 3.5))

>>>如有疑问,欢迎评论区一起探讨

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/花生_TL007/article/detail/374708
推荐阅读
相关标签
  

闽ICP备14008679号