当前位置:   article > 正文

“Datawhale AI夏令营第二期”-NLP方向 task3 笔记

“Datawhale AI夏令营第二期”-NLP方向 task3 笔记

班级群-NLP自然语言处理12—yujiarui-太原理工大学

Task3:基于Transformer解决机器翻译任务

一、Transformer 介绍

基于循环卷积神经网络的序列到序列建模方法是现存机器翻译任务中的经典方法。然而,它们在建模文本长程依赖方面都存在一定的局限性

  • 对于卷积神经网络来说,受限的上下文窗口在建模长文本方面天然地存在不足。如果要对长距离依赖进行描述,需要多层卷积操作,而且不同层之间信息传递也可能有损失,这些都限制了模型的能力。

  • 而对于循环神经网络来说,上下文的语义依赖是通过维护循环单元中的隐状态实现的。在编码过程中,每一个时间步的输入建模都涉及到对隐藏状态的修改。随着序列长度的增加,编码在隐藏状态中的序列早期的上下文信息被逐渐遗忘。尽管注意力机制的引入在一定程度上缓解了这个问题,但循环网络在编码效率方面仍存在很大的不足之处。由于编码端和解码端的每一个时间步的隐藏状态都依赖于前一时间步的计算结果,这就造成了在训练和推断阶段的低效。

  • 二、基于 task2 的 baseline 修改代码

    1. # 位置编码
    2. class PositionalEncoding(nn.Module):
    3. def __init__(self, d_model, dropout=0.1, max_len=5000):
    4. super(PositionalEncoding, self).__init__()
    5. self.dropout = nn.Dropout(p=dropout)
    6. pe = torch.zeros(max_len, d_model)
    7. position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
    8. div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
    9. pe[:, 0::2] = torch.sin(position * div_term)
    10. pe[:, 1::2] = torch.cos(position * div_term)
    11. pe = pe.unsqueeze(0).transpose(0, 1)
    12. self.register_buffer('pe', pe)
    13. def forward(self, x):
    14. x = x + self.pe[:x.size(0), :]
    15. return self.dropout(x)
    16. # Transformer
    17. class TransformerModel(nn.Module):
    18. def __init__(self, src_vocab, tgt_vocab, d_model, nhead, num_encoder_layers, num_decoder_layers, dim_feedforward, dropout):
    19. super(TransformerModel, self).__init__()
    20. self.transformer = nn.Transformer(d_model, nhead, num_encoder_layers, num_decoder_layers, dim_feedforward, dropout)
    21. self.src_embedding = nn.Embedding(len(src_vocab), d_model)
    22. self.tgt_embedding = nn.Embedding(len(tgt_vocab), d_model)
    23. self.positional_encoding = PositionalEncoding(d_model, dropout)
    24. self.fc_out = nn.Linear(d_model, len(tgt_vocab))
    25. self.src_vocab = src_vocab
    26. self.tgt_vocab = tgt_vocab
    27. self.d_model = d_model
    28. def forward(self, src, tgt):
    29. # 调整src和tgt的维度
    30. src = src.transpose(0, 1) # (seq_len, batch_size)
    31. tgt = tgt.transpose(0, 1) # (seq_len, batch_size)
    32. src_mask = self.transformer.generate_square_subsequent_mask(src.size(0)).to(src.device)
    33. tgt_mask = self.transformer.generate_square_subsequent_mask(tgt.size(0)).to(tgt.device)
    34. src_padding_mask = (src == self.src_vocab['<pad>']).transpose(0, 1)
    35. tgt_padding_mask = (tgt == self.tgt_vocab['<pad>']).transpose(0, 1)
    36. src_embedded = self.positional_encoding(self.src_embedding(src) * math.sqrt(self.d_model))
    37. tgt_embedded = self.positional_encoding(self.tgt_embedding(tgt) * math.sqrt(self.d_model))
    38. output = self.transformer(src_embedded, tgt_embedded,
    39. src_mask, tgt_mask, None, src_padding_mask, tgt_padding_mask, src_padding_mask)
    40. return self.fc_out(output).transpose(0, 1)

    三、其他上分技巧

  • 最简单的就是调参,将 epochs 调大一点,使用全部训练集,以及调整模型的参数,如head、layers等。如果数据量允许,增加模型的深度(更多的编码器/解码器层)或宽度(更大的隐藏层尺寸),这通常可以提高模型的表达能力和翻译质量,尤其是在处理复杂或专业内容时。

  • 加入术语词典,这是在此竞赛中比较有效的方法,加入术语词典的方法策略也有很多,如:

    • 模型生成的翻译输出中替换术语,这是最简单的方法

    • 整合到数据预处理流程,确保它们在翻译中保持一致

    • 在模型内部动态地调整术语的嵌入,这涉及到在模型中加入一个额外的层,该层负责查找术语词典中的术语,并为其生成专门的嵌入向量,然后将这些向量与常规的词嵌入结合使用

  • 认真做数据清洗,我们在 Task2 已经提到过当前训练集存在脏数据的问题,会影响我们的模型训练

  • 数据扩增

    • 回译(back-translation):将源语言文本先翻译成目标语言,再将目标语言文本翻译回源语言,生成的新文本作为额外的训练数据

    • 同义词替换:随机选择句子中的词,并用其同义词替换

    • 使用句法分析和语义解析技术重新表述句子,保持原意不变

    • 将文本翻译成多种语言后再翻译回原语言,以获得多样化翻译

  • 采用更精细的学习率调度策略(baseline我们使用的是固定学习率):

    • Noam Scheduler:结合了warmup阶段和衰减阶段

    • Step Decay:最简单的一种学习率衰减策略,每隔一定数量的epoch,学习率按固定比例衰减

    • Cosine Annealing:学习率随周期性变化,通常从初始值下降到接近零,然后再逐渐上升

  • 自己训练一个小的预训练模型,尽量选择 1B 以下小模型,对 GPU 资源要求比较高,仅仅使用魔搭平台可能就满足不了

  • 将训练集上训练出来的模型拿到开发集(dev dataset)上 finetune 可以提高测试集(test dataset)的得分,因为开发集与测试集的分布比较相近

  • 在开发集和测试集上训一个语言模型,用这个语言模型给训练集中的句子打分,选出一些高分句子

  • 集成学习:训练多个不同初始化或架构的模型,并使用集成方法(如投票或平均)来产生最终翻译。这可以减少单一模型的过拟合风险,提高翻译的稳定性。

task3 学习后:

N_EPOCHS=5时

运行训练时间:24m10s

输出翻译用时:8m10s

第三次 task3 得分:6.5828

本文内容由网友自发贡献,转载请注明出处:【wpsshop博客】
推荐阅读
相关标签
  

闽ICP备14008679号