当前位置:   article > 正文

【PyTorch】11 聊天机器人实战——Cornell Movie-Dialogs Corpus电影剧本数据集处理、利用Global attention实现Seq2Seq模型

cornell movie-dialogs corpus

此为官方PyTorch之文本篇的最后一个教程

在本教程中,我们探索一个好玩有趣的循环的序列到序列(sequence-to-sequence)的模型用例。我们将用Cornell Movie-Dialogs Corpus 处的电影剧本来训练一个简单的聊天机器人

在人工智能研究领域中,对话模型是一个非常热门的话题。聊天机器人可以在各种设置中找到,包括客户服务应用和在线帮助。这些机器人通常由基于检索的模型提供支持,这些模型的输出是某些形式问题预先定义的响应。在像公司IT服务台这样高度受限制的领域中,这些模型可能足够了,但是,对于更一般的用例它们还不够健壮。让一台机器与多领域的人进行有意义的对话是一个远未解决的研究问题。最近,深度学习热潮已经允许强大的生成模型,如谷歌的神经对话模型Neural Conversational Model,这标志着向多领域生成对话模型迈出了一大步。 在本教程中,我们将在PyTorch中实现这种模型

要点:

1. 下载数据文件

下载数据文件点击这里并将其放入到目标目录下
在这里插入图片描述

2. 加载和预处理数据

下一步就是格式化处理我们的数据文件并将数据加载到我们可以使用的结构中。 Cornell Movie-Dialogs Corpus是一个丰富的电影角色对话数据集: * 10,292 对电影角色之间的220,579次对话 * 617部电影中的9,035个电影角色 * 总共304,713发言量

这个数据集庞大而多样,在语言形式、时间段、情感上等都有很大的变化。我们希望这种多样性使我们的模型能够适应多种形式的输入和查询

首先,我们通过数据文件的某些行来查看原始数据的格式

with open(file, 'rb') as datafile:
    lines = datafile.readlines()
    for line in lines[:10]:
        print(line)
  • 1
  • 2
  • 3
  • 4

结果:

b'L1045 +++$+++ u0 +++$+++ m0 +++$+++ BIANCA +++$+++ They do not!\n'
b'L1044 +++$+++ u2 +++$+++ m0 +++$+++ CAMERON +++$+++ They do to!\n'
b'L985 +++$+++ u0 +++$+++ m0 +++$+++ BIANCA +++$+++ I hope so.\n'
b'L984 +++$+++ u2 +++$+++ m0 +++$+++ CAMERON +++$+++ She okay?\n'
b"L925 +++$+++ u0 +++$+++ m0 +++$+++ BIANCA +++$+++ Let's go.\n"
b'L924 +++$+++ u2 +++$+++ m0 +++$+++ CAMERON +++$+++ Wow\n'
b"L872 +++$+++ u0 +++$+++ m0 +++$+++ BIANCA +++$+++ Okay -- you're gonna need to learn how to lie.\n"
b'L871 +++$+++ u2 +++$+++ m0 +++$+++ CAMERON +++$+++ No\n'
b'L870 +++$+++ u0 +++$+++ m0 +++$+++ BIANCA +++$+++ I\'m kidding.  You know how sometimes you just become this "persona"?  And you don\'t know how to quit?\n'
b'L869 +++$+++ u0 +++$+++ m0 +++$+++ BIANCA +++$+++ Like my fear of wearing pastels?\n'
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10

2.1 创建格式化数据文件

为了方便起见,我们将创建一个格式良好的数据文件,其中每一行包含一个由tab制表符分隔的查询语句和响应语句对

以下函数便于解析原始 movie_lines.txt 数据文件:

  • loadLines:将文件的每一行拆分为字段(lineID, characterID, movieID, character, text)组合的字典
  • loadConversations :根据movie_conversations.txt将loadLines中的每一行数据进行归类
  • extractSentencePairs: 从对话中提取句子对

"movie_lines.txt"文件第一行:

L1045 +++$+++ u0 +++$+++ m0 +++$+++ BIANCA +++$+++ They do not!
  • 1

loadLines处理完结果:

print(lines['L1044'])
  • 1
{'lineID': 'L1044', 'characterID': 'u2', 'movieID': 'm0', 'character': 'CAMERON', 'text': 'They do to!\n'}
  • 1

"movie_conversations.txt"文件第一行:

u0 +++$+++ u2 +++$+++ m0 +++$+++ ['L194', 'L195', 'L196', 'L197']
  • 1

loadConversations处理完结果:

print(conversations[0])
  • 1
{'character1ID': 'u0', 'character2ID': 'u2', 'movieID': 'm0', 'utteranceIDs': "['L194', 'L195', 'L196', 'L197']\n", 'lines': [{'lineID': 'L194', 'characterID': 'u0', 'movieID': 'm0', 'character': 'BIANCA', 'text': 'Can we make this quick?  Roxanne Korrine and Andrew Barrett are having an incredibly horrendous public break- up on the quad.  Again.\n'}, {'lineID': 'L195', 'characterID': 'u2', 'movieID': 'm0', 'character': 'CAMERON', 'text': "Well, I thought we'd start with pronunciation, if that's okay with you.\n"}, {'lineID': 'L196', 'characterID': 'u0', 'movieID': 'm0', 'character': 'BIANCA', 'text': 'Not the hacking and gagging and spitting part.  Please.\n'}, {'lineID': 'L197', 'characterID': 'u2', 'movieID': 'm0', 'character': 'CAMERON', 'text': "Okay... then how 'bout we try out some French cuisine.  Saturday?  Night?\n"}]}
  • 1

extractSentencePairs处理完结果:

print(extractSentencePairs(conversations)[0])
  • 1
['Can we make this quick?  Roxanne Korrine and Andrew Barrett are having an incredibly horrendous public break- up on the quad.  Again.', "Well, I thought we'd start with pronunciation, if that's okay with you."]
  • 1

现在我们将调用这些函数来创建文件,我们命名为formatted_movie_lines.txt

print("Sample lines from file:")	
with open(datafile, 'rb') as datafile:		# 打印样本的十行
    lines = datafile.readlines()
    for line in lines[:10]:
        print(line)
  • 1
  • 2
  • 3
  • 4
  • 5
Sample lines from file:
b"Can we make this quick?  Roxanne Korrine and Andrew Barrett are having an incredibly horrendous public break- up on the quad.  Again.\tWell, I thought we'd start with pronunciation, if that's okay with you.\r\n"
b"Well, I thought we'd start with pronunciation, if that's okay with you.\tNot the hacking and gagging and spitting part.  Please.\r\n"
b"Not the hacking and gagging and spitting part.  Please.\tOkay... then how 'bout we try out some French cuisine.  Saturday?  Night?\r\n"
b"You're asking me out.  That's so cute. What's your name again?\tForget it.\r\n"
b"No, no, it's my fault -- we didn't have a proper introduction ---\tCameron.\r\n"
b"Cameron.\tThe thing is, Cameron -- I'm at the mercy of a particularly hideous breed of loser.  My sister.  I can't date until she does.\r\n"
b"The thing is, Cameron -- I'm at the mercy of a particularly hideous breed of loser.  My sister.  I can't date until she does.\tSeems like she could get a date easy enough...\r\n"
b'Why?\tUnsolved mystery.  She used to be really popular when she started high school, then it was just like she got sick of it or something.\r\n'
b"Unsolved mystery.  She used to be really popular when she started high school, then it was just like she got sick of it or something.\tThat's a shame.\r\n"
b'Gosh, if only we could find Kat a boyfriend...\tLet me see what I can do.\r\n'
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11

可以发现比教程多了一个’\r’,意思是将当前位置移至本行开头,问题应该不大

2.2 加载和清洗数据

我们下一个任务是创建词汇表并将查询/响应句子对(对话)加载到内存

注意我们正在处理词序,这些词序没有映射到离散数值空间。因此,我们必须通过数据集中的单词来创建一个索引

为此我们创建了一个Voc类,它会存储从单词到索引的映射、索引到单词的反向映射、每个单词的计数和总单词量。这个类提供向词汇表中添加单词的方法(addWord)、添加所有单词到句子中的方法 (addSentence) 和清洗不常见的单词方法(trim)。更多的数据清洗在后面进行

现在我们可以组装词汇表和查询/响应语句对。在使用数据之前,我们必须做一些预处理

首先,我们必须使用unicodeToAscii将 unicode 字符串转换为 ASCII。然后,我们应该将所有字母转换为小写字母并清洗掉除基本标点之外的所有非字母字符 (normalizeString)。最后,为了帮助训练收敛,我们将过滤掉长度大于MAX_LENGTH 的句子 (filterPairs)

关于os.path.join(path[,path2[,...])函数,其功能是组合多个路径并返回,可忽视’\\’

打印一下结果:

for pair in pairs[:10]:
    print(pair)
  • 1
  • 2
Read 221282 sentence pairs
Trimmed to 64271 sentence pairs
Counting words...
Counted words: 18008
['there .', 'where ?']
['you have my word . as a gentleman', 'you re sweet .']
['hi .', 'looks like things worked out tonight huh ?']
['you know chastity ?', 'i believe we share an art instructor']
['have fun tonight ?', 'tons']
['well no . . .', 'then that s all you had to say .']
['then that s all you had to say .', 'but']
['but', 'you always been this selfish ?']
['do you listen to this crap ?', 'what crap ?']
['what good stuff ?', 'the real you .']
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14

另一种有利于让训练更快收敛的策略是去除词汇表中很少使用的单词。减少特征空间也会降低模型学习目标函数的难度。我们通过以下两个步骤完成这个操作:

  1. 使用voc.trim函数去除 MIN_COUNT 阈值以下单词
  2. 如果句子中包含词频过小的单词,那么整个句子也被过滤掉

3.为模型准备数据

尽管我们已经投入了大量精力来准备和清洗我们的数据,将它变成一个很好的词汇对象和一系列的句子对,但我们的模型最终希望数据以 numerical torch张量作为输入。可以在seq2seq translation tutorial 或者之前的Blog中找到为模型准备处理数据的一种方法。 在该教程中,我们使用batch size大小为1,这意味着我们所要做的就是将句子对中的单词转换为词汇表中的相应索引,并将其提供给模型

但是,如果你想要加速训练或者想要利用GPU并行计算能力,则需要使用小批量mini-batches来训练

使用小批量mini-batches也意味着我们必须注意批量处理中句子长度的变化。为了容纳同一batch中不同大小的句子,我们将使我们的批量输入张量大小(max_length,batch_size),其中短于max_length的句子在EOS_token之后进行零填充(zero padded)

如果我们简单地将我们的英文句子转换为张量,通过将单词转换为索引indicesFromSentence和零填充zero-pad,我们的张量的大小将是 (batch_size,max_length),并且索引第一维将在所有时间步骤中返回完整序列。但是,我们需要沿着时间对我们批量数据进行索引并且包括批量数据中所有序列。因此,我们将输入批处理大小转换为(max_length,batch_size),以便跨第一维的索引返回批处理中所有句子的时间步长。 我们在zeroPadding函数中隐式处理这个转置,此操作与之前的Blog相似

在这里插入图片描述
关于List sort()函数:

list.sort(cmp=None, key=None, reverse=False)
  • 1
  • key – 主要是用来进行比较的元素,只有一个参数,具体的函数的参数就是取自于可迭代对象中,指定可迭代对象中的一个元素来进行排序。
  • reverse – 排序规则,reverse = True 降序, reverse = False 升序(默认)

其和lambda组合起来使用:

y = [['there .', 'where ?'],
    ['you have my word . as a gentleman', 'you re sweet .'],
    ['hi .', 'looks like things worked out tonight huh ?'],
    ['you know chastity ?', 'i believe we share an art instructor'],
    ['have fun tonight ?', 'tons']]
y.sort(key=lambda x: len(x[0].split(' ')), reverse=True)
print(y)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
[['you have my word . as a gentleman', 'you re sweet .'], ['you know chastity ?', 'i believe we share an art instructor'], ['have fun tonight ?', 'tons'], ['there .', 'where ?'], ['hi .', 'looks like things worked out tonight huh ?']]
  • 1

应该就是对每个元素的第一个字符串的长度进行从高到低排序

结果:

print("input_variable:", input_variable)
print("lengths:", lengths)
print("target_variable:", target_variable)
print("mask:", mask)
print("max_target_len:", max_target_len)
  • 1
  • 2
  • 3
  • 4
  • 5
input_variable: tensor([[  34,  101,   91,   16, 6489],
        [ 383, 1823,    4, 2080,    6],
        [   7,  191,    4,   66,    2],
        [ 572,  117,    4,    2,    0],
        [  27,   12,    2,    0,    0],
        [  14, 4188,    0,    0,    0],
        [ 274,    4,    0,    0,    0],
        [   4,    2,    0,    0,    0],
        [   2,    0,    0,    0,    0]])
lengths: tensor([9, 8, 5, 4, 3])
target_variable: tensor([[  64,   62,  967,   50, 2238],
        [1000,    4,    4,   68,    5],
        [  50,   64,    4,    7,   92],
        [  37,  101,    4,   47,    7],
        [  64,  102,   25,   40,  144],
        [3762, 2268,    4,   70, 3200],
        [   6,    6,    4,   71,    6],
        [   2,    2,    4,    6,    2],
        [   0,    0,    2,    2,    0]])
mask: tensor([[1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1],
        [0, 0, 1, 1, 0]], dtype=torch.uint8)
max_target_len: 9
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29

4.定义模型

4.1 Seq2Seq模型

Seq2Seq模型我们聊天机器人的大脑是序列到序列(seq2seq)模型。seq2seq模型的目标是将可变长度序列作为输入,并使用固定大小的模型将可变长度序列作为输出返回

Sutskever et al.发现通过一起使用两个独立的RNN,我们可以完成这项任务。第一个RNN充当编码器,其将可变长度输入序列编码为固定长度上下文向量。理论上,该上下文向量(RNN的最终隐藏层)将包含关于输入到机器人的查询语句的语义信息。第二个RNN是一个解码器,它接收输入文字和上下文矢量,并返回序列中下一句文字的概率和在下一次迭代中使用的隐藏状态

在这里插入图片描述
图片来源: https://jeddy92.github.io/JEddy92.github.io/ts_seq2seq_intro/

4.2 编码器

编码器RNN每次迭代中输入一个语句输出一个token(例如,一个单词),同时在这时间内输出“输出”向量和“隐藏状态”向量。然后将隐藏 状态向量传递到下一步,并记录输出向量。编码器将其在序列中的每一点处看到的上下文转换为高维空间中的一系列点,解码器将使用这些点 为给定任务生成有意义的输出

我们的编码器的核心是由Cho et al.等人发明的多层门循环单元。在2014年,我们将使用 GRU的双向变体,这意味着基本上有两个独立的RNN:一个以正常的顺序输入输入序列,另一个以相反的顺序输入输入序列。每个网络的输出在 每个时间步骤求和。使用双向GRU将为我们提供编码过去和未来上下文的优势
在这里插入图片描述
图片来源: https://colah.github.io/posts/2015-09-NN-Types-FP/

注意:embedding层用于在任意大小的特征空间中对我们的单词索引进行编码。对于我们的模型,此图层会将每个单词映射到大小为 hidden_size的特征空间。训练后,这些值会被编码成和他们相似的有意义词语

最后,如果将填充的一批序列传递给RNN模块,我们必须分别使用torch.nn.utils.rnn.pack_padded_sequencetorch.nn.utils.rnn.pad_packed_sequence 在RNN传递时分别进行填充和反填充

计算图
1.将单词索引转换为词嵌入 embeddings。
2.为RNN模块打包填充batch序列。
3.通过GRU进行前向传播。
4.反填充。
5.对双向GRU输出求和。
6.返回输出和最终隐藏状态

输入
input_seq:一批输入句子; shape =(max_length,batch_size
input_lengths:batch中每个句子对应的句子长度列表;shape=(batch_size)
hidden:隐藏状态;shape =(n_layers x num_directions,batch_size,hidden_size)
输出
outputs:GRU最后一个隐藏层的输出特征(双向输出之和);shape =(max_length,batch_size,hidden_size)
hidden:从GRU更新隐藏状态;shape =(n_layers x num_directions,batch_size,hidden_size)

nn.utils.rnn.pack_padded_sequence和之前使用的from torch.nn.utils.rnn import pad_sequence还不太一样,具体可见官方中文文档

其lengths (Tensor or list(int)) – list of sequence lengths of each batch element (must be on the CPU if provided as a tensor).

4.3 解码器

解码器RNN以token-by-token的方式生成响应语句。它使用编码器的上下文向量和内部隐藏状态来生成序列中的下一个单词。它持续生成单词, 直到输出是EOS_token,这个表示句子的结尾。一个 vanilla seq2seq 解码器的常见问题是,如果我们只依赖于上下文向量来编码整个输入序列的含义,那么我们很可能会丢失信息。尤其是在处理长输入序列时,这极大地限制了我们的解码器的能力

为了解决这个问题,Bahdanau et al.等人创建了一种“attention mechanism”,允许解码器关注输入序列的某些部分,而不是在每一步都使用完全固定的上下文

在一个高的层级中,用解码器的当前隐藏状态和编码器输出来计算注意力。输出注意力的权重与输入序列具有相同的大小,允许我们将它们乘以编码器输出,给出一个加权和,表示要注意的编码器输出部分。Sean Robertson的图片很好地描述了这一点:

在这里插入图片描述
Luong et al.通过创造“Global attention”,改善了Bahdanau et al. 的基础工作。关键的区别在于,对于“Global attention”,我们考虑所有编码器的隐藏状态,而不是 Bahdanau 等人的“Local attention”, 它只考虑当前步中编码器的隐藏状态。另一个区别在于,通过“Global attention”,我们仅使用当前步的解码器的隐藏状态来计算注意力权重 (或者能量)。Bahdanau 等人的注意力计算需要知道前一步中解码器的状态。 此外,Luong等人提供各种方法来计算编码器输出和解码器输出 之间的注意权重(能量),称之为“score functions”:
在这里插入图片描述
其中 h t h_t ht=当前目标解码器状态, h s h_s hs= 所有编码器状态

总体而言,Global attention机制可以通过下图进行总结。请注意,我们将“Attention Layer”用一个名为Attn的nn.Module来单独实现。 该模块的输出是经过softmax标准化后权重张量的大小(batch_size,1,max_length)
在这里插入图片描述
现在我们已经定义了注意力子模块,我们可以实现真实的解码器模型。对于解码器,我们将每次手动进行一批次的输入。这意味着我们的词嵌 入张量和GRU输出都将具有相同大小(1,batch_size,hidden_size)

计算图

  1. 获取当前输入的词嵌入
  2. 通过单向GRU进行前向传播
  3. 通过2输出的当前GRU计算注意力权重
  4. 将注意力权重乘以编码器输出以获得新的“weighted sum”上下文向量
  5. 使用Luong eq.5连接加权上下文向量和GRU输出
  6. 使用Luong eq.6预测下一个单词(没有softmax)
  7. 返回输出和最终隐藏状态

输入

  • input_step:每一步输入序列batch(一个单词);shape =(1,batch_size)
  • last_hidden:GRU的最终隐藏层;shape =(n_layers x num_directions,batch_size,hidden_size)
  • encoder_outputs:编码器模型的输出;shape =(max_length,batch_size,hidden_size)

输出

  • output: 一个softmax标准化后的张量, 代表了每个单词在解码序列中是下一个输出单词的概率;shape =(batch_size,voc.num_words)
  • hidden: GRU的最终隐藏状态;shape =(n_layers x num_directions,batch_size,hidden_size
 x = torch.rand(1,5,500)
 y = torch.rand(9,5,500)
 z = x * y
 print(z.size())     # torch.Size([9, 5, 500])
 print(torch.sum(z, dim=2).size())   # torch.Size([9, 5])
  • 1
  • 2
  • 3
  • 4
  • 5

在这里插入图片描述

5.定义训练步骤

5.1 Masked 损失

由于我们处理的是批量填充序列,因此在计算损失时我们不能简单地考虑张量的所有元素。我们定义maskNLLLoss可以根据解码器的输出张量、 描述目标张量填充的binary mask张量来计算损失。该损失函数计算与mask tensor中的1对应的元素的平均负对数似然

关于torch.gather函数,官方文档见

torch.gather(input, dim, index, *, sparse_grad=False, out=None) → Tensor
  • 1
out[i][j][k] = input[i][index[i][j][k]][k]  # if dim == 1
  • 1

例如:

>>> t = torch.tensor([[1, 2], [3, 4]])
>>> torch.gather(t, 1, torch.tensor([[0, 0], [1, 0]]))
tensor([[ 1,  1],
        [ 4,  3]])
  • 1
  • 2
  • 3
  • 4
x = torch.rand(6, 5)
y = torch.arange(5).view(-1, 1)
print(x)
print(y)
print(torch.gather(x,1,y))
  • 1
  • 2
  • 3
  • 4
  • 5

相当于取对角线的值:

tensor([[0.4736, 0.4092, 0.3242, 0.0021, 0.2569],
        [0.9570, 0.1728, 0.6498, 0.0388, 0.1231],
        [0.1590, 0.0062, 0.3644, 0.4779, 0.4086],
        [0.0151, 0.1323, 0.7863, 0.7310, 0.9738],
        [0.6469, 0.8148, 0.4508, 0.0107, 0.2141],
        [0.7299, 0.0607, 0.4077, 0.5305, 0.0558]])
tensor([[0],
        [1],
        [2],
        [3],
        [4]])
tensor([[0.4736],
        [0.1728],
        [0.3644],
        [0.7310],
        [0.2141]])
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16

5.2 单次训练迭代

train函数包含单次训练迭代的算法(单批输入)
我们将使用一些巧妙的技巧来帮助融合:

  • 第一个技巧是使用teacher forcing。 这意味着在一些概率是由teacher_forcing_ratio设置,我们使用当前目标单词作为解码器 的下一个输入,而不是使用解码器的当前推测。该技巧充当解码器的training wheels,有助于更有效的训练。然而,teacher forcing 可能导致推导中的模型不稳定,因为解码器可能没有足够的机会在训练期间真正地制作自己的输出序列。因此,我们必须注意我们如何设置teacher_forcing_ratio, 同时不要被快速的收敛所迷惑。
  • 第二个技巧是梯度裁剪(gradient clipping)。这是一种用于对抗“爆炸梯度(exploding gradient)”问题的常用技术。本质上, 通过将梯度剪切或阈值化到最大值,我们可以防止在损失函数中梯度以指数方式增长并发生溢出(NaN)或者越过梯度

在这里插入图片描述
图片来源: Goodfellow et al. Deep Learning. 2016. https://www.deeplearningbook.org/

操作顺序

  1. 通过编码器前向计算整个批次输入。
  2. 将解码器输入初始化为SOS_token,将隐藏状态初始化为编码器的最终隐藏状态。
  3. 通过解码器一次一步地前向计算输入一批序列。
  4. 如果是 teacher forcing 算法:将下一个解码器输入设置为当前目标;如果是 no teacher forcing 算法:将下一个解码器输入设置为当前解码器输出。
  5. 计算并累积损失。
  6. 执行反向传播。
  7. 裁剪梯度。
  8. 更新编码器和解码器模型参数

PyTorch的RNN模块(RNN,LSTM,GRU)可以像任何其他非重复层一样使用,只需将整个输入序列(或一批序列)传递给它们。 我们在编码器中使用GRU层就是这样的。实际情况是,在计算中有一个迭代过程循环计算隐藏状态的每一步。或者,你每次只运行一个模块。在这种情况下,我们在训练过程中手动循环遍历序列就像我们必须为解码器模型做的那样。只要你正确的维护这些模型的模块,就可以非常简单的实现顺序模型

5.3 训练迭代

现在终于将完整的训练步骤与数据结合在一起了。给定传递的模型、优化器、数据等,trainIters函数负责运行n_iterations的训练。 这个功能显而易见,因为我们通过train函数的完成了繁重工作

报警告:

[W IndexKernel.cu:401] Warning: masked_scatter_ received a mask with dtype torch.uint8, this behavior is now deprecated,please use a mask with dtype torch.bool instead. (function masked_scatter__cuda)
  • 1

这里需要把mask改成bool类型

cuda is available!
Processing corpus...
Loading conversations...
Writing newly formatted file...
Start preparing training data ...
Reading lines...
Read 221282 sentence pairs
Trimmed to 64271 sentence pairs
Counting words...
Counted words: 18008
keep_words 7823 / 18005 = 0.4345
Trimmed from 64271 pairs to 53165, 0.8272 of total
Building encoder and decoder ...
Building optimizers ...
Starting Training!
Initializing ...
Training...
200 Finished Training. Loss:37.443223
400 Finished Training. Loss:31.082634
600 Finished Training. Loss:29.710671
800 Finished Training. Loss:28.579337
1000 Finished Training. Loss:27.827667
1200 Finished Training. Loss:26.942233
1400 Finished Training. Loss:26.375093
1600 Finished Training. Loss:25.854567
1800 Finished Training. Loss:25.230953
2000 Finished Training. Loss:24.758790
2200 Finished Training. Loss:24.113004
2400 Finished Training. Loss:23.766075
2600 Finished Training. Loss:23.205067
2800 Finished Training. Loss:22.746353
3000 Finished Training. Loss:22.181855
3200 Finished Training. Loss:21.720691
3400 Finished Training. Loss:21.308844
3600 Finished Training. Loss:20.722166
3800 Finished Training. Loss:20.189193
4000 Finished Training. Loss:19.957068
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37

训练每一次打印平均Loss变化曲线:
在这里插入图片描述

6.评估定义

在训练模型后,我们希望能够自己与机器人交谈。首先,我们必须定义我们希望模型如何解码编码输入

6.1 贪婪解码

贪婪解码是我们在不使用 teacher forcing 时在训练期间使用的解码方法。换句话说,对于每一步,我们只需从具有最高 softmax 值的 decoder_output 中选择单词。该解码方法在单步长级别上是最佳的。

为了便于贪婪解码操作,我们定义了一个GreedySearchDecoder类。当运行时,类的实例化对象输入序列(input_seq)的大小是(input_seq length,1), 标量输入(input_length)长度的张量和 max_length 来约束响应句子长度。使用以下计算图来评估输入句子:

计算图

  1. 通过编码器模型前向计算。
  2. 准备编码器的最终隐藏层,作为解码器的第一个隐藏输入。
  3. 将解码器的第一个输入初始化为 SOS_token。
  4. 将初始化张量追加到解码后的单词中。
  5. 一次迭代解码一个单词token:
     (i) 通过解码器进行前向计算。
     (ii) 获得最可能的单词token及其softmax分数。
     (iii) 记录token和分数。
     (iv) 准备当前token作为下一个解码器的输入。
  6. 返回收集到的词 tokens 和分数

6.2 评估我们的文本

现在我们已经定义了解码方法,我们可以编写用于评估字符串输入句子的函数。evaluate函数管理输入句子的低层级处理过程。我们首先使 用batch_size == 1将句子格式化为输入batch的单词索引。我们通过将句子的单词转换为相应的索引,并通过转换维度来为我们的模型准备张量。我们还创建了一个lengths张量,其中包含输入句子的长度。在这种情况下,lengths是标量,因为我们一次只评估一个句子(batch_size == 1)。 接下来,我们使用我们的GreedySearchDecoder实例化后的对象(searcher)获得解码响应句子的张量。最后,我们将响应的索引转换为单词并返回已解码单词的列表。

evaluateInput充当聊天机器人的用户接口。调用时,将生成一个输入文本字段,我们可以在其中输入查询语句。在输入我们的输入句子并按 Enter 后,我们的文本以与训练数据相同的方式标准化,并最终被输入到评估函数以获得解码的输出句子。我们循环这个过程,这样我们可以继续与我们的机器人聊天直到我们输入“q”或“quit”。

最后,如果输入的句子包含一个不在词汇表中的单词,我们会通过打印错误消息并提示用户输入另一个句子来优雅地处理

结果:

> Hello!
Bot: hello . . . . .
> What are you doing?
Bot: i m going to kill you . .
> what?
Bot: you know what i m saying . .
> You must be kidding!
Bot: i m not . . . .
> Oops...
Bot: what ? . . . . .
> You serious?
Bot: yeah . s a little while .
> I love you!
Bot: i know . . . .
> Do you know Siri?
Error: Encountered unknown word.
> q
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
> hello?
Bot: hi . . . . .
> hello!
Bot: hi . . . .
> hello
Bot: hi . . . .
> where am I?
Bot: in the parking way . . .
> who are you?
Bot: i m the bowler . . .
> how are you doing?
Bot: i m fine . . .
> are you my friend?
Bot: yes . . . . .
> you're under arrest
Bot: i m not . . .
> i'm just kidding
Bot: no . . . . .
> where are you from?
Bot: in the bathroom . . .
> it's time for me to leave
Bot: i m sorry to do that . !
> goodbye
Bot: goodbye . . . .
> q
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25

7. 全部代码

import torch

if torch.cuda.is_available():
    device = torch.device("cuda")
    print('cuda is available!')
else:
    device = torch.device("cpu")

import os
path = '... your path\\cornell_movie_dialogs_corpus\\cornell movie-dialogs corpus'

MOVIE_LINES_FIELDS = ["lineID", "characterID", "movieID", "character", "text"]
def loadLines(fileName, fields):        # 将文件的每一行拆分为字段字典(lineID, characterID, movieID, character, text)
    lines = {}
    with open(fileName, 'r', encoding='iso-8859-1') as f:
        for line in f:      # 对于每一句话
            values = line.split(' +++$+++ ')
            lineObj = {}        # 提取字段
            for i, field in enumerate(fields):
                lineObj[field] = values[i]
            lines[lineObj['lineID']] = lineObj
    return lines

file = os.path.join(path, "movie_lines.txt")
print("Processing corpus...")       # 加载行和进程对话
lines = loadLines(file, MOVIE_LINES_FIELDS)

MOVIE_CONVERSATIONS_FIELDS = ["character1ID", "character2ID", "movieID", "utteranceIDs"]
def loadConversations(fileName, lines, fields):     # 将 `loadLines` 中的行字段分组为基于 *movie_conversations.txt* 的对话
    conversations = []
    with open(fileName, 'r', encoding='iso-8859-1') as f:
        for line in f:     # 对于每一行
            values = line.split(" +++$+++ ")
            convObj = {}    # 提取字段
            for i, field in enumerate(fields):
                convObj[field] = values[i]
            lineIds = eval(convObj["utteranceIDs"])     # Convert string to list (convObj["utteranceIDs"] == "['L598485', 'L598486', ...]")
            convObj["lines"] = []
            for lineId in lineIds:
                convObj['lines'].append(lines[lineId])
            conversations.append(convObj)
    return conversations

file = os.path.join(path, "movie_conversations.txt")
print("Loading conversations...")
conversations = loadConversations(file, lines, MOVIE_CONVERSATIONS_FIELDS)

def extractSentencePairs(conversations):    # 从对话中提取一对句子
    qa_pairs = []
    for conversation in conversations:
        for i in range(len(conversation["lines"]) - 1):  # We ignore the last line (no answer for it)
            inputLine = conversation["lines"][i]["text"].strip()
            targetLine = conversation["lines"][i + 1]["text"].strip()
            if inputLine and targetLine:
                qa_pairs.append([inputLine, targetLine])
    return qa_pairs

delimiter = '\t'
import codecs
delimiter = str(codecs.decode(delimiter, "unicode_escape"))

datafile = os.path.join(path, "formatted_movie_lines.txt")
print("Writing newly formatted file...")
import csv
with open(datafile, 'w', encoding='utf-8') as outputfile:
    writer = csv.writer(outputfile, delimiter=delimiter, lineterminator='\n')
    for pair in extractSentencePairs(conversations):
        writer.writerow(pair)

PAD_token = 0  # Used for padding short sentences
SOS_token = 1  # Start-of-sentence token
EOS_token = 2  # End-of-sentence token

class Voc:
    def __init__(self, name):
        self.name = name
        self.trimmed = False
        self.word2index = {}
        self.word2count = {}
        self.index2word = {PAD_token: "PAD", SOS_token: "SOS", EOS_token: "EOS"}
        self.num_words = 3  # Count SOS, EOS, PAD

    def addSentence(self, sentence):
        for word in sentence.split(' '):
            self.addWord(word)

    def addWord(self, word):
        if word not in self.word2index:
            self.word2index[word] = self.num_words
            self.word2count[word] = 1
            self.index2word[self.num_words] = word
            self.num_words += 1
        else:
            self.word2count[word] += 1

    def trim(self, min_count):       # 删除低于特定计数阈值的单词
        if self.trimmed:
            return
        self.trimmed = True
        keep_words = []
        for k, v in self.word2count.items():
            if v >= min_count:
                keep_words.append(k)
        print('keep_words {} / {} = {:.4f}'.format(
            len(keep_words), len(self.word2index), len(keep_words) / len(self.word2index)
        ))

        self.word2index = {}    # 重初始化字典
        self.word2count = {}
        self.index2word = {PAD_token: "PAD", SOS_token: "SOS", EOS_token: "EOS"}
        self.num_words = 3  # Count default tokens

        for word in keep_words:
            self.addWord(word)

import unicodedata
# def unicodeToAscii(s):
#     return ''.join(
#         c for c in unicodedata.normalize('NFD', s)
#         if unicodedata.category(c) != 'Mn'
#     )

import re
def normalizeString(s):
    # 将 Unicode 字符转换为 ASCII
    Ascii = []
    for c in unicodedata.normalize('NFD', s):
        if unicodedata.category(c) != 'Mn':
            Ascii.append(c)
    # 将ASCII列表转化为字符串,并将所有内容都转换为小写,并修剪大多数标点符号
    s = ''.join(Ascii).lower().strip()
    s = re.sub(r"([.!?])", r" \1", s)       # 在.!?前面加上空格
    s = re.sub(r"[^a-zA-Z.!?]+", r" ", s)       # 只保留a-zA-Z.!?并在其后加空格
    return s.strip(' ')

def readVocs(datafile, corpus_name):
    print("Reading lines...")
    lines = open(datafile, encoding='utf-8').read().strip().split('\n')     # 读取文件并分成几行
    pairs = [[normalizeString(s) for s in l.strip('\r').split('\t')] for l in lines]
    voc = Voc(corpus_name)
    return voc, pairs

MAX_LENGTH = 10  # Maximum sentence length to consider
def filterPair(p):      # 如果对 'p' 中的两个句子都低于 MAX_LENGTH 阈值,则返回True
    return len(p[0].split(' ')) < MAX_LENGTH and len(p[1].split(' ')) < MAX_LENGTH      # # Input sequences need to preserve the last word for EOS token

def filterPairs(pairs):
    return [pair for pair in pairs if filterPair(pair)]

def loadPrepareData(corpus, corpus_name, datafile, save_dir):
    print("Start preparing training data ...")
    voc, pairs = readVocs(datafile, corpus_name)
    print("Read {!s} sentence pairs".format(len(pairs)))
    pairs = filterPairs(pairs)
    print("Trimmed to {!s} sentence pairs".format(len(pairs)))
    print("Counting words...")
    for pair in pairs:
        voc.addSentence(pair[0])
        voc.addSentence(pair[1])
    print("Counted words:", voc.num_words)
    return voc, pairs

save_dir = os.path.join(path, "save")
voc, pairs = loadPrepareData(path, "cornell movie-dialogs corpus", datafile, save_dir)

MIN_COUNT = 3    # 修剪的最小字数阈值
def trimRareWords(voc, pairs, MIN_COUNT):
    voc.trim(MIN_COUNT)     # 修剪来自voc的MIN_COUNT下使用的单词
    keep_pairs = []     # 过滤掉带有修剪词的pair
    for pair in pairs:
        input_sentence = pair[0]
        output_sentence = pair[1]
        keep_input = True
        keep_output = True
        for word in input_sentence.split(' '):      # 检查输入句子
            if word not in voc.word2index:
                keep_input = False
                break
        for word in output_sentence.split(' '):     # 检查输出句子
            if word not in voc.word2index:
                keep_output = False
                break
        if keep_input and keep_output:      # 只保留输入或输出句子中不包含修剪单词的对
            keep_pairs.append(pair)

    print("Trimmed from {} pairs to {}, {:.4f} of total".format(len(pairs), len(keep_pairs),
                                                                len(keep_pairs) / len(pairs)))
    return keep_pairs

pairs = trimRareWords(voc, pairs, MIN_COUNT)        # 修剪voc和对

def indexesFromSentence(voc, sentence):
    return [voc.word2index[word] for word in sentence.split(' ')] + [EOS_token]

import itertools
def zeroPadding(l, fillvalue=PAD_token):        # zip 对数据进行合并了,相当于行列转置了
    return list(itertools.zip_longest(*l, fillvalue=fillvalue))

def binaryMatrix(l, value=PAD_token):       # 记录 PAD_token的位置为0, 其他的为1
    m = []
    for i, seq in enumerate(l):
        m.append([])
        for token in seq:
            if token == PAD_token:
                m[i].append(0)
            else:
                m[i].append(1)
    return m

def inputVar(l, voc):       # 返回填充前(加入结束index EOS_token做标记)的长度 和 填充后的输入序列张量, ['you have my word .','','','','']
    indexes_batch = [indexesFromSentence(voc, sentence) for sentence in l]      # [[1,2,3],[],[],[],[]]
    lengths = torch.tensor([len(indexes) for indexes in indexes_batch])     # [1,2,3,4,5]
    padList = zeroPadding(indexes_batch)        # seq_len * 5
    padVar = torch.LongTensor(padList)
    return padVar, lengths

def outputVar(l, voc):  # 返回填充前(加入结束index EOS_token做标记)最长的一个长度 和 填充后的输入序列张量, 和 填充后的标记 mask
    indexes_batch = [indexesFromSentence(voc, sentence) for sentence in l]
    max_target_len = max([len(indexes) for indexes in indexes_batch])
    padList = zeroPadding(indexes_batch)
    padVar = torch.LongTensor(padList)
    mask = binaryMatrix(padList)
    mask = torch.ByteTensor(mask)
    return padVar, mask, max_target_len

def batch2TrainData(voc, pair_batch):
    pair_batch.sort(key=lambda x:len(x[0].split(' ')), reverse=True)
    input_batch, output_batch = [], []
    for pair in pair_batch:
        input_batch.append(pair[0])
        output_batch.append(pair[1])
    inp, lengths = inputVar(input_batch, voc)       # seq_len * 5, [seq_len,2,3,4,5]
    output, mask, max_target_len = outputVar(output_batch, voc)
    return inp, lengths, output, mask, max_target_len

import random
# small_batch_size = 5    # 验证例子
# batches = batch2TrainData(voc, [random.choice(pairs) for _ in range(small_batch_size)])
# input_variable, lengths, target_variable, mask, max_target_len = batches

import torch.nn as nn

class EncoderRNN(nn.Module):
    def __init__(self, hidden_size, embedding, n_layers=1, dropout=0):
        super(EncoderRNN, self).__init__()
        self.n_layers = n_layers
        self.hidden_size = hidden_size
        self.embedding = embedding

        self.gru = nn.GRU(hidden_size, hidden_size, n_layers, dropout=(0 if n_layers == 1 else dropout),
                          bidirectional=True)

    def forward(self, input_seq, input_lengths, hidden=None):
        embedded = self.embedding(input_seq)  # seq_len * 5 * embed_dim
        packed = nn.utils.rnn.pack_padded_sequence(embedded, input_lengths)     # 为RNN模块打包填充batch序列
        outputs, hidden = self.gru(packed, hidden)  # seq_len * 5 * (2*500)
        outputs, _ = nn.utils.rnn.pad_packed_sequence(outputs)      # 打开填充
        outputs = outputs[:, :, :self.hidden_size] + outputs[:, :, self.hidden_size:]
        return outputs, hidden      # seq_len * 5 * 500, (1*2) * 5 * 500

import torch.nn.functional as F
class Attn(torch.nn.Module):
    def __init__(self, method, hidden_size):
        super(Attn, self).__init__()
        self.method = method
        if self.method not in ['dot', 'general', 'concat']:
            raise ValueError(self.method, "is not an appropriate attention method.")
        self.hidden_size = hidden_size
        if self.method == 'general':
            self.attn = torch.nn.Linear(self.hidden_size, hidden_size)
        elif self.method == 'concat':
            self.attn = torch.nn.Linear(self.hidden_size * 2, hidden_size)
            self.v = torch.nn.Parameter(torch.FloatTensor(hidden_size))

    def dot_score(self, hidden, encoder_output):
        return torch.sum(hidden * encoder_output, dim=2)

    def general_score(self, hidden, encoder_output):        # 1 * 5 * 500, seq_len * 5 * 500
        energy = self.attn(encoder_output)      # seq_len * 5 * 500
        return torch.sum(hidden * energy, dim=2)        # seq_len * 5

    def concat_score(self, hidden, encoder_output):
        energy = self.attn(torch.cat((hidden.expand(encoder_output.size(0), -1, -1), encoder_output), 2)).tanh()
        return torch.sum(self.v * energy, dim=2)

    def forward(self, hidden, encoder_outputs):     # 根据给定的方法计算注意力(能量)
        if self.method == 'general':
            attn_energies = self.general_score(hidden, encoder_outputs)
        elif self.method == 'concat':
            attn_energies = self.concat_score(hidden, encoder_outputs)
        elif self.method == 'dot':
            attn_energies = self.dot_score(hidden, encoder_outputs)

        attn_energies = attn_energies.t()       # Transpose max_length and batch_size dimensions, 5 * seq_len

        return F.softmax(attn_energies, dim=1).unsqueeze(1)     # 5 * 1 * seq_len

class LuongAttnDecoderRNN(nn.Module):
    def __init__(self, attn_model, embedding, hidden_size, output_size, n_layers=1, dropout=0.1):
        super(LuongAttnDecoderRNN, self).__init__()
        self.attn_model = attn_model
        self.hidden_size = hidden_size
        self.output_size = output_size
        self.n_layers = n_layers
        self.dropout = dropout

        # 定义层
        self.embedding = embedding
        self.embedding_dropout = nn.Dropout(dropout)
        self.gru = nn.GRU(hidden_size, hidden_size, n_layers, dropout=(0 if n_layers == 1 else dropout))
        self.concat = nn.Linear(hidden_size * 2, hidden_size)
        self.out = nn.Linear(hidden_size, output_size)
        self.attn = Attn(attn_model, hidden_size)

    def forward(self, input_step, last_hidden, encoder_outputs):        # 注意:我们一次运行这一步(单词)
        embedded = self.embedding(input_step)       # 获取当前输入字的嵌入, 1 * 5 * 500
        embedded = self.embedding_dropout(embedded)
        rnn_output, hidden = self.gru(embedded, last_hidden)        # 通过单向GRU转发1 * 5 * 500, 1 * 5 * 500
        # 这里为什么是rnn_output,不应该是hidden吗???
        # print(hidden.size(),encoder_outputs.size())  torch.Size([2, 64, 500]) torch.Size([10, 64, 500])
        attn_weights = self.attn(rnn_output, encoder_outputs)       # 从当前GRU输出计算注意力1 * 5 * 500, seq_len * 5 * 500  # 5 * 1 * seq_len
        context = attn_weights.bmm(encoder_outputs.transpose(0, 1))     # 5 * 1 * seq_len bmm 5 * seq_len * 500 → 5 * 1 * 500
        rnn_output = rnn_output.squeeze(0)      # 5 * 500
        context = context.squeeze(1)        # 5 * 500
        concat_input = torch.cat((rnn_output, context), 1)      # 5 * 1000
        concat_output = torch.tanh(self.concat(concat_input))       # 5 * 500
        output = self.out(concat_output)        # 5 * output_size
        output = F.softmax(output, dim=1)
        return output, hidden

def maskNLLLoss(inp, target, mask):
    nTotal = mask.sum()
    crossEntropy = -torch.log(torch.gather(inp, 1, target.view(-1, 1)).squeeze(1))      # seq_len * 5, 5 * 1 → 5 * 1 → [5]
    loss = crossEntropy.masked_select(mask).mean()
    loss = loss.to(device)
    return loss, nTotal.item()

def train(input_variable, lengths, target_variable, mask, max_target_len):
    encoder_optimizer.zero_grad()
    decoder_optimizer.zero_grad()
    input_variable = input_variable.to(device)
    # lengths = lengths.to(device)
    target_variable = target_variable.to(device)

    # for i in range(len(mask)):
    #     for j in range(len(mask[0])):
    #         if mask[i][j] == 1: mask[i][j] = True
    #         elif mask[i][j] == 0: mask[i][j] = False

    mask = mask.bool().to(device)

    loss = 0

    encoder_outputs, encoder_hidden = encoder(input_variable, lengths)      # 正向传递编码器
    decoder_input = torch.LongTensor([[SOS_token for _ in range(batch_size)]])      # 创建初始解码器输入(从每个句子的SOS令牌开始) 1 * 5
    decoder_input = decoder_input.to(device)
    decoder_hidden = encoder_hidden[:decoder.n_layers]      # 将初始解码器隐藏状态设置为编码器的最终隐藏状态

    use_teacher_forcing = True if random.random() < teacher_forcing_ratio else False    # 确定我们是否此次迭代使用`teacher forcing`
    if use_teacher_forcing:
        for t in range(max_target_len):
            decoder_output, decoder_hidden = decoder(
                decoder_input, decoder_hidden, encoder_outputs
            )
            decoder_input = target_variable[t].view(1, -1)       # Teacher forcing: 下一个输入是当前的目标
            mask_loss, nTotal = maskNLLLoss(decoder_output, target_variable[t], mask[t])
            loss += mask_loss
    else:
        for t in range(max_target_len):
            decoder_output, decoder_hidden = decoder(
                decoder_input, decoder_hidden, encoder_outputs
            )
            _, topi = decoder_output.topk(1)        # # No teacher forcing: 下一个输入是解码器自己的当前输出
            decoder_input = torch.LongTensor([[topi[i][0] for i in range(batch_size)]])
            decoder_input = decoder_input.to(device)
            # 计算并累计损失
            mask_loss, nTotal = maskNLLLoss(decoder_output, target_variable[t], mask[t])
            loss += mask_loss

    loss.backward()

    _ = torch.nn.utils.clip_grad_norm_(encoder.parameters(), clip)      # # 剪辑梯度:梯度被修改到位
    _ = torch.nn.utils.clip_grad_norm_(decoder.parameters(), clip)

    encoder_optimizer.step()
    decoder_optimizer.step()

    return loss.item()


import matplotlib.pyplot as plt
def trainIters():
    training_batches = [batch2TrainData(voc, [random.choice(pairs) for _ in range(batch_size)])
                        for _ in range(n_iteration)]
    print('Initializing ...')
    start_iteration = 1
    print_loss = []
    losses = 0
    print("Training...")
    for iteration in range(start_iteration, n_iteration + 1):
        training_batch = training_batches[iteration - 1]
        input_variable, lengths, target_variable, mask, max_target_len = training_batch
        loss = train(input_variable, lengths, target_variable, mask, max_target_len)
        print_loss.append(loss)
        losses += loss
        if iteration % print_every == 0:
            print('%d Finished Training. Loss:%f' % (iteration, losses/print_every))
            losses = 0

    torch.save(encoder.state_dict(), '... your path\\model_encoder2.pth')
    torch.save(decoder.state_dict(), '... your path\\model_decoder2.pth')
    plt.plot(print_loss)
    plt.show()

class GreedySearchDecoder(nn.Module):
    def __init__(self, encoder, decoder):
        super(GreedySearchDecoder, self).__init__()
        self.encoder = encoder
        self.decoder = decoder

    def forward(self, input_seq, input_length, max_length):
        encoder_outputs, encoder_hidden = self.encoder(input_seq, input_length)     # 通过编码器模型转发输入
        decoder_hidden = encoder_hidden[:decoder.n_layers]      # 准备编码器的最终隐藏层作为解码器的第一个隐藏输入
        decoder_input = torch.ones(1, 1, device=device, dtype=torch.long) * SOS_token   # 使用SOS_token初始化解码器输入
        all_tokens = torch.zeros([0], device=device, dtype=torch.long)
        all_scores = torch.zeros([0], device=device)

        for _ in range(max_length):
            decoder_output, decoder_hidden = self.decoder(decoder_input, decoder_hidden, encoder_outputs)   # 正向通过解码器
            decoder_scores, decoder_input = torch.max(decoder_output, dim=1)        # 获得最可能的单词标记及其softmax分数
            all_tokens = torch.cat((all_tokens, decoder_input), dim=0)      # 记录token和分数
            all_scores = torch.cat((all_scores, decoder_scores), dim=0)
            decoder_input = torch.unsqueeze(decoder_input, 0)       # 准备当前令牌作为下一个解码器输入(添加维度)
        return all_tokens, all_scores       # 返回收集到的词tokens和分数

def evaluate(encoder, decoder, searcher, voc, sentence, max_length=MAX_LENGTH):
    indexes_batch = [indexesFromSentence(voc, sentence)]        ### 格式化输入句子作为batch
    lengths = torch.tensor([len(indexes) for indexes in indexes_batch])     # 创建lengths张量
    input_batch = torch.LongTensor(indexes_batch).transpose(0, 1)       # 转置batch的维度以匹配模型的期望
    input_batch = input_batch.to(device)
    # lengths = lengths.to(device)
    tokens, scores = searcher(input_batch, lengths, max_length) # 用searcher解码句子
    decoded_words = [voc.index2word[token.item()] for token in tokens]
    return decoded_words

def evaluateInput(encoder, decoder, searcher, voc):
    input_sentence = ''
    while (1):
        try:
            input_sentence = input('> ')        # 获取输入句子
            if input_sentence == 'q' or input_sentence == 'quit': break     # 检查是否退出
            input_sentence = normalizeString(input_sentence)        # 规范化句子
            output_words = evaluate(encoder, decoder, searcher, voc, input_sentence)
            output_words[:] = [x for x in output_words if not (x == 'EOS' or x == 'PAD')]       # 格式化和打印回复句
            print('Bot:', ' '.join(output_words))
        except:
            print("Error: Encountered unknown word.")


from torch import optim
if __name__ == '__main__':
    attn_model = 'dot'
    # attn_model = 'general'
    # attn_model = 'concat'
    hidden_size = 500
    encoder_n_layers = 2
    decoder_n_layers = 2
    dropout = 0.1
    batch_size = 64

    print('Building encoder and decoder ...')
    embedding = nn.Embedding(voc.num_words, hidden_size)
    encoder = EncoderRNN(hidden_size, embedding, encoder_n_layers, dropout)
    decoder = LuongAttnDecoderRNN(attn_model, embedding, hidden_size, voc.num_words, decoder_n_layers, dropout)
    encoder = encoder.to(device)
    decoder = decoder.to(device)
    # 以下为训练
    # encoder.train()
    # decoder.train()
    #
    # clip = 50.0
    # teacher_forcing_ratio = 1.0
    # learning_rate = 0.0001
    # decoder_learning_ratio = 5.0
    # n_iteration = 4000
    # print_every = 200
    #
    # print('Building optimizers ...')
    # encoder_optimizer = optim.Adam(encoder.parameters(), lr=learning_rate)
    # decoder_optimizer = optim.Adam(decoder.parameters(), lr=learning_rate * decoder_learning_ratio)
    # print("Starting Training!")
    # trainIters()

    # 以下为测试
    encoder.load_state_dict(torch.load('... your path\\model_encoder2.pth'))
    decoder.load_state_dict(torch.load('... your path\\model_decoder2.pth'))
    encoder.eval()
    decoder.eval()
    searcher = GreedySearchDecoder(encoder, decoder)
    evaluateInput(encoder, decoder, searcher, voc)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107
  • 108
  • 109
  • 110
  • 111
  • 112
  • 113
  • 114
  • 115
  • 116
  • 117
  • 118
  • 119
  • 120
  • 121
  • 122
  • 123
  • 124
  • 125
  • 126
  • 127
  • 128
  • 129
  • 130
  • 131
  • 132
  • 133
  • 134
  • 135
  • 136
  • 137
  • 138
  • 139
  • 140
  • 141
  • 142
  • 143
  • 144
  • 145
  • 146
  • 147
  • 148
  • 149
  • 150
  • 151
  • 152
  • 153
  • 154
  • 155
  • 156
  • 157
  • 158
  • 159
  • 160
  • 161
  • 162
  • 163
  • 164
  • 165
  • 166
  • 167
  • 168
  • 169
  • 170
  • 171
  • 172
  • 173
  • 174
  • 175
  • 176
  • 177
  • 178
  • 179
  • 180
  • 181
  • 182
  • 183
  • 184
  • 185
  • 186
  • 187
  • 188
  • 189
  • 190
  • 191
  • 192
  • 193
  • 194
  • 195
  • 196
  • 197
  • 198
  • 199
  • 200
  • 201
  • 202
  • 203
  • 204
  • 205
  • 206
  • 207
  • 208
  • 209
  • 210
  • 211
  • 212
  • 213
  • 214
  • 215
  • 216
  • 217
  • 218
  • 219
  • 220
  • 221
  • 222
  • 223
  • 224
  • 225
  • 226
  • 227
  • 228
  • 229
  • 230
  • 231
  • 232
  • 233
  • 234
  • 235
  • 236
  • 237
  • 238
  • 239
  • 240
  • 241
  • 242
  • 243
  • 244
  • 245
  • 246
  • 247
  • 248
  • 249
  • 250
  • 251
  • 252
  • 253
  • 254
  • 255
  • 256
  • 257
  • 258
  • 259
  • 260
  • 261
  • 262
  • 263
  • 264
  • 265
  • 266
  • 267
  • 268
  • 269
  • 270
  • 271
  • 272
  • 273
  • 274
  • 275
  • 276
  • 277
  • 278
  • 279
  • 280
  • 281
  • 282
  • 283
  • 284
  • 285
  • 286
  • 287
  • 288
  • 289
  • 290
  • 291
  • 292
  • 293
  • 294
  • 295
  • 296
  • 297
  • 298
  • 299
  • 300
  • 301
  • 302
  • 303
  • 304
  • 305
  • 306
  • 307
  • 308
  • 309
  • 310
  • 311
  • 312
  • 313
  • 314
  • 315
  • 316
  • 317
  • 318
  • 319
  • 320
  • 321
  • 322
  • 323
  • 324
  • 325
  • 326
  • 327
  • 328
  • 329
  • 330
  • 331
  • 332
  • 333
  • 334
  • 335
  • 336
  • 337
  • 338
  • 339
  • 340
  • 341
  • 342
  • 343
  • 344
  • 345
  • 346
  • 347
  • 348
  • 349
  • 350
  • 351
  • 352
  • 353
  • 354
  • 355
  • 356
  • 357
  • 358
  • 359
  • 360
  • 361
  • 362
  • 363
  • 364
  • 365
  • 366
  • 367
  • 368
  • 369
  • 370
  • 371
  • 372
  • 373
  • 374
  • 375
  • 376
  • 377
  • 378
  • 379
  • 380
  • 381
  • 382
  • 383
  • 384
  • 385
  • 386
  • 387
  • 388
  • 389
  • 390
  • 391
  • 392
  • 393
  • 394
  • 395
  • 396
  • 397
  • 398
  • 399
  • 400
  • 401
  • 402
  • 403
  • 404
  • 405
  • 406
  • 407
  • 408
  • 409
  • 410
  • 411
  • 412
  • 413
  • 414
  • 415
  • 416
  • 417
  • 418
  • 419
  • 420
  • 421
  • 422
  • 423
  • 424
  • 425
  • 426
  • 427
  • 428
  • 429
  • 430
  • 431
  • 432
  • 433
  • 434
  • 435
  • 436
  • 437
  • 438
  • 439
  • 440
  • 441
  • 442
  • 443
  • 444
  • 445
  • 446
  • 447
  • 448
  • 449
  • 450
  • 451
  • 452
  • 453
  • 454
  • 455
  • 456
  • 457
  • 458
  • 459
  • 460
  • 461
  • 462
  • 463
  • 464
  • 465
  • 466
  • 467
  • 468
  • 469
  • 470
  • 471
  • 472
  • 473
  • 474
  • 475
  • 476
  • 477
  • 478
  • 479
  • 480
  • 481
  • 482
  • 483
  • 484
  • 485
  • 486
  • 487
  • 488
  • 489
  • 490
  • 491
  • 492
  • 493
  • 494
  • 495
  • 496
  • 497
  • 498
  • 499
  • 500

小结

此篇文章与之前关于Attention模型(称为attention mechanism)不一样:

  1. 英语法语翻译Blog是Decoder的隐层和Embedded做Attention求权重矩阵,再对Encoder的Output做Weight Sum
  2. 英语德语翻译Blog是Decoder的隐层和Encoder的Output和做Attention求权重矩阵,再对Encoder的Output做Weight Sum
  3. 这次是Decoder的RNN的Output和Encoder的Output做Attention求权重矩阵,再对Encoder的Output做Weight Sum得到context,这和RNN的Output一起作为Decoder的输出,称为Global attention(dot运算)
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/凡人多烦事01/article/detail/483972
推荐阅读
相关标签
  

闽ICP备14008679号