当前位置:   article > 正文

Datawhale AI 夏令营 NLP方向 Task2笔记

Datawhale AI 夏令营 NLP方向 Task2笔记

该笔记比较baseline2与1相比进行了哪些改进,以及对后续优化的启发。

赛题回顾

利用神经网络实现机器翻译,先编码再解码,将英文翻译为中文。

task2任务内容 

 如何更好地实现文本到编码的转换从而提升模型性能? 

  1. import torch
  2. import torch.nn as nn
  3. import torch.nn.functional as F
  4. import torch.optim as optim
  5. from torch.nn.utils import clip_grad_norm_
  6. from torchtext.data.metrics import bleu_score
  7. from torch.utils.data import Dataset, DataLoader
  8. from torchtext.data.utils import get_tokenizer
  9. from torchtext.vocab import build_vocab_from_iterator
  10. from typing import List, Tuple
  11. import jieba
  12. import random
  13. from torch.nn.utils.rnn import pad_sequence
  14. import sacrebleu
  15. import time
  16. import math

数据预处理

分词器tokenizer

中文采用jieba库分词,英文采用spacy库分词,结果更加准确

模型构建

与baseline1相比,新增attention类,引入注意力机制。 

该机制使神经网络模型能够关注编码器每一层的隐藏输出,都作为解码器的输入,而不只考虑编码器的最后一层隐藏输出。

  1. class Attention(nn.Module):
  2. def __init__(self, hid_dim):
  3. super().__init__()
  4. self.attn = nn.Linear(hid_dim * 2, hid_dim)
  5. self.v = nn.Linear(hid_dim, 1, bias=False)
  6. def forward(self, hidden, encoder_outputs):
  7. # hidden = [1, batch size, hid dim]
  8. # encoder_outputs = [batch size, src len, hid dim]
  9. batch_size = encoder_outputs.shape[0]
  10. src_len = encoder_outputs.shape[1]
  11. hidden = hidden.repeat(src_len, 1, 1).transpose(0, 1)
  12. # hidden = [batch size, src len, hid dim]
  13. energy = torch.tanh(self.attn(torch.cat((hidden, encoder_outputs), dim=2)))
  14. # energy = [batch size, src len, hid dim]
  15. attention = self.v(energy).squeeze(2)
  16. # attention = [batch size, src len]
  17. return F.softmax(attention, dim=1)

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/笔触狂放9/article/detail/862363
推荐阅读
相关标签
  

闽ICP备14008679号