当前位置:   article > 正文

基于注意力机制的seq2seq模型_基于点积的注意力机制 seq2seq

基于点积的注意力机制 seq2seq

一、前言

在此之前,我们实现了最普通的seq2seq模型,该模型的编码器和解码器均采用的是两层单向的GRU。本篇文章将基于注意力机制改进之前的seq2seq模型,其中编码器采用两层双向的LSTM,解码器采用含有注意力机制的两层单向LSTM。由于数据预处理部分相同,因此本文不再赘述,详情可参考之前的文章。

二、模型搭建

本文接下来的叙述将沿用这篇文章中的符号。

2.1 编码器

编码器我们采用两层双向LSTM。编码器的输入形状为 ( N , L ) (N,L) (N,L),输出 output 的形状为 ( L , N , 2 h ) (L,N,2h) (L,N,2h),它是正向LSTM和反向LSTM输出进行了concat后的结果,包含了正反向的信息。编码器输出的 h_nc_n 的形状均为 ( 2 n , N , h ) (2n,N,h) (2n,N,h),需要将其形状改变为 ( n , N , 2 h ) (n,N,2h) (n,N,2h) 后才可作为解码器的初始隐状态。

至于为什么要改变 h_nc_n 的形状以及为什么不能直接用 reshape 去改变会在后面提到。

编码器的实现如下:

class Seq2SeqEncoder(nn.Module):
    def __init__(self, vocab_size, emb_size, hidden_size, num_layers=2, dropout=0.1):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, emb_size, padding_idx=1)
        self.rnn = nn.LSTM(emb_size, hidden_size, num_layers=num_layers, dropout=dropout, bidirectional=True)

    def forward(self, encoder_inputs):
        encoder_inputs = self.embedding(encoder_inputs).permute(1, 0, 2)
        output, (h_n, c_n) = self.rnn(encoder_inputs)  # output shape: (seq_len, batch_size, 2 * hidden_size)
        h_n = torch.cat((h_n[::2], h_n[1::2]), dim=2)  # (num_layers, batch_size, 2 * hidden_size)
        c_n = torch.cat((c_n[::2], c_n[1::2]), dim=2)  # (num_layers, batch_size, 2 * hidden_size)
        return output, h_n, c_n
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12

2.2 注意力机制

在原先的seq2seq模型中,解码器在每一个时间步所使用的上下文向量均相同。现在我们希望解码器在不同的时间步上能够注意到源序列中不同的信息,因此考虑采用注意力机制。

解码器的核心架构为两层单向的LSTM(只能是单向),在 t t t 时刻,我们采用解码器在 t − 1 t-1 t1 时刻最后一个隐层的输出作为查询,每个 output[t] 既作为键也作为值,相应的计算上下文向量的公式如下:

context [ t ] = ∑ i = 1 L α ( decoder_state [ t − 1 ] , output [ i ] ) ⋅ output [ i ] \text{context}[t]=\sum_{i=1}^L \alpha(\text{decoder\_state}[t-1], \text{output}[i])\cdot \text{output}[i] context[t]=i=1Lα(decoder_state[t1],output[i])output[i]

其中 α ( q , k ) \alpha(q,k) α(q,k) 是注意力权重。

假设编码器所采用的LSTM的隐层大小为 h h h,解码器所采用的LSTM的隐层大小为 h ′ h' h。因 output[t] 的形状为 ( N , 2 h ) (N,2h) (N,2h)decoder_state[t - 1] 的形状为 ( N , h ′ ) (N,h') (N,h),要使用缩放点积注意力,则必须有 h ′ = 2 h h'=2h h=2h,否则无法进行内积操作,所以可以得出:解码器隐层大小是编码器的两倍

注意力机制实现如下:

class AttentionMechanism(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, decoder_state, encoder_output):
        # 解码器的隐藏层大小必须是编码器的两倍,否则无法进行接下来的内积操作
        # decoder_state shape: (batch_size, 2 * hidden_size)
        # encoder_output shape: (seq_len, batch_size, 2 * hidden_size)
        decoder_state = decoder_state.unsqueeze(1)  # (batch_size, 1, 2 * hidden_size)
        encoder_output = encoder_output.transpose(0, 1)  # (batch_size, seq_len, 2 * hidden_size)
        # scores shape: (batch_size, seq_len)
        scores = torch.sum(decoder_state * encoder_output, dim=-1) / math.sqrt(decoder_state.shape[2])  # 广播机制
        attn_weights = F.softmax(scores, dim=-1)
        # context shape: (batch_size, 2 * hidden_size)
        context = torch.sum(attn_weights.unsqueeze(-1) * encoder_output, dim=1)  # 广播机制
        return context
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16

2.3 解码器

解码器的原始输入的形状为 ( N , L ) (N,L) (N,L),通过嵌入以及 permute 操作后其形状变为 ( L , N , d ) (L,N,d) (L,N,d),而上下文的形状为 ( N , 2 h ) (N,2h) (N,2h) ( L , N , d ) (L,N,d) (L,N,d) ( N , 2 h ) (N,2h) (N,2h) 进行concat后才会作为其内部LSTM的输入。因此解码器采用的LSTM的 input_size d + 2 h d+2h d+2h。为了保证注意力机制正常运作,其隐藏层大小也应为编码器的两倍,即 2 h 2h 2h

从编码器我们得到了形状为 ( 2 n , N , h ) (2n,N,h) (2n,N,h)h_n,而解码器采用的LSTM是单向的,从而其接受的 h_0 的形状应为 ( n , N , 2 h ) (n,N,2h) (n,N,2h)。一个很自然的想法是直接使用 reshape 完成形状的转化,但这样做会带来一个问题,即无法保证 h_0[-1] 对应的是正反向编码器在最后一个时间步最后一个隐层的输出的拼接,为此可考虑采用如下方式解决:

h 0 = Concat ( ( h n [ :   : 2 ] , h n [ 1 :   : 2 ] ) ,    dim = 2 ) h_0 =\text{Concat}((h_n [ : \, : 2],h_n [ 1: \, : 2]),\;\text{dim} = 2) h0=Concat((hn[::2],hn[1::2]),dim=2)

至于为什么这样做,可以参考这篇文章

在评估阶段中,我们往往需要利用模型的解码器一步一步地输出,每一时刻都会利用上一时刻解码器输出的隐状态,类似于下面的伪代码:

decoder_output, hidden_state = decoder(decoder_input, hidden_state)
  • 1

这要求输入到解码器中的隐状态和解码器输出的隐状态的形状必须相同因此 h_nc_n 的形状转化必须在编码器中完成

解码器的实现:

class Seq2SeqDecoder(nn.Module):
    def __init__(self, vocab_size, emb_size, hidden_size, num_layers=2, dropout=0.1):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, emb_size, padding_idx=1)
        self.attn = AttentionMechanism()
        self.rnn = nn.LSTM(emb_size + 2 * hidden_size, 2 * hidden_size, num_layers=num_layers, dropout=dropout)
        self.fc = nn.Linear(2 * hidden_size, vocab_size)

    def forward(self, decoder_inputs, encoder_output, h_n, c_n):
        decoder_inputs = self.embedding(decoder_inputs).permute(1, 0, 2)  # (seq_len, batch_size, emb_size)
        # 注意将其移动到GPU上
        decoder_output = torch.zeros(decoder_inputs.shape[0], *h_n.shape[1:]).to(device)  # (seq_len, batch_size, 2 * hidden_size)
        for i in range(len(decoder_inputs)):
            context = self.attn(h_n[-1], encoder_output)  # (batch_size, 2 * hidden_size)
            # single_step_output shape: (1, batch_size, 2 * hidden_size)
            single_step_output, (h_n, c_n) = self.rnn(torch.cat((decoder_inputs[i], context), -1).unsqueeze(0), (h_n, c_n))
            decoder_output[i] = single_step_output.squeeze()
        logits = self.fc(decoder_output)  # (seq_len, batch_size, vocab_size)
        return logits, h_n, c_n
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19

2.4 Seq2Seq模型

整体架构如下:

只需要将编码器和解码器封装在一起即可:

class Seq2SeqModel(nn.Module):
    def __init__(self, encoder, decoder):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder

    def forward(self, encoder_inputs, decoder_inputs):
        return self.decoder(decoder_inputs, *self.encoder(encoder_inputs))
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8

三、模型的训练与评估

因为输入输出发生了一些变化,我们只需要对原先的 train 函数和 evaluate 函数稍作修改

def train(train_loader, model, criterion, optimizer, num_epochs):
    train_loss = []
    model.train()
    for epoch in range(num_epochs):
        for batch_idx, (encoder_inputs, decoder_targets) in enumerate(train_loader):
            encoder_inputs, decoder_targets = encoder_inputs.to(device), decoder_targets.to(device)
            bos_column = torch.tensor([tgt_vocab['<bos>']] * decoder_targets.shape[0]).reshape(-1, 1).to(device)
            decoder_inputs = torch.cat((bos_column, decoder_targets[:, :-1]), dim=1)
            pred, _, _ = model(encoder_inputs, decoder_inputs)
            loss = criterion(pred.permute(1, 2, 0), decoder_targets)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            train_loss.append(loss.item())
            if (batch_idx + 1) % 50 == 0:
                print(
                    f'[Epoch {epoch + 1}] [{(batch_idx + 1) * len(encoder_inputs)}/{len(train_loader.dataset)}] loss: {loss:.4f}')
        print()
    return train_loss


def evaluate(test_loader, model, bleu_k):
    bleu_scores = []
    translation_results = []
    model.eval()
    for src_seq, tgt_seq in test_loader:
        encoder_inputs = src_seq.to(device)
        encoder_output, h_n, c_n = model.encoder(encoder_inputs)
        pred_seq = [tgt_vocab['<bos>']]
        for _ in range(SEQ_LEN):
            decoder_inputs = torch.tensor(pred_seq[-1]).reshape(1, 1).to(device)
            pred, h_n, c_n = model.decoder(decoder_inputs, encoder_output, h_n, c_n)
            next_token_idx = pred.squeeze().argmax().item()
            if next_token_idx == tgt_vocab['<eos>']:
                break
            pred_seq.append(next_token_idx)
        pred_seq = tgt_vocab[pred_seq[1:]]
        tgt_seq = tgt_seq.squeeze().tolist()
        tgt_seq = tgt_vocab[tgt_seq[:tgt_seq.index(tgt_vocab['<eos>'])]] if tgt_vocab['<eos>'] in tgt_seq else tgt_vocab[tgt_seq]
        translation_results.append((' '.join(tgt_seq), ' '.join(pred_seq)))
        if len(pred_seq) >= bleu_k:
            bleu_scores.append(bleu(tgt_seq, pred_seq, k=bleu_k))

    return bleu_scores, translation_results
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46

保持其他超参数不变,使用 NVIDIA A40 进行训练(亲测 RTX 3090 会爆掉显存),大概需要6个小时,损失函数曲线如下:

在这里插入图片描述

与之前不同的是,在评估阶段,我们会分别计算平均BLEU-{2,3,4}分数并与原先的模型进行比较

bleu_2_scores, _ = evaluate(test_loader, net, bleu_k=2)
bleu_3_scores, _ = evaluate(test_loader, net, bleu_k=3)
bleu_4_scores, _ = evaluate(test_loader, net, bleu_k=4)
print(f"BLEU-2: {np.mean(bleu_2_scores)} | BLEU-3: {np.mean(bleu_3_scores)} | BLEU-4: {np.mean(bleu_4_scores)}")
  • 1
  • 2
  • 3
  • 4

比较结果列在下表中

模型平均BLEU-2平均BLEU-3平均BLEU-4
Vanilla Seq2Seq(链接0.47990.32290.2144
Attention-based Seq2Seq(本文)0.57110.41950.3036

可以看出加入了注意力机制后,BLEU得分提升了约十个百分点

一些可以改进的地方:

  • 完全可以先将 translation_results 计算出来再计算每种BLEU得分,这样做可以大大节省时间;
  • 训练过程中Teacher Forcing的比率为100%,可以尝试降低此比率以达到更好的效果;
  • BLEU无法理解同义词,导致一些合理的翻译会被否定,可以尝试换用其他的度量来更准确地评估模型。

附录一、翻译效果比较

translation_results 中随机抽取十个。

target:     je suis plutôt occupée .
vanilla:    je suis plutôt occupé .
attn-based: je suis plutôt occupé .

target:     ça t'arrive de dormir ?
vanilla:    t'arrive-t-il de dormir ?
attn-based: t'arrive-t-il de dormir ?

target:     je ne partirai probablement pas demain .
vanilla:    je ne vais probablement pas vouloir demain .
attn-based: je ne serai probablement pas demain .

target:     je suis prudent .
vanilla:    je suis prudente .
attn-based: je suis prudente .

target:     je suis sure que c'était juste un malentendu .
vanilla:    je suis sûr que c'était un malentendu .
attn-based: je suis sûr que ce fut un malentendu .

target:     je me demandais ce qui t'avait fait changer d'avis .
vanilla:    je me demandais ce que tu ressens .
attn-based: je me demandais ce qui aurait réussi à ce sujet .

target:     il me jeta un regard sévère .
vanilla:    il me fit une robe bleue .
attn-based: il m'a donné un grand regard .

target:     te fies-tu à qui que ce soit ?
vanilla:    vous fiez-vous à quiconque ?
attn-based: te fies-tu à quiconque ?

target:     es-tu sûre d'avoir assez chaud ?
vanilla:    es-tu sûr que tu es allé ?
attn-based: êtes-vous sûr d'avoir assez chaud ?

target:     je commençais à me faire du souci à ton sujet .
vanilla:    je commençais à m'inquiéter pour toi .
attn-based: je commençais à m'inquiéter à votre sujet .
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39

附录二、完整代码

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/很楠不爱3/article/detail/420594
推荐阅读
相关标签