当前位置:   article > 正文

昇思MindSpore技术公开课-transformer(4)_transformer-mindspore

transformer-mindspore

学习总结

transformer解码器的mindspore实现

学习心得

解码器 (Decoder)

解码器将编码器输出的上下文序列转换为目标序列的预测结果\gamma(y帽),该输出将在模型训练中与真实目标输出Y进行比较,计算损失。

不同于编码器,每个Decoder层中包含两层多头注意力机制,并在最后多出一个线性层,输出对目标序列的预测结果。

  • 第一层:计算目标序列的注意力分数的掩码多头自注意力
  • 第二层:用于计算上下文序列与目标序列对应关系,其中Decoder掩码多头注意力的输出作为query,Encoder的输出(上下文序列)作为key和value;

带掩码的多头注意力

在处理目标序列的输入时,t时刻的模型只能“观察”直到t-1时刻的所有词元,后续的词语不应该一并输入Decoder中。

为了保证在t时刻,只有t-1个词元作为输入参与多头注意力分数的计算,我们需要在第一个多头注意力中额外增加一个时间掩码,使目标序列中的词随时间发展逐个被暴露出来。

该注意力掩码可通过三角矩阵实现,对角线以上的词元表示为不参与注意力计算的词元,标记为1。

经验分享

  1. def get_attn_subsequent_mask(seq_q, seq_k):
  2. """生成时间掩码,使decoder在第t时刻只能看到序列的前t-1个元素
  3. Args:
  4. seq_q (Tensor): query序列,shape = [batch size, len_q]
  5. seq_k (Tensor): key序列,shape = [batch size, len_k]
  6. """
  7. batch_size, len_q = seq_q.shape
  8. batch_size, len_k = seq_k.shape
  9. ones = ops.ones((batch_size, len_q, len_k), mindspore.float32)
  10. subsequent_mask = mnp.triu(ones, k=1)
  11. return subsequent_mask
  12. q = k = ops.ones((1, 4), mstype.float32)
  13. mask = get_attn_subsequent_mask(q, k)
  14. print(mask)

Decoder Layer

  1. # 首先实现Decoder中的一个层
  2. def __init__(self, d_model, n_heads, d_ff, dropout_p=0.):
  3. super().__init__()
  4. d_k = d_model // n_heads
  5. if d_k * n_heads != d_model:
  6. raise ValueError(f"The `d_model` {d_model} can not be divisible by `num_heads` {n_heads}.")
  7. self.dec_self_attn = MultiHeadAttention(d_model, d_k, n_heads, dropout_p)
  8. self.dec_enc_attn = MultiHeadAttention(d_model, d_k, n_heads, dropout_p)
  9. self.pos_ffn = PoswiseFeedForward(d_ff, d_model, dropout_p)
  10. self.add_norm1 = AddNorm(d_model, dropout_p)
  11. self.add_norm2 = AddNorm(d_model, dropout_p)
  12. self.add_norm3 = AddNorm(d_model, dropout_p)
  13. def construct(self, dec_inputs, enc_outputs, dec_self_attn_mask, dec_enc_attn_mask):
  14. """
  15. dec_inputs: [batch_size, trg_len, d_model]
  16. enc_outputs: [batch_size, src_len, d_model]
  17. dec_self_attn_mask: [batch_size, trg_len, trg_len]
  18. dec_enc_attn_mask: [batch_size, trg_len, src_len]
  19. """
  20. residual = dec_inputs
  21. dec_outputs, dec_self_attn = self.dec_self_attn(dec_inputs, dec_inputs, dec_inputs, dec_self_attn_mask)
  22. dec_outputs = self.add_norm1(dec_outputs, residual)
  23. residual = dec_outputs
  24. dec_outputs, dec_enc_attn = self.dec_enc_attn(dec_outputs, enc_outputs, enc_outputs, dec_enc_attn_mask)
  25. dec_outputs = self.add_norm2(dec_outputs, residual)
  26. residual = dec_outputs
  27. dec_outputs = self.pos_ffn(dec_outputs)
  28. dec_outputs = self.add_norm3(dec_outputs, residual)
  29. return dec_outputs, dec_self_attn, dec_enc_attn
  30. x = y = ops.ones((1, 2, 4), mstype.float32)
  31. mask1 = mask2 = Tensor([False]).broadcast_to((1, 2, 2))
  32. decoder_layer = DecoderLayer(4, 1, 16)
  33. output, attn1, attn2 = decoder_layer(x, y, mask1, mask2)
  34. print(output.shape, attn1.shape, attn2.shape)
  35. # 将上面实现的DecoderLayer堆叠n_layer次,添加word embedding与positional encoding,以及最后的线性层。
  36. # 输出的dec_outputs为对目标序列的预测。
  37. class Decoder(nn.Cell):
  38. def __init__(self, trg_vocab_size, d_model, n_heads, d_ff, n_layers, dropout_p=0.):
  39. super().__init__()
  40. self.trg_emb = nn.Embedding(trg_vocab_size, d_model)
  41. self.pos_emb = PositionalEncoding(d_model, dropout_p)
  42. self.layers = nn.CellList([DecoderLayer(d_model, n_heads, d_ff) for _ in range(n_layers)])
  43. self.projection = nn.Dense(d_model, trg_vocab_size)
  44. self.scaling_factor = ops.Sqrt()(Tensor(d_model, mstype.float32))
  45. def construct(self, dec_inputs, enc_inputs, enc_outputs, src_pad_idx, trg_pad_idx):
  46. """
  47. dec_inputs: [batch_size, trg_len]
  48. enc_inputs: [batch_size, src_len]
  49. enc_outputs: [batch_size, src_len, d_model]
  50. """
  51. dec_outputs = self.trg_emb(dec_inputs.astype(mstype.int32))
  52. dec_outputs = self.pos_emb(dec_outputs * self.scaling_factor)
  53. dec_self_attn_pad_mask = get_attn_pad_mask(dec_inputs, dec_inputs, trg_pad_idx)
  54. dec_self_attn_subsequent_mask = get_attn_subsequent_mask(dec_inputs, dec_inputs)
  55. dec_self_attn_mask = ops.gt((dec_self_attn_pad_mask + dec_self_attn_subsequent_mask), 0)
  56. dec_enc_attn_mask = get_attn_pad_mask(dec_inputs, enc_inputs, src_pad_idx)
  57. dec_self_attns, dec_enc_attns = [], []
  58. for layer in self.layers:
  59. dec_outputs, dec_self_attn, dec_enc_attn = layer(dec_outputs, enc_outputs, dec_self_attn_mask, dec_enc_attn_mask)
  60. dec_self_attns.append(dec_self_attn)
  61. dec_enc_attns.append(dec_enc_attn)
  62. dec_outputs = self.projection(dec_outputs)
  63. return dec_outputs, dec_self_attns, dec_enc_attns

课程反馈

学会使用mindspore实现transformer的decoder部分

使用MindSpore昇思的体验和反馈

使用mindspore实现decoder加深了对decoder的理解

未来展望

继续学习将decoder和encoder结合起来

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/代码探险家/article/detail/883537
推荐阅读
相关标签
  

闽ICP备14008679号