当前位置:   article > 正文

如何利用PyTorch实现一个Encoder-Decoder结构进行英法互译_decode pytorch示例

decode pytorch示例

数据集下载地址:

https://download.pytorch.org/tutorial/data.zip​

download.pytorch.org

数据集在eng-fra.txt文件中,每一行是一对儿英语和法语之间的互译。

运行以下代码,请确保

PyTorch=1.9.0

torchtext=0.10.0

Encoder中的数据流:

Decoder中的数据流:

带有注意力机制Decoder的数据流:

  1. # Encoder-Decoder实现英法互译
  2. from __future__ import unicode_literals, print_function, division
  3. import random
  4. import re
  5. from io import open
  6. import torch
  7. import torch.nn as nn
  8. import torch.nn.functional as F
  9. import unicodedata
  10. from torch import optim
  11. # 获取可用设备
  12. device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  13. # 一句话的开始标志 start of string
  14. SOS_token = 0
  15. # 一句话的结尾标志 end of string
  16. EOS_token = 1
  17. # 要翻译的语言的包装类,包含了常用工具
  18. class Lang:
  19. def __init__(self, name):
  20. # 名称
  21. self.name = name
  22. # 词语->索引
  23. self.word2index = {}
  24. # 词语->计数
  25. self.word2count = {}
  26. # 索引->词语
  27. # 默认添加SOS,EOS
  28. self.index2word = {0: "SOS", 1: "EOS"}
  29. # 词语数
  30. # 因为现在已经有 SOS,EOS 所以=2
  31. self.n_words = 2 # Count SOS and EOS
  32. # 添加一句话
  33. def addSentence(self, sentence):
  34. # 以空格分割这句话
  35. # 然后取出每一个词语
  36. for word in sentence.split(' '):
  37. # 添加词语
  38. self.addWord(word)
  39. def addWord(self, word):
  40. # 如果以前没有添加过这个词语
  41. if word not in self.word2index:
  42. # 索引从0开始
  43. # 所以先赋值
  44. # 最后self.n_words+=1
  45. self.word2index[word] = self.n_words
  46. self.word2count[word] = 1
  47. self.index2word[self.n_words] = word
  48. self.n_words += 1
  49. else:
  50. # 已经存在则计数+=1
  51. self.word2count[word] += 1
  52. # 将一个Unicoide编码的字符
  53. # 转换为ASCII编码的字符
  54. # 统一字符编码方便处理
  55. # 将一个Unicode字符串(数据集中的)转换为一个ASCII字符串(输入模型中的)
  56. # 数据标准化
  57. # 一个Unicode字符可以用多种不同的ASCII字符表示
  58. # 转换为统一的形式方便模型处理
  59. def unicodeToAscii(s):
  60. return ''.join(
  61. # normalize() 第一个参数指定字符串标准化的方式。
  62. # NFC表示字符使用单一编码优先,
  63. # 而NFD表示字符应该分解为多个组合字符表示
  64. # 先将输入的字符转换
  65. # 然后再过滤
  66. # Mn表示Mark
  67. # 如果不是特殊标记
  68. c for c in unicodedata.normalize('NFD', s)
  69. if unicodedata.category(c) != 'Mn'
  70. )
  71. # 将字符串规范化
  72. def normalizeString(s):
  73. # s.lower()先转换为小写
  74. # .strip()去除首尾的空格
  75. # 转换为ASCII编码的形式
  76. s = unicodeToAscii(s.lower().strip())
  77. # 去除标点符号
  78. s = re.sub(r"([.!?])", r" \1", s)
  79. # 去除非字母
  80. s = re.sub(r"[^a-zA-Z.!?]+", r" ", s)
  81. return s
  82. # 从数据集中读取一行数据
  83. def readLangs(lang1, lang2, reverse=False):
  84. print("Reading lines...")
  85. # Read the file and split into lines
  86. # 首先以utf-8的方式打开数据集文件
  87. # read()读取
  88. # strip()去除多余的空格
  89. # 以\n分割读取到的内容,也就是分割出每一行
  90. lines = open('data/%s-%s.txt' % (lang1, lang2), encoding='utf-8'). \
  91. read().strip().split('\n')
  92. # Split every line into pairs and normalize
  93. # 对于数据集中的每一行for l in lines
  94. # 将每一行以\t分割for s in l.split('\t')
  95. # 对于分割出来的每一句话s,进行规范化
  96. pairs = [[normalizeString(s) for s in l.split('\t')] for l in lines]
  97. # Reverse pairs, make Lang instances
  98. # 如果要翻转数据集
  99. # 什么意思呢,就是如果原数据集存放的是英语->法语
  100. # 如果指定reverse
  101. # 那么将它进行翻转,变成法语->英语
  102. if reverse:
  103. pairs = [list(reversed(p)) for p in pairs]
  104. input_lang = Lang(lang2)
  105. output_lang = Lang(lang1)
  106. else:
  107. input_lang = Lang(lang1)
  108. output_lang = Lang(lang2)
  109. return input_lang, output_lang, pairs
  110. # 一句话的最大长度
  111. MAX_LENGTH = 10
  112. # 选取数据集中的一部分进行训练
  113. # 选取带有一以下前缀的句子
  114. eng_prefixes = (
  115. "i am ", "i m ",
  116. "he is", "he s ",
  117. "she is", "she s ",
  118. "you are", "you re ",
  119. "we are", "we re ",
  120. "they are", "they re "
  121. )
  122. # 按照最大长度
  123. # 指定前缀
  124. # 筛选数据集
  125. def filterPair(p):
  126. return len(p[0].split(' ')) < MAX_LENGTH and len(p[1].split(' ')) < MAX_LENGTH and (
  127. p[0].startswith(eng_prefixes) or p[1].startswith(eng_prefixes))
  128. def filterPairs(pairs):
  129. return [pair for pair in pairs if filterPair(pair)]
  130. # 读取数据集
  131. def prepareData(lang1, lang2, reverse=False):
  132. input_lang, output_lang, pairs = readLangs(lang1, lang2, reverse)
  133. print("Read %s sentence pairs" % len(pairs))
  134. pairs = filterPairs(pairs)
  135. print("Trimmed to %s sentence pairs" % len(pairs))
  136. print("Counting words...")
  137. for pair in pairs:
  138. input_lang.addSentence(pair[0])
  139. output_lang.addSentence(pair[1])
  140. print("Counted words:")
  141. print(input_lang.name, input_lang.n_words)
  142. print(output_lang.name, output_lang.n_words)
  143. return input_lang, output_lang, pairs
  144. input_lang, output_lang, pairs = prepareData('eng', 'fra', True)
  145. print(random.choice(pairs))
  146. # 定义Encoder
  147. class EncoderRNN(nn.Module):
  148. def __init__(self, input_size, hidden_size):
  149. super(EncoderRNN, self).__init__()
  150. self.hidden_size = hidden_size
  151. # 词嵌入
  152. self.embedding = nn.Embedding(input_size, hidden_size)
  153. # GRU
  154. # 因为前面将输入进行了词嵌入,所以输入维度是hidden_size
  155. self.gru = nn.GRU(hidden_size, hidden_size)
  156. # 前向传递,建立计算图
  157. def forward(self, input, hidden):
  158. # 改成[长度,批大小,嵌入维度]的格式
  159. # 为什么这里长度,批大小都是1呢
  160. # 因为后面我们是将一句话中的每一个词逐一输入到Encoder中的
  161. # Decoder同理
  162. embedded = self.embedding(input).view(1, 1, -1)
  163. output = embedded
  164. output, hidden = self.gru(output, hidden)
  165. return output, hidden
  166. def initHidden(self):
  167. return torch.zeros(1, 1, self.hidden_size, device=device)
  168. # 定义Decoder
  169. class DecoderRNN(nn.Module):
  170. def __init__(self, hidden_size, output_size):
  171. super(DecoderRNN, self).__init__()
  172. self.hidden_size = hidden_size
  173. self.embedding = nn.Embedding(output_size, hidden_size)
  174. self.gru = nn.GRU(hidden_size, hidden_size)
  175. self.out = nn.Linear(hidden_size, output_size)
  176. self.softmax = nn.LogSoftmax(dim=1)
  177. def forward(self, input, hidden):
  178. output = self.embedding(input).view(1, 1, -1)
  179. output = F.relu(output)
  180. output, hidden = self.gru(output, hidden)
  181. output = self.softmax(self.out(output[0]))
  182. return output, hidden
  183. def initHidden(self):
  184. return torch.zeros(1, 1, self.hidden_size, device=device)
  185. # 定义带有注意力机制的Decoder
  186. class AttnDecoderRNN(nn.Module):
  187. def __init__(self, hidden_size, output_size, dropout_p=0.1, max_length=MAX_LENGTH):
  188. super(AttnDecoderRNN, self).__init__()
  189. self.hidden_size = hidden_size
  190. self.output_size = output_size
  191. self.dropout_p = dropout_p
  192. self.max_length = max_length
  193. self.embedding = nn.Embedding(self.output_size, self.hidden_size)
  194. # attention的输入是词嵌入向量和隐状态
  195. # 所以输入维度是self.hidden_size*2
  196. # 因为Decoder输出的句子长度不确定
  197. # 所以这里输出维度直接取最大了
  198. self.attn = nn.Linear(self.hidden_size * 2, self.max_length)
  199. self.attn_combine = nn.Linear(self.hidden_size * 2, self.hidden_size)
  200. self.dropout = nn.Dropout(self.dropout_p)
  201. self.gru = nn.GRU(self.hidden_size, self.hidden_size)
  202. self.out = nn.Linear(self.hidden_size, self.output_size)
  203. def forward(self, input, hidden, encoder_outputs):
  204. embedded = self.embedding(input).view(1, 1, -1)
  205. embedded = self.dropout(embedded)
  206. attn_weights = F.softmax(
  207. self.attn(torch.cat((embedded[0], hidden[0]), 1)), dim=1)
  208. # 将encoder的输出乘以注意力权重
  209. attn_applied = torch.bmm(attn_weights.unsqueeze(0),
  210. encoder_outputs.unsqueeze(0))
  211. output = torch.cat((embedded[0], attn_applied[0]), 1)
  212. output = self.attn_combine(output).unsqueeze(0)
  213. output = F.relu(output)
  214. output, hidden = self.gru(output, hidden)
  215. output = F.log_softmax(self.out(output[0]), dim=1)
  216. return output, hidden, attn_weights
  217. def initHidden(self):
  218. return torch.zeros(1, 1, self.hidden_size, device=device)
  219. # 句子->索引
  220. def indexesFromSentence(lang, sentence):
  221. return [lang.word2index[word] for word in sentence.split(' ')]
  222. # 将句子转换为张量
  223. def tensorFromSentence(lang, sentence):
  224. indexes = indexesFromSentence(lang, sentence)
  225. indexes.append(EOS_token)
  226. return torch.tensor(indexes, dtype=torch.long, device=device).view(-1, 1)
  227. # 将数据集中的一个样本转换为张量
  228. def tensorsFromPair(pair):
  229. input_tensor = tensorFromSentence(input_lang, pair[0])
  230. target_tensor = tensorFromSentence(output_lang, pair[1])
  231. return input_tensor, target_tensor
  232. teacher_forcing_ratio = 0.5
  233. def train(input_tensor, target_tensor, encoder, decoder, encoder_optimizer, decoder_optimizer, criterion,
  234. max_length=MAX_LENGTH):
  235. # 初始化Encoder的隐藏层
  236. encoder_hidden = encoder.initHidden()
  237. # 梯度清零
  238. encoder_optimizer.zero_grad()
  239. decoder_optimizer.zero_grad()
  240. # 输入输出的长度
  241. input_length = input_tensor.size(0)
  242. target_length = target_tensor.size(0)
  243. encoder_outputs = torch.zeros(max_length, encoder.hidden_size, device=device)
  244. loss = 0
  245. # 将一句话中的每个词语输入到Encoder中
  246. for ei in range(input_length):
  247. encoder_output, encoder_hidden = encoder(
  248. input_tensor[ei], encoder_hidden)
  249. # 获取每一步的输出
  250. encoder_outputs[ei] = encoder_output[0, 0]
  251. # decoder的输入是一个SOS标记
  252. decoder_input = torch.tensor([[SOS_token]], device=device)
  253. # 隐状态是Encoder的最后的隐状态输出
  254. decoder_hidden = encoder_hidden
  255. # 是否使用teacher_force的训练模式
  256. use_teacher_forcing = True if random.random() < teacher_forcing_ratio else False
  257. if use_teacher_forcing:
  258. # 如果指定了teacher_force训练模式
  259. # decoder每一步的输入是真实target中的词语
  260. for di in range(target_length):
  261. decoder_output, decoder_hidden, decoder_attention = decoder(
  262. decoder_input, decoder_hidden, encoder_outputs)
  263. loss += criterion(decoder_output, target_tensor[di])
  264. decoder_input = target_tensor[di]
  265. else:
  266. # 指没有指定teacher_force训练模式
  267. # decoder的下一步的输入是decoder上一步的输出
  268. for di in range(target_length):
  269. decoder_output, decoder_hidden, decoder_attention = decoder(
  270. decoder_input, decoder_hidden, encoder_outputs)
  271. topv, topi = decoder_output.topk(1)
  272. decoder_input = topi.squeeze().detach()
  273. loss += criterion(decoder_output, target_tensor[di])
  274. # 如果已经翻译完了
  275. if decoder_input.item() == EOS_token:
  276. break
  277. # 反向传播
  278. loss.backward()
  279. # 梯度更新
  280. encoder_optimizer.step()
  281. decoder_optimizer.step()
  282. return loss.item() / target_length
  283. import time
  284. import math
  285. # 秒到分钟转换
  286. def asMinutes(s):
  287. m = math.floor(s / 60)
  288. s -= m * 60
  289. return '%dm %ds' % (m, s)
  290. # 获取运行时间间隔
  291. def timeSince(since, percent):
  292. now = time.time()
  293. s = now - since
  294. es = s / percent
  295. rs = es - s
  296. return '%s (- %s)' % (asMinutes(s), asMinutes(rs))
  297. def trainIters(encoder, decoder, n_iters, print_every=1000, plot_every=100, learning_rate=0.01):
  298. start = time.time()
  299. plot_losses = []
  300. print_loss_total = 0
  301. plot_loss_total = 0
  302. # 随机梯度下降优化
  303. encoder_optimiz.er = optim.SGD(encoder.parameters(), lr=learning_rate)
  304. decoder_optimizer = optim.SGD(decoder.parameters(), lr=learning_rate)
  305. # 获取训练数据
  306. training_pairs = [tensorsFromPair(random.choice(pairs))
  307. for i in range(n_iters)]
  308. # NLLLoss()+LogSoftmax()=CrossEntropy()
  309. criterion = nn.NLLLoss()
  310. for iter in range(1, n_iters + 1):
  311. training_pair = training_pairs[iter - 1]
  312. input_tensor = training_pair[0]
  313. target_tensor = training_pair[1]
  314. loss = train(input_tensor, target_tensor, encoder,
  315. decoder, encoder_optimizer, decoder_optimizer, criterion)
  316. print_loss_total += loss
  317. plot_loss_total += loss
  318. if iter % print_every == 0:
  319. print_loss_avg = print_loss_total / print_every
  320. print_loss_total = 0
  321. print('%s (%d %d%%) %.4f' % (timeSince(start, iter / n_iters),
  322. iter, iter / n_iters * 100, print_loss_avg))
  323. if iter % plot_every == 0:
  324. plot_loss_avg = plot_loss_total / plot_every
  325. plot_losses.append(plot_loss_avg)
  326. plot_loss_total = 0
  327. showPlot(plot_losses)
  328. import matplotlib.pyplot as plt
  329. plt.switch_backend('agg')
  330. import matplotlib.ticker as ticker
  331. # 画图
  332. def showPlot(points):
  333. plt.figure()
  334. fig, ax = plt.subplots()
  335. loc = ticker.MultipleLocator(base=0.2)
  336. ax.yaxis.set_major_locator(loc)
  337. plt.plot(points)
  338. # 模型验证
  339. def evaluate(encoder, decoder, sentence, max_length=MAX_LENGTH):
  340. with torch.no_grad():
  341. input_tensor = tensorFromSentence(input_lang, sentence)
  342. input_length = input_tensor.size()[0]
  343. encoder_hidden = encoder.initHidden()
  344. encoder_outputs = torch.zeros(max_length, encoder.hidden_size, device=device)
  345. for ei in range(input_length):
  346. encoder_output, encoder_hidden = encoder(input_tensor[ei],
  347. encoder_hidden)
  348. encoder_outputs[ei] += encoder_output[0, 0]
  349. decoder_input = torch.tensor([[SOS_token]], device=device) # SOS
  350. decoder_hidden = encoder_hidden
  351. decoded_words = []
  352. decoder_attentions = torch.zeros(max_length, max_length)
  353. for di in range(max_length):
  354. decoder_output, decoder_hidden, decoder_attention = decoder(
  355. decoder_input, decoder_hidden, encoder_outputs)
  356. decoder_attentions[di] = decoder_attention.data
  357. topv, topi = decoder_output.data.topk(1)
  358. if topi.item() == EOS_token:
  359. decoded_words.append('<EOS>')
  360. break
  361. else:
  362. decoded_words.append(output_lang.index2word[topi.item()])
  363. decoder_input = topi.squeeze().detach()
  364. return decoded_words, decoder_attentions[:di + 1]
  365. def evaluateRandomly(encoder, decoder, n=10):
  366. for i in range(n):
  367. pair = random.choice(pairs)
  368. print('>', pair[0])
  369. print('=', pair[1])
  370. output_words, attentions = evaluate(encoder, decoder, pair[0])
  371. output_sentence = ' '.join(output_words)
  372. print('<', output_sentence)
  373. print('')
  374. hidden_size = 256
  375. encoder1 = EncoderRNN(input_lang.n_words, hidden_size).to(device)
  376. attn_decoder1 = AttnDecoderRNN(hidden_size, output_lang.n_words, dropout_p=0.1).to(device)
  377. trainIters(encoder1, attn_decoder1, 75000, print_every=5000)
  378. evaluateRandomly(encoder1, attn_decoder1)
  379. output_words, attentions = evaluate(
  380. encoder1, attn_decoder1, "je suis trop froid .")
  381. plt.matshow(attentions.numpy())
  382. # 可视化注意力
  383. def showAttention(input_sentence, output_words, attentions):
  384. fig = plt.figure()
  385. ax = fig.add_subplot(111)
  386. cax = ax.matshow(attentions.numpy(), cmap='bone')
  387. fig.colorbar(cax)
  388. ax.set_xticklabels([''] + input_sentence.split(' ') +
  389. ['<EOS>'], rotation=90)
  390. ax.set_yticklabels([''] + output_words)
  391. ax.xaxis.set_major_locator(ticker.MultipleLocator(1))
  392. ax.yaxis.set_major_locator(ticker.MultipleLocator(1))
  393. plt.show()
  394. def evaluateAndShowAttention(input_sentence):
  395. output_words, attentions = evaluate(
  396. encoder1, attn_decoder1, input_sentence)
  397. print('input =', input_sentence)
  398. print('output =', ' '.join(output_words))
  399. showAttention(input_sentence, output_words, attentions)
  400. evaluateAndShowAttention("elle a cinq ans de moins que moi .")
  401. evaluateAndShowAttention("elle est trop petit .")
  402. evaluateAndShowAttention("je ne crains pas de mourir .")
  403. evaluateAndShowAttention("c est un jeune directeur plein de talent .")

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/weixin_40725706/article/detail/379321
推荐阅读
相关标签
  

闽ICP备14008679号