当前位置:   article > 正文

自然语言处理实验——基于Transformer实现机器翻译(日译中)_基于transformer的机器翻译

基于transformer的机器翻译

一、机器翻译

机器翻译是指将一段文本从一种语言自动翻译到另一种语言。因为一段文本序列在不同语言中的长度不一定相同,所以我们使用机器翻译为例来介绍编码器—解码器和注意力机制的应用。

二、读取和预处理数据

我们先定义一些特殊符号。其中“<pad>”(padding)符号用来添加在较短序列后,直到每个序列等长,而“<bos>”和“<eos>”符号分别表示序列的开始和结束。

  1. import collections
  2. import os
  3. import io
  4. import math
  5. import torch
  6. from torch import nn
  7. import torch.nn.functional as F
  8. import torchtext.vocab as Vocab
  9. import torch.utils.data as Data
  10. import sys
  11. # sys.path.append("..")
  12. import d2lzh_pytorch as d2l
  13. PAD, BOS, EOS = '<pad>', '<bos>', '<eos>'
  14. os.environ["CUDA_VISIBLE_DEVICES"] = "0"
  15. device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  16. print(torch.__version__, device)

接着定义两个辅助函数对后面读取的数据进行预处理。

  1. # 将一个序列中所有的词记录在all_tokens中以便之后构造词典,然后在该序列后面添加PAD直到序列
  2. # 长度变为max_seq_len,然后将序列保存在all_seqs中
  3. def process_one_seq(seq_tokens, all_tokens, all_seqs, max_seq_len):
  4. all_tokens.extend(seq_tokens)
  5. seq_tokens += [EOS] + [PAD] * (max_seq_len - len(seq_tokens) - 1)
  6. all_seqs.append(seq_tokens)
  7. # 使用所有的词来构造词典。并将所有序列中的词变换为词索引后构造Tensor
  8. def build_data(all_tokens, all_seqs):
  9. vocab = Vocab.Vocab(collections.Counter(all_tokens),
  10. specials=[PAD, BOS, EOS])
  11. indices = [[vocab.stoi[w] for w in seq] for seq in all_seqs]
  12. return vocab, torch.tensor(indices)

为了演示方便,我们在这里使用一个很小的法语—英语数据集。在这个数据集里,每一行是一对法语句子和它对应的英语句子,中间使用'\t'隔开。在读取数据时,我们在句末附上“<eos>”符号,并可能通过添加“<pad>”符号使每个序列的长度均为max_seq_len。我们为法语词和英语词分别创建词典。法语词的索引和英语词的索引相互独立。

  1. def read_data(max_seq_len):
  2. # in和out分别是input和output的缩写
  3. in_tokens, out_tokens, in_seqs, out_seqs = [], [], [], []
  4. with io.open('fr-en-small.txt') as f:
  5. lines = f.readlines()
  6. for line in lines:
  7. in_seq, out_seq = line.rstrip().split('\t')
  8. in_seq_tokens, out_seq_tokens = in_seq.split(' '), out_seq.split(' ')
  9. if max(len(in_seq_tokens), len(out_seq_tokens)) > max_seq_len - 1:
  10. continue # 如果加上EOS后长于max_seq_len,则忽略掉此样本
  11. process_one_seq(in_seq_tokens, in_tokens, in_seqs, max_seq_len)
  12. process_one_seq(out_seq_tokens, out_tokens, out_seqs, max_seq_len)
  13. in_vocab, in_data = build_data(in_tokens, in_seqs)
  14. out_vocab, out_data = build_data(out_tokens, out_seqs)
  15. return in_vocab, out_vocab, Data.TensorDataset(in_data, out_data)

将序列的最大长度设成7,然后查看读取到的第一个样本。该样本分别包含法语词索引序列和英语词索引序列。

  1. max_seq_len = 7
  2. in_vocab, out_vocab, dataset = read_data(max_seq_len)
  3. dataset[0]

运行结果:

(tensor([ 5,  4, 45,  3,  2,  0,  0]), tensor([ 8,  4, 27,  3,  2,  0,  0]))

三、 含注意力机制的编码器—解码器

我们将使用含注意力机制的编码器—解码器来将一段简短的法语翻译成英语。下面我们来介绍模型的实现。

3.1 编码器

在编码器中,我们将输入语言的词索引通过词嵌入层得到词的表征,然后输入到一个多层门控循环单元中。正如我们在6.5节(循环神经网络的简洁实现)中提到的,PyTorch的nn.GRU实例在前向计算后也会分别返回输出和最终时间步的多层隐藏状态。其中的输出指的是最后一层的隐藏层在各个时间步的隐藏状态,并不涉及输出层计算。注意力机制将这些输出作为键项和值项。

  1. class Encoder(nn.Module):
  2. def __init__(self, vocab_size, embed_size, num_hiddens, num_layers,
  3. drop_prob=0, **kwargs):
  4. super(Encoder, self).__init__(**kwargs)
  5. self.embedding = nn.Embedding(vocab_size, embed_size)
  6. self.rnn = nn.GRU(embed_size, num_hiddens, num_layers, dropout=drop_prob)
  7. def forward(self, inputs, state):
  8. # 输入形状是(批量大小, 时间步数)。将输出互换样本维和时间步维
  9. embedding = self.embedding(inputs.long()).permute(1, 0, 2) # (seq_len, batch, input_size)
  10. return self.rnn(embedding, state)
  11. def begin_state(self):
  12. return None

下面我们来创建一个批量大小为4、时间步数为7的小批量序列输入。设门控循环单元的隐藏层个数为2,隐藏单元个数为16。编码器对该输入执行前向计算后返回的输出形状为(时间步数, 批量大小, 隐藏单元个数)。门控循环单元在最终时间步的多层隐藏状态的形状为(隐藏层个数, 批量大小, 隐藏单元个数)。对于门控循环单元来说,state就是一个元素,即隐藏状态;如果使用长短期记忆,state是一个元组,包含两个元素即隐藏状态和记忆细胞。

  1. encoder = Encoder(vocab_size=10, embed_size=8, num_hiddens=16, num_layers=2)
  2. output, state = encoder(torch.zeros((4, 7)), encoder.begin_state())
  3. output.shape, state.shape # GRU的state是h, 而LSTM的是一个元组(h, c)

运行结果:

(torch.Size([7, 4, 16]), torch.Size([2, 4, 16]))

3.2 注意力机制

我们将实现注意力机制中定义的函数声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/码创造者/article/detail/795637

推荐阅读
相关标签