赞
踩
该笔记比较baseline2与1相比进行了哪些改进,以及对后续优化的启发。
利用神经网络实现机器翻译,先编码再解码,将英文翻译为中文。
如何更好地实现文本到编码的转换从而提升模型性能?
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
- import torch.optim as optim
- from torch.nn.utils import clip_grad_norm_
- from torchtext.data.metrics import bleu_score
- from torch.utils.data import Dataset, DataLoader
- from torchtext.data.utils import get_tokenizer
- from torchtext.vocab import build_vocab_from_iterator
- from typing import List, Tuple
- import jieba
- import random
- from torch.nn.utils.rnn import pad_sequence
- import sacrebleu
- import time
- import math
中文采用jieba库分词,英文采用spacy库分词,结果更加准确
与baseline1相比,新增attention类,引入注意力机制。
该机制使神经网络模型能够关注编码器每一层的隐藏输出,都作为解码器的输入,而不只考虑编码器的最后一层隐藏输出。
- class Attention(nn.Module):
- def __init__(self, hid_dim):
- super().__init__()
- self.attn = nn.Linear(hid_dim * 2, hid_dim)
- self.v = nn.Linear(hid_dim, 1, bias=False)
-
- def forward(self, hidden, encoder_outputs):
- # hidden = [1, batch size, hid dim]
- # encoder_outputs = [batch size, src len, hid dim]
-
- batch_size = encoder_outputs.shape[0]
- src_len = encoder_outputs.shape[1]
-
- hidden = hidden.repeat(src_len, 1, 1).transpose(0, 1)
- # hidden = [batch size, src len, hid dim]
-
- energy = torch.tanh(self.attn(torch.cat((hidden, encoder_outputs), dim=2)))
- # energy = [batch size, src len, hid dim]
-
- attention = self.v(energy).squeeze(2)
- # attention = [batch size, src len]
-
- return F.softmax(attention, dim=1)
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。