赞
踩
transformer解码器的mindspore实现
解码器将编码器输出的上下文序列转换为目标序列的预测结果(y帽),该输出将在模型训练中与真实目标输出Y进行比较,计算损失。
不同于编码器,每个Decoder层中包含两层多头注意力机制,并在最后多出一个线性层,输出对目标序列的预测结果。
在处理目标序列的输入时,t时刻的模型只能“观察”直到t-1时刻的所有词元,后续的词语不应该一并输入Decoder中。
为了保证在t时刻,只有t-1个词元作为输入参与多头注意力分数的计算,我们需要在第一个多头注意力中额外增加一个时间掩码,使目标序列中的词随时间发展逐个被暴露出来。
该注意力掩码可通过三角矩阵实现,对角线以上的词元表示为不参与注意力计算的词元,标记为1。
- def get_attn_subsequent_mask(seq_q, seq_k):
- """生成时间掩码,使decoder在第t时刻只能看到序列的前t-1个元素
-
- Args:
- seq_q (Tensor): query序列,shape = [batch size, len_q]
- seq_k (Tensor): key序列,shape = [batch size, len_k]
- """
- batch_size, len_q = seq_q.shape
- batch_size, len_k = seq_k.shape
-
- ones = ops.ones((batch_size, len_q, len_k), mindspore.float32)
- subsequent_mask = mnp.triu(ones, k=1)
- return subsequent_mask
-
- q = k = ops.ones((1, 4), mstype.float32)
- mask = get_attn_subsequent_mask(q, k)
- print(mask)
- # 首先实现Decoder中的一个层
- def __init__(self, d_model, n_heads, d_ff, dropout_p=0.):
- super().__init__()
- d_k = d_model // n_heads
- if d_k * n_heads != d_model:
- raise ValueError(f"The `d_model` {d_model} can not be divisible by `num_heads` {n_heads}.")
- self.dec_self_attn = MultiHeadAttention(d_model, d_k, n_heads, dropout_p)
- self.dec_enc_attn = MultiHeadAttention(d_model, d_k, n_heads, dropout_p)
- self.pos_ffn = PoswiseFeedForward(d_ff, d_model, dropout_p)
- self.add_norm1 = AddNorm(d_model, dropout_p)
- self.add_norm2 = AddNorm(d_model, dropout_p)
- self.add_norm3 = AddNorm(d_model, dropout_p)
-
- def construct(self, dec_inputs, enc_outputs, dec_self_attn_mask, dec_enc_attn_mask):
- """
- dec_inputs: [batch_size, trg_len, d_model]
- enc_outputs: [batch_size, src_len, d_model]
- dec_self_attn_mask: [batch_size, trg_len, trg_len]
- dec_enc_attn_mask: [batch_size, trg_len, src_len]
- """
- residual = dec_inputs
-
- dec_outputs, dec_self_attn = self.dec_self_attn(dec_inputs, dec_inputs, dec_inputs, dec_self_attn_mask)
-
- dec_outputs = self.add_norm1(dec_outputs, residual)
- residual = dec_outputs
-
- dec_outputs, dec_enc_attn = self.dec_enc_attn(dec_outputs, enc_outputs, enc_outputs, dec_enc_attn_mask)
-
- dec_outputs = self.add_norm2(dec_outputs, residual)
- residual = dec_outputs
-
- dec_outputs = self.pos_ffn(dec_outputs)
-
- dec_outputs = self.add_norm3(dec_outputs, residual)
-
- return dec_outputs, dec_self_attn, dec_enc_attn
-
- x = y = ops.ones((1, 2, 4), mstype.float32)
- mask1 = mask2 = Tensor([False]).broadcast_to((1, 2, 2))
- decoder_layer = DecoderLayer(4, 1, 16)
- output, attn1, attn2 = decoder_layer(x, y, mask1, mask2)
- print(output.shape, attn1.shape, attn2.shape)
-
- # 将上面实现的DecoderLayer堆叠n_layer次,添加word embedding与positional encoding,以及最后的线性层。
-
- # 输出的dec_outputs为对目标序列的预测。
- class Decoder(nn.Cell):
- def __init__(self, trg_vocab_size, d_model, n_heads, d_ff, n_layers, dropout_p=0.):
- super().__init__()
- self.trg_emb = nn.Embedding(trg_vocab_size, d_model)
- self.pos_emb = PositionalEncoding(d_model, dropout_p)
- self.layers = nn.CellList([DecoderLayer(d_model, n_heads, d_ff) for _ in range(n_layers)])
- self.projection = nn.Dense(d_model, trg_vocab_size)
- self.scaling_factor = ops.Sqrt()(Tensor(d_model, mstype.float32))
-
- def construct(self, dec_inputs, enc_inputs, enc_outputs, src_pad_idx, trg_pad_idx):
- """
- dec_inputs: [batch_size, trg_len]
- enc_inputs: [batch_size, src_len]
- enc_outputs: [batch_size, src_len, d_model]
- """
- dec_outputs = self.trg_emb(dec_inputs.astype(mstype.int32))
- dec_outputs = self.pos_emb(dec_outputs * self.scaling_factor)
-
- dec_self_attn_pad_mask = get_attn_pad_mask(dec_inputs, dec_inputs, trg_pad_idx)
- dec_self_attn_subsequent_mask = get_attn_subsequent_mask(dec_inputs, dec_inputs)
- dec_self_attn_mask = ops.gt((dec_self_attn_pad_mask + dec_self_attn_subsequent_mask), 0)
-
- dec_enc_attn_mask = get_attn_pad_mask(dec_inputs, enc_inputs, src_pad_idx)
-
- dec_self_attns, dec_enc_attns = [], []
- for layer in self.layers:
- dec_outputs, dec_self_attn, dec_enc_attn = layer(dec_outputs, enc_outputs, dec_self_attn_mask, dec_enc_attn_mask)
- dec_self_attns.append(dec_self_attn)
- dec_enc_attns.append(dec_enc_attn)
-
- dec_outputs = self.projection(dec_outputs)
- return dec_outputs, dec_self_attns, dec_enc_attns
学会使用mindspore实现transformer的decoder部分
使用mindspore实现decoder加深了对decoder的理解
继续学习将decoder和encoder结合起来
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。