当前位置:   article > 正文

pytrorch官方教程做一个聊天机器人(代码注释)_文字回答 模型pytorch

文字回答 模型pytorch
  1. from __future__ import absolute_import
  2. from __future__ import division
  3. from __future__ import print_function
  4. from __future__ import unicode_literals
  5. import torch
  6. from torch.jit import script, trace
  7. import torch.nn as nn
  8. from torch import optim
  9. import torch.nn.functional as F
  10. import csv
  11. import random
  12. import re
  13. import os
  14. import unicodedata
  15. import codecs
  16. from io import open
  17. import itertools
  18. import math
  19. USE_CUDA=torch.cuda.is_available()
  20. device =torch.device('cuda' if USE_CUDA else 'cpu')
  21. #忽略警告
  22. import warnings
  23. warnings.filterwarnings("ignore")
  24. """
  25. 加载和预处理数据
  26. 下一步就是格式化处理我们的数据文件并加载到我们可以使用的结构中
  27. Cornell Movie-Dialogs Corpus 是一个丰富的电影角色对话数据集:
  28. 10,292 对电影角色的220,579 次对话
  29. 617部电影中的9,035电影角色
  30. 总共304,713中语调
  31. 这个数据集庞大而多样,在语言形式、时间段、情感上等都有很大的变化。
  32. 我们希望这种多样性使我们的模型能够适应多种形式的输入和查询。
  33. 首先,我们通过数据文件的某些行来查看原始数据的格式
  34. """
  35. corpus_name = "cornell movie-dialogs corpus"
  36. corpus = os.path.join("data", corpus_name)
  37. def printlines(file,n=10):
  38. with open(file,'rb') as datafile:
  39. lines=datafile.readlines()
  40. for line in lines[:10]:
  41. print(line)
  42. """
  43. printlines(os.path.join(corpus,"movie_lines.txt"))
  44. print('>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>')
  45. printlines(os.path.join(corpus,"movie_characters_metadata.txt"))
  46. print('>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>')
  47. printlines(os.path.join(corpus,"movie_titles_metadata.txt"))
  48. print('>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>')
  49. printlines(os.path.join(corpus,"movie_conversations.txt"))
  50. print('>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>')
  51. printlines(os.path.join(corpus,"chameleons.pdf"))
  52. print('>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>')
  53. printlines(os.path.join(corpus,"chameleons.pdf"))
  54. print('>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>')
  55. """
  56. """
  57. 创建格式化数据文件
  58. 为了方便起见,我们将创建一个格式良好的数据文件,其中每一行包含一个由 tab 制表符分隔的查询语句和响应语句对。
  59. 以下函数便于解析原始 movie_lines.txt 数据文件。
  60. loadLines 将文件的每一行拆分为字段(lineID, characterID, movieID, character, text)组合的字典
  61. loadConversations 根据 movie_conversations.txt 将 loadLines 中的每一行数据进行归类
  62. extractSentencePairs 从对话中提取一对句子
  63. """
  64. """
  65. 格式化并且加载数据
  66. 将文件的每一行拆分为字段字典
  67. line = {
  68. 'L183198': {
  69. 'lineID': 'L183198',
  70. 'characterID': 'u5022',
  71. 'movieID': 'm333',
  72. 'character': 'FRANKIE',
  73. 'text': "Well we'd sure like to help you.\n"
  74. }, {...}
  75. }
  76. """
  77. def loadlines(fileName,fields):
  78. lines={}
  79. with open(fileName,'r',encoding='iso-8859-1') as f:
  80. for line in f:
  81. values=line.split(" +++$+++ ")
  82. lineobj={}
  83. for i,field in enumerate(fields):
  84. lineobj[field]=values[i]
  85. lines[lineobj["lineID"]]=lineobj
  86. return lines
  87. """
  88. # 将 `loadLines` 中的行字段分组为基于 *movie_conversations.txt* 的对话
  89. # [{
  90. # 'character1ID': 'u0',
  91. # 'character2ID': 'u2',
  92. # 'movieID': 'm0',
  93. # 'utteranceIDs': "['L194', 'L195', 'L196', 'L197']\n",
  94. # 'lines': [{
  95. # 'lineID': 'L194',
  96. # 'characterID': 'u0',
  97. # 'movieID': 'm0',
  98. # 'character': 'BIANCA',
  99. # 'text': 'Can we make this quick? Roxanne Korrine and Andrew Barrett are having an incredibly horrendous public break- up on the quad. Again.\n'
  100. # }, {
  101. # 'lineID': 'L195',
  102. # 'characterID': 'u2',
  103. # 'movieID': 'm0',
  104. # 'character': 'CAMERON',
  105. # 'text': "Well, I thought we'd start with pronunciation, if that's okay with you.\n"
  106. # }, {
  107. # 'lineID': 'L196',
  108. # 'characterID': 'u0',
  109. # 'movieID': 'm0',
  110. # 'character': 'BIANCA',
  111. # 'text': 'Not the hacking and gagging and spitting part. Please.\n'
  112. # }, {
  113. # 'lineID': 'L197',
  114. # 'characterID': 'u2',
  115. # 'movieID': 'm0',
  116. # 'character': 'CAMERON',
  117. # 'text': "Okay... then how 'bout we try out some French cuisine. Saturday? Night?\n"
  118. # }]
  119. # }, {...}]
  120. """
  121. def loadConversations(fillname,lines,fields):
  122. conversations=[]
  123. with open(fillname,'r',encoding='iso-8859-1') as f:
  124. for line in f:
  125. values=line.split("+++$+++")
  126. convObj={}
  127. for i,field in enumerate(fields):
  128. convObj[field]=values[i]
  129. lineIds=eval(convObj['utteranceIDs'])
  130. convObj["lines"]=[]
  131. for id in lineIds:
  132. convObj["lines"].append(lines[id])
  133. conversations.append(convObj)
  134. return conversations
  135. # 从对话中提取一对句子
  136. def extractSentencePairs(conversations):
  137. qa_pairs=[]
  138. for con in conversations:
  139. for i in range(len(con['lines'])-1):
  140. inputline=con['lines'][i]['text'].strip()
  141. outputline=con['lines'][i+1]['text'].strip()
  142. if inputline and outputline:
  143. qa_pairs.append([inputline,outputline])
  144. return qa_pairs
  145. """
  146. 现在我们将调用这些函数来创建文件,我们命名为 formatted_movie_lines.txt.
  147. Processing corpus...
  148. Loading conversations...
  149. Writing newly formatted file...
  150. Sample lines from file:
  151. b"Can we make this quick? Roxanne Korrine and Andrew Barrett are having an incredibly horrendous public break- up on the quad. Again.\tWell, I thought we'd start with pronunciation, if that's okay with you.\n"
  152. b"Well, I thought we'd start with pronunciation, if that's okay with you.\tNot the hacking and gagging and spitting part. Please.\n"
  153. b"Not the hacking and gagging and spitting part. Please.\tOkay... then how 'bout we try out some French cuisine. Saturday? Night?\n"
  154. b"You're asking me out. That's so cute. What's your name again?\tForget it.\n"
  155. b"No, no, it's my fault -- we didn't have a proper introduction ---\tCameron.\n"
  156. b"Cameron.\tThe thing is, Cameron -- I'm at the mercy of a particularly hideous breed of loser. My sister. I can't date until she does.\n"
  157. b"The thing is, Cameron -- I'm at the mercy of a particularly hideous breed of loser. My sister. I can't date until she does.\tSeems like she could get a date easy enough...\n"
  158. b'Why?\tUnsolved mystery. She used to be really popular when she started high school, then it was just like she got sick of it or something.\n'
  159. b"Unsolved mystery. She used to be really popular when she started high school, then it was just like she got sick of it or something.\tThat's a shame.\n"
  160. b'Gosh, if only we could find Kat a boyfriend...\tLet me see what I can do.\n'
  161. """
  162. datafile = os.path.join(corpus, "formatted_movie_lines.txt")
  163. delimiter = '\t'
  164. # Unescape the delimiter
  165. delimiter = str(codecs.decode(delimiter, "unicode_escape"))
  166. lines={}
  167. conversations=[]
  168. MOVIE_LINES_FIELDS = ["lineID", "characterID", "movieID", "character", "text"]
  169. MOVIE_CONVERSATIONS_FIELDS = ["character1ID", "character2ID", "movieID", "utteranceIDs"]
  170. # Load lines and process conversations
  171. print("\nProcessing corpus...")
  172. lines=loadlines(os.path.join(corpus, "movie_lines.txt"), MOVIE_LINES_FIELDS)
  173. print("\nLoading conversations...")
  174. conversations=loadConversations(os.path.join(os.path.join(corpus, "movie_conversations.txt")),
  175. lines, MOVIE_CONVERSATIONS_FIELDS)
  176. # Write new csv file
  177. print("\nWriting newly formatted file...")
  178. with open(datafile,'w',encoding='utf-8') as outputfile:
  179. writer = csv.writer(outputfile, delimiter=delimiter, lineterminator='\n')
  180. for pair in extractSentencePairs(conversations):
  181. writer.writerow(pair)
  182. # Print a sample of lines
  183. print("\nSample lines from file:")
  184. printlines(datafile)
  185. """
  186. 加载和清洗数据
  187. 我们下一个任务是创建词汇表并将查询/响应句子对(对话)加载到内存。
  188. 注意我们正在处理词序,这些词序没有映射到离散数值空间。因此,
  189. 我们必须通过数据集中的单词来创建一个索引。
  190. 为此我们创建了一个Voc类,它会存储从单词到索引的映射、索引到单词的反向映射、每个单词的计数和总单词量。
  191. 这个类提供向词汇表中添加单词的方法(addWord)、添加所有单词到句子中的方法 (addSentence) 和清洗不常见的单词方法(trim)。
  192. 更多的数据清洗在后面进行。
  193. """
  194. PAD_token=0 #used for pandding short sentences
  195. SOS_token=1 #start of sentence token
  196. EOS_token=2 #end of sentence token
  197. class Voc:
  198. def __init__(self,name):
  199. self.name=name
  200. self.trimmed=False
  201. self.word2index={}
  202. self.index2word={PAD_token:"PAD",SOS_token:"SOS",EOS_token:"EOS"}
  203. self.word2count={}
  204. self.num_words=3 #count PAD SOS EOS
  205. def addSentence(self,sentence):
  206. for word in sentence.split(' '):
  207. self.addWord(word)
  208. def addWord(self,word):
  209. if word not in self.word2index:
  210. self.word2index[word]=self.num_words
  211. self.word2count[word]=1
  212. self.index2word[self.num_words]=word
  213. self.num_words+=1
  214. else:
  215. self.word2count[word]+=1
  216. #删除低于摸个阀值的单词
  217. def trim(self,min_count):
  218. if self.trimmed:
  219. retrun
  220. self.trimmed=True
  221. keep_words=[]
  222. for k,v in self.word2count.items():
  223. if v>=min_count:
  224. keep_words.append(k)
  225. print('keep_words {} / {} = {:.4f}'.format(
  226. len(keep_words), len(self.word2index), len(keep_words) / len(self.word2index)
  227. ))
  228. self.word2index={}
  229. self.word2count={}
  230. self.index2word={PAD_token:"PAD",SOS_token:"SOS",EOS_token:"EOS"}
  231. self.num_words=3
  232. for word in keep_words:
  233. self.addWord(word)
  234. """
  235. 现在我们可以组装词汇表和查询/响应语句对。在使用数据之前,我们必须做一些预处理。
  236. 首先,我们必须使用unicodeToAscii将unicode字符串转换为ASCII。
  237. 然后,我们应该将所有字母转换为小写字母并清洗掉除基本标点之外的所有非字母字符 (normalizeString)。
  238. 最后,为了帮助训练收敛,我们将过滤掉长度大于MAX_LENGTH 的句子 (filterPairs)。
  239. """
  240. MAX_LENGTH=10
  241. def unicodeToAscii(s):
  242. return "".join(
  243. c for c in unicodedata.normalize('NFD',s)
  244. if unicodedata.category(c) != 'Mn'
  245. )
  246. #初始化VOc对象 和格式化pairs对话存放在List中
  247. def readVocs(datafile,corpus_name):
  248. print("reading lines ...")
  249. #read the file and split into lines
  250. lines =open(datafile,encoding='utf-8').read().strip().split('\n')
  251. pairs=[[unicodeToAscii(s) for s in l.split('\t')] for l in lines]
  252. voc=Voc(corpus_name)
  253. return voc ,pairs
  254. #如果对'p'中的2各句子都第一max_length的阀值 则返回True
  255. def filterPair(p):
  256. return len(p[0].split(" "))<MAX_LENGTH and len(p[1].split(" "))<MAX_LENGTH
  257. # 过滤满足条件的Pairs对话
  258. def filterPairs(pairs):
  259. return [pair for pair in pairs if filterPair(pair)]
  260. #使用上面定义的函数,返回一个填充的voc对象和对列表
  261. def loadPrepareData(corpus,corpus_name,datafile,save_dir):
  262. print("start prepareing training data ...")
  263. voc, pairs =readVocs(datafile,corpus_n
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/小小林熬夜学编程/article/detail/371802
推荐阅读
相关标签
  

闽ICP备14008679号