当前位置:   article > 正文

Datawhale AI夏令营NLP自然语言处理-Task1学习笔记

Datawhale AI夏令营NLP自然语言处理-Task1学习笔记

        从零入门NLP竞赛 是 Datawhale 2024 年 AI 夏令营第二期的学习活动(“NLP”方向),基于讯飞开放平台“基于术语词典干预的机器翻译挑战赛”开展的实践学习——适合想 入门并实践 深度学习、解决NLP问题、机器翻译 的学习者参与

赛题解析

        本赛题是一个经典的NLP问题,赛事目标为——【通过术语词典构建机器翻译模型】

        在特定领域或行业中,由于机器翻译难以保证术语的一致性,导致翻译效果还不够理想。对于术语名词、人名地名等机器翻译不准确的结果,可以通过术语词典进行纠正,避免了混淆或歧义,最大限度提高翻译质量。

        自然语言处理(Natural Language Processing,NLP)是语言学与人工智能的分支,试图让计算机能够完成处理语言、理解语言和生成语言等任务。

        大致可以将NLP 任务分为四类:

  1. 序列标注:比如中文分词,词性标注,命名实体识别,语义角色标注等都可以归入这一类问题。这类任务的共同点是句子中每个单词要求模型根据上下文都要给出一个分类类别;

  2. 分类任务:比如我们常见的文本分类,情感计算等都可以归入这一类。这类任务特点是不管文章有多长,总体给出一个分类类别即可;

  3. 句子关系判断:比如问答推理,语义改写,自然语言推理等任务都是这个模式,它的特点是给定两个句子,模型判断出两个句子是否具备某种语义关系;

  4. 生成式任务:比如机器翻译,文本摘要,写诗造句,看图说话等都属于这一类。它的特点是输入文本内容后,需要自主生成另外一段文字。

  • 训练集(Training Set)

    1. 作用:训练集用于训练模型,使模型能够学习输入数据与输出结果之间的映射关系。模型会根据训练集中的样本调整其参数,以最小化预测误差。

    2. 目标:让模型在训练数据上尽可能地拟合好,学习到数据的内在规律。

  • 开发集/验证集(Development/Validation Set)

    1. 作用:开发集用于在模型训练过程中调整超参数、选择模型架构以及防止过拟合。它作为独立于训练集的数据,用于评估模型在未见过的数据上的表现。

    2. 目标:通过在开发集上的性能评估,选择最佳的模型配置,避免模型在训练集上过度拟合,确保模型的泛化能力。

  • 测试集(Test Set)

    1. 作用:测试集用于最终评估模型的性能,是在模型训练和调参完全完成后,用来衡量模型实际应用效果的一组数据。它是最接近真实世界数据的评估标准。

    2. 目标:提供一个公正、无偏见的性能估计,反映模型在未知数据上的泛化能力。 

Datawhale有一个NLP开源项目

  • 评估指标

        对于参赛队伍提交的测试集翻译结果文件,采用自动评价指标 BLEU-4 进行评价,具体工具使用 sacrebleu开源版本

        在机器翻译领域,BLEU(Bilingual Evaluation Understudy)是一种常用的自动评价指标,用于衡量计算机生成的翻译与一组参考译文之间的相似度。这个指标特别关注 n-grams(连续的n个词)的精确匹配,可以被认为是对翻译准确性和流利度的一种统计估计。计算BLEU分数时,首先会统计生成文本中n-grams的频率,然后将这些频率与参考文本中的n-grams进行比较。如果生成的翻译中包含的n-grams与参考译文中出现的相同,则认为是匹配的。最终的BLEU分数是一个介于0到1之间的数值,其中1表示与参考译文完美匹配,而0则表示完全没有匹配。

        BLEU-4 特别指的是在计算时考虑四元组(即连续四个词)的匹配情况。

        BLEU 评估指标的特点:

  • 优点:计算速度快、计算成本低、容易理解、与具体语言无关、和人类给的评估高度相关。

  • 缺点:不考虑语言表达(语法)上的准确性;测评精度会受常用词的干扰;短译句的测评精度有时会较高;没有考虑同义词或相似表达的情况,可能会导致合理翻译被否定。

        除了翻译之外,BLEU评分结合深度学习方法可应用于其他的语言生成问题,例如:语言生成、图片标题生成、文本摘要、语音识别。

Baseline代码

        1、速通baseline代码

  1. !pip install torchtext
  2. import torch
  3. import torch.nn as nn
  4. import torch.optim as optim
  5. from torch.utils.data import Dataset, DataLoader
  6. from torchtext.data.utils import get_tokenizer
  7. from collections import Counter
  8. import random
  9. from torch.utils.data import Subset, DataLoader
  10. import time
  11. # 定义数据集类
  12. # 修改TranslationDataset类以处理术语
  13. class TranslationDataset(Dataset):
  14. def __init__(self, filename, terminology):
  15. self.data = []
  16. with open(filename, 'r', encoding='utf-8') as f:
  17. for line in f:
  18. en, zh = line.strip().split('\t')
  19. self.data.append((en, zh))
  20. self.terminology = terminology
  21. # 创建词汇表,注意这里需要确保术语词典中的词也被包含在词汇表中
  22. self.en_tokenizer = get_tokenizer('basic_english')
  23. self.zh_tokenizer = list # 使用字符级分词
  24. en_vocab = Counter(self.terminology.keys()) # 确保术语在词汇表中
  25. zh_vocab = Counter()
  26. for en, zh in self.data:
  27. en_vocab.update(self.en_tokenizer(en))
  28. zh_vocab.update(self.zh_tokenizer(zh))
  29. # 添加术语到词汇表
  30. self.en_vocab = ['<pad>', '<sos>', '<eos>'] + list(self.terminology.keys()) + [word for word, _ in en_vocab.most_common(10000)]
  31. self.zh_vocab = ['<pad>', '<sos>', '<eos>'] + [word for word, _ in zh_vocab.most_common(10000)]
  32. self.en_word2idx = {word: idx for idx, word in enumerate(self.en_vocab)}
  33. self.zh_word2idx = {word: idx for idx, word in enumerate(self.zh_vocab)}
  34. def __len__(self):
  35. return len(self.data)
  36. def __getitem__(self, idx):
  37. en, zh = self.data[idx]
  38. en_tensor = torch.tensor([self.en_word2idx.get(word, self.en_word2idx['<sos>']) for word in self.en_tokenizer(en)] + [self.en_word2idx['<eos>']])
  39. zh_tensor = torch.tensor([self.zh_word2idx.get(word, self.zh_word2idx['<sos>']) for word in self.zh_tokenizer(zh)] + [self.zh_word2idx['<eos>']])
  40. return en_tensor, zh_tensor
  41. def collate_fn(batch):
  42. en_batch, zh_batch = [], []
  43. for en_item, zh_item in batch:
  44. en_batch.append(en_item)
  45. zh_batch.append(zh_item)
  46. # 对英文和中文序列分别进行填充
  47. en_batch = nn.utils.rnn.pad_sequence(en_batch, padding_value=0, batch_first=True)
  48. zh_batch = nn.utils.rnn.pad_sequence(zh_batch, padding_value=0, batch_first=True)
  49. return en_batch, zh_batch
  50. class Encoder(nn.Module):
  51. def __init__(self, input_dim, emb_dim, hid_dim, n_layers, dropout):
  52. super().__init__()
  53. self.embedding = nn.Embedding(input_dim, emb_dim)
  54. self.rnn = nn.GRU(emb_dim, hid_dim, n_layers, dropout=dropout, batch_first=True)
  55. self.dropout = nn.Dropout(dropout)
  56. def forward(self, src):
  57. # src shape: [batch_size, src_len]
  58. embedded = self.dropout(self.embedding(src))
  59. # embedded shape: [batch_size, src_len, emb_dim]
  60. outputs, hidden = self.rnn(embedded)
  61. # outputs shape: [batch_size, src_len, hid_dim]
  62. # hidden shape: [n_layers, batch_size, hid_dim]
  63. return outputs, hidden
  64. class Decoder(nn.Module):
  65. def __init__(self, output_dim, emb_dim, hid_dim, n_layers, dropout):
  66. super().__init__()
  67. self.output_dim = output_dim
  68. self.embedding = nn.Embedding(output_dim, emb_dim)
  69. self.rnn = nn.GRU(emb_dim, hid_dim, n_layers, dropout=dropout, batch_first=True)
  70. self.fc_out = nn.Linear(hid_dim, output_dim)
  71. self.dropout = nn.Dropout(dropout)
  72. def forward(self, input, hidden):
  73. # input shape: [batch_size, 1]
  74. # hidden shape: [n_layers, batch_size, hid_dim]
  75. embedded = self.dropout(self.embedding(input))
  76. # embedded shape: [batch_size, 1, emb_dim]
  77. output, hidden = self.rnn(embedded, hidden)
  78. # output shape: [batch_size, 1, hid_dim]
  79. # hidden shape: [n_layers, batch_size, hid_dim]
  80. prediction = self.fc_out(output.squeeze(1))
  81. # prediction shape: [batch_size, output_dim]
  82. return prediction, hidden
  83. class Seq2Seq(nn.Module):
  84. def __init__(self, encoder, decoder, device):
  85. super().__init__()
  86. self.encoder = encoder
  87. self.decoder = decoder
  88. self.device = device
  89. def forward(self, src, trg, teacher_forcing_ratio=0.5):
  90. # src shape: [batch_size, src_len]
  91. # trg shape: [batch_size, trg_len]
  92. batch_size = src.shape[0]
  93. trg_len = trg.shape[1]
  94. trg_vocab_size = self.decoder.output_dim
  95. outputs = torch.zeros(batch_size, trg_len, trg_vocab_size).to(self.device)
  96. _, hidden = self.encoder(src)
  97. input = trg[:, 0].unsqueeze(1) # Start token
  98. for t in range(1, trg_len):
  99. output, hidden = self.decoder(input, hidden)
  100. outputs[:, t, :] = output
  101. teacher_force = random.random() < teacher_forcing_ratio
  102. top1 = output.argmax(1)
  103. input = trg[:, t].unsqueeze(1) if teacher_force else top1.unsqueeze(1)
  104. return outputs
  105. # 新增术语词典加载部分
  106. def load_terminology_dictionary(dict_file):
  107. terminology = {}
  108. with open(dict_file, 'r', encoding='utf-8') as f:
  109. for line in f:
  110. en_term, ch_term = line.strip().split('\t')
  111. terminology[en_term] = ch_term
  112. return terminology
  113. def train(model, iterator, optimizer, criterion, clip):
  114. model.train()
  115. epoch_loss = 0
  116. for i, (src, trg) in enumerate(iterator):
  117. src, trg = src.to(device), trg.to(device)
  118. optimizer.zero_grad()
  119. output = model(src, trg)
  120. output_dim = output.shape[-1]
  121. output = output[:, 1:].contiguous().view(-1, output_dim)
  122. trg = trg[:, 1:].contiguous().view(-1)
  123. loss = criterion(output, trg)
  124. loss.backward()
  125. torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
  126. optimizer.step()
  127. epoch_loss += loss.item()
  128. return epoch_loss / len(iterator)
  129. # 主函数
  130. if __name__ == '__main__':
  131. start_time = time.time() # 开始计时
  132. device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  133. #terminology = load_terminology_dictionary('../dataset/en-zh.dic')
  134. terminology = load_terminology_dictionary('../dataset/en-zh.dic')
  135. # 加载数据
  136. dataset = TranslationDataset('../dataset/train.txt',terminology = terminology)
  137. # 选择数据集的前N个样本进行训练
  138. N = 1000 #int(len(dataset) * 1) # 或者你可以设置为数据集大小的一定比例,如 int(len(dataset) * 0.1)
  139. subset_indices = list(range(N))
  140. subset_dataset = Subset(dataset, subset_indices)
  141. train_loader = DataLoader(subset_dataset, batch_size=32, shuffle=True, collate_fn=collate_fn)
  142. # 定义模型参数
  143. INPUT_DIM = len(dataset.en_vocab)
  144. OUTPUT_DIM = len(dataset.zh_vocab)
  145. ENC_EMB_DIM = 256
  146. DEC_EMB_DIM = 256
  147. HID_DIM = 512
  148. N_LAYERS = 2
  149. ENC_DROPOUT = 0.5
  150. DEC_DROPOUT = 0.5
  151. # 初始化模型
  152. enc = Encoder(INPUT_DIM, ENC_EMB_DIM, HID_DIM, N_LAYERS, ENC_DROPOUT)
  153. dec = Decoder(OUTPUT_DIM, DEC_EMB_DIM, HID_DIM, N_LAYERS, DEC_DROPOUT)
  154. model = Seq2Seq(enc, dec, device).to(device)
  155. # 定义优化器和损失函数
  156. optimizer = optim.Adam(model.parameters())
  157. criterion = nn.CrossEntropyLoss(ignore_index=dataset.zh_word2idx['<pad>'])
  158. # 训练模型
  159. N_EPOCHS = 10
  160. CLIP = 1
  161. for epoch in range(N_EPOCHS):
  162. train_loss = train(model, train_loader, optimizer, criterion, CLIP)
  163. print(f'Epoch: {epoch+1:02} | Train Loss: {train_loss:.3f}')
  164. # 在训练循环结束后保存模型
  165. torch.save(model.state_dict(), './translation_model_GRU.pth')
  166. end_time = time.time() # 结束计时
  167. # 计算并打印运行时间
  168. elapsed_time_minute = (end_time - start_time)/60
  169. print(f"Total running time: {elapsed_time_minute:.2f} minutes")

        2、在开发集上进行模型评价

  1. import torch
  2. from sacrebleu.metrics import BLEU
  3. from typing import List
  4. # 假设我们已经定义了TranslationDataset, Encoder, Decoder, Seq2Seq类
  5. def load_sentences(file_path: str) -> List[str]:
  6. with open(file_path, 'r', encoding='utf-8') as f:
  7. return [line.strip() for line in f]
  8. # 更新translate_sentence函数以考虑术语词典
  9. def translate_sentence(sentence: str, model: Seq2Seq, dataset: TranslationDataset, terminology, device: torch.device, max_length: int = 50):
  10. model.eval()
  11. tokens = dataset.en_tokenizer(sentence)
  12. tensor = torch.LongTensor([dataset.en_word2idx.get(token, dataset.en_word2idx['<sos>']) for token in tokens]).unsqueeze(0).to(device) # [1, seq_len]
  13. with torch.no_grad():
  14. _, hidden = model.encoder(tensor)
  15. translated_tokens = []
  16. input_token = torch.LongTensor([[dataset.zh_word2idx['<sos>']]]).to(device) # [1, 1]
  17. for _ in range(max_length):
  18. output, hidden = model.decoder(input_token, hidden)
  19. top_token = output.argmax(1)
  20. translated_token = dataset.zh_vocab[top_token.item()]
  21. if translated_token == '<eos>':
  22. break
  23. # 如果翻译的词在术语词典中,则使用术语词典中的词
  24. if translated_token in terminology.values():
  25. for en_term, ch_term in terminology.items():
  26. if translated_token == ch_term:
  27. translated_token = en_term
  28. break
  29. translated_tokens.append(translated_token)
  30. input_token = top_token.unsqueeze(1) # [1, 1]
  31. return ''.join(translated_tokens)
  32. def evaluate_bleu(model: Seq2Seq, dataset: TranslationDataset, src_file: str, ref_file: str, terminology,device: torch.device):
  33. model.eval()
  34. src_sentences = load_sentences(src_file)
  35. ref_sentences = load_sentences(ref_file)
  36. translated_sentences = []
  37. for src in src_sentences:
  38. translated = translate_sentence(src, model, dataset, terminology, device)
  39. translated_sentences.append(translated)
  40. bleu = BLEU()
  41. score = bleu.corpus_score(translated_sentences, [ref_sentences])
  42. return score
  43. # 主函数
  44. if __name__ == '__main__':
  45. device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  46. # 加载术语词典
  47. terminology = load_terminology_dictionary('../dataset/en-zh.dic')
  48. # 创建数据集实例时传递术语词典
  49. dataset = TranslationDataset('../dataset/train.txt', terminology)
  50. # 定义模型参数
  51. INPUT_DIM = len(dataset.en_vocab)
  52. OUTPUT_DIM = len(dataset.zh_vocab)
  53. ENC_EMB_DIM = 256
  54. DEC_EMB_DIM = 256
  55. HID_DIM = 512
  56. N_LAYERS = 2
  57. ENC_DROPOUT = 0.5
  58. DEC_DROPOUT = 0.5
  59. # 初始化模型
  60. enc = Encoder(INPUT_DIM, ENC_EMB_DIM, HID_DIM, N_LAYERS, ENC_DROPOUT)
  61. dec = Decoder(OUTPUT_DIM, DEC_EMB_DIM, HID_DIM, N_LAYERS, DEC_DROPOUT)
  62. model = Seq2Seq(enc, dec, device).to(device)
  63. # 加载训练好的模型
  64. model.load_state_dict(torch.load('./translation_model_GRU.pth'))
  65. # 评估BLEU分数
  66. bleu_score = evaluate_bleu(model, dataset, '../dataset/dev_en.txt', '../dataset/dev_zh.txt', terminology = terminology,device = device)
  67. print(f'BLEU-4 score: {bleu_score.score:.2f}')

        3、在测试集上进行推理

  1. def inference(model: Seq2Seq, dataset: TranslationDataset, src_file: str, save_dir:str, terminology, device: torch.device):
  2. model.eval()
  3. src_sentences = load_sentences(src_file)
  4. translated_sentences = []
  5. for src in src_sentences:
  6. translated = translate_sentence(src, model, dataset, terminology, device)
  7. #print(translated)
  8. translated_sentences.append(translated)
  9. #print(translated_sentences)
  10. # 将列表元素连接成一个字符串,每个元素后换行
  11. text = '\n'.join(translated_sentences)
  12. # 打开一个文件,如果不存在则创建,'w'表示写模式
  13. with open(save_dir, 'w', encoding='utf-8') as f:
  14. # 将字符串写入文件
  15. f.write(text)
  16. #return translated_sentences
  17. # 主函数
  18. if __name__ == '__main__':
  19. device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  20. # 加载术语词典
  21. terminology = load_terminology_dictionary('../dataset/en-zh.dic')
  22. # 加载数据集和模型
  23. dataset = TranslationDataset('../dataset/train.txt',terminology = terminology)
  24. # 定义模型参数
  25. INPUT_DIM = len(dataset.en_vocab)
  26. OUTPUT_DIM = len(dataset.zh_vocab)
  27. ENC_EMB_DIM = 256
  28. DEC_EMB_DIM = 256
  29. HID_DIM = 512
  30. N_LAYERS = 2
  31. ENC_DROPOUT = 0.5
  32. DEC_DROPOUT = 0.5
  33. # 初始化模型
  34. enc = Encoder(INPUT_DIM, ENC_EMB_DIM, HID_DIM, N_LAYERS, ENC_DROPOUT)
  35. dec = Decoder(OUTPUT_DIM, DEC_EMB_DIM, HID_DIM, N_LAYERS, DEC_DROPOUT)
  36. model = Seq2Seq(enc, dec, device).to(device)
  37. # 加载训练好的模型
  38. model.load_state_dict(torch.load('./translation_model_GRU.pth'))
  39. save_dir = '../dataset/submit.txt'
  40. inference(model, dataset, src_file="../dataset/test_en.txt", save_dir = save_dir, terminology = terminology, device = device)
  41. print(f"翻译完成!文件已保存到{save_dir}")

代码解析

        1、导入库和模块

  1. import torch
  2. import torch.nn as nn
  3. import torch.optim as optim
  4. from torch.utils.data import Dataset, DataLoader
  5. from torchtext.data.utils import get_tokenizer
  6. from collections import Counter
  7. import random
  8. from torch.utils.data import Subset, DataLoader
  9. import time
  10. from sacrebleu.metrics import BLEU
  11. from typing import List

重点使用的库:

  • torch:主要用于张量操作和神经网络模型的构建与训练
  • Subset:用于从数据集中选择子集
  • get_tokenizer:用于获取文本分词器的函数
  • Counter:用于统计词频

        2、数据集类TranslationDataset

  1. class TranslationDataset(Dataset):
  2. def __init__(self, filename, terminology):
  3. # 初始化数据集,包括加载数据、创建词汇表等
  4. pass
  5. def __len__(self):
  6. # 返回数据集大小
  7. pass
  8. def __getitem__(self, idx):
  9. # 按索引返回数据对的张量表示
  10. pass
  • 加载并处理包含源语言和目标语言句子的文本文件。
  • 创建词汇表,包括特殊标记、术语词典中的术语和最常见的单词。
  • 提供 __len__ 方法返回数据集大小,__getitem__ 方法按索引返回数据对的张量表示。

        3、数据加载和填充函数collate_fn

  1. def collate_fn(batch):
  2. # 对批次中的数据对进行填充
  3. pass
  • 用于处理一个批次的数据对,确保每个批次中的序列具有相同的长度。
  • 使用 nn.utils.rnn.pad_sequence 函数进行填充操作。

        4、编码器(Encoder)类

  1. class Encoder(nn.Module):
  2. def __init__(self, input_dim, emb_dim, hid_dim, n_layers, dropout):
  3. # 初始化编码器,包括词嵌入、GRU层和Dropout
  4. pass
  5. def forward(self, src):
  6. # 编码器的前向传播过程
  7. pass
  • 将输入序列编码为上下文向量。
  • 使用 nn.Embedding 层将输入序列索引映射为词嵌入。
  • 使用 nn.GRU 层进行序列编码,并应用了 Dropout 进行正则化。

        5、解码器(Decoder)类

  1. class Decoder(nn.Module):
  2. def __init__(self, output_dim, emb_dim, hid_dim, n_layers, dropout):
  3. # 初始化解码器,包括词嵌入、GRU层、线性层和Dropout
  4. pass
  5. def forward(self, input, hidden):
  6. # 解码器的前向传播过程
  7. pass
  • 根据编码器的上下文向量生成目标语言序列的预测。
  • 使用 nn.Embedding 层将目标语言序列索引映射为词嵌入。
  • 使用 nn.GRU 层进行序列解码,并应用了 Dropout 进行正则化。

        6、Seq2Seq模型类

  1. class Seq2Seq(nn.Module):
  2. def __init__(self, encoder, decoder, device):
  3. # 初始化Seq2Seq模型,包括编码器、解码器和设备
  4. pass
  5. def forward(self, src, trg, teacher_forcing_ratio=0.5):
  6. # Seq2Seq模型的前向传播过程
  7. pass
  • 整合编码器和解码器,执行完整的序列到序列转换。
  • 在前向传播中,接收源语言输入并预测目标语言输出序列。
  • 使用 teacher_forcing_ratio 控制是否使用教师强制(teacher forcing)。

        7、加载术语词典的函数load_terminology_dictionary

  1. def load_terminology_dictionary(dict_file):
  2. # 从文件加载术语词典,并返回一个字典
  3. pass
  • 从指定文件加载术语词典,以字典形式返回。

        8、训练函数train

  1. def train(model, iterator, optimizer, criterion, clip):
  2. # 模型训练函数,包括前向传播、损失计算、反向传播和参数更新
  3. pass
  • 对模型进行训练,计算并返回每个 epoch 的平均损失。

        9、翻译句子函数translate_sentence

  1. def translate_sentence(sentence, model, dataset, terminology, device, max_length=50):
  2. # 使用训练好的模型翻译单个句子
  3. pass
  • 将输入的源语言句子翻译为目标语言,考虑了术语词典的使用。

        10、BLEU评估函数evaluate_bleu

  1. def evaluate_bleu(model, dataset, src_file, ref_file, terminology, device):
  2. # 评估模型在指定数据集上的BLEU分数
  3. pass
  • 对模型在给定数据集上的翻译结果计算 BLEU 分数。

        11、推理函数inference

  1. def inference(model, dataset, src_file, save_dir, terminology, device):
  2. # 对指定文件中的所有句子进行翻译,并将结果保存到文件
  3. pass
  • 批量翻译源语言文件中的句子,并将结果保存到指定文件中。

        12、主函数

  1. if __name__ == '__main__':
  2. # 主函数,包括模型训练、加载、评估或推理等整体流程
  3. pass
  • 加载数据集和术语词典。
  • 定义模型及其参数。
  • 执行模型训练、加载预训练模型、评估BLEU分数或进行推理。
  • 主要控制整个程序的流程和调用各功能模块。

代码改进

  1. # 主函数
  2. if __name__ == '__main__':
  3. start_time = time.time() # 开始计时
  4. device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  5. #terminology = load_terminology_dictionary('../dataset/en-zh.dic')
  6. terminology = load_terminology_dictionary('../dataset/en-zh.dic')
  7. # 加载数据
  8. dataset = TranslationDataset('../dataset/train.txt',terminology = terminology)
  9. # 选择数据集的前N个样本进行训练
  10. N = 2000 #int(len(dataset) * 1) # 或者你可以设置为数据集大小的一定比例,如 int(len(dataset) * 0.1)
  11. subset_indices = list(range(N))
  12. subset_dataset = Subset(dataset, subset_indices)
  13. train_loader = DataLoader(subset_dataset, batch_size=32, shuffle=True, collate_fn=collate_fn)
  14. # 定义模型参数
  15. INPUT_DIM = len(dataset.en_vocab)
  16. OUTPUT_DIM = len(dataset.zh_vocab)
  17. ENC_EMB_DIM = 256
  18. DEC_EMB_DIM = 256
  19. HID_DIM = 512
  20. N_LAYERS = 2
  21. ENC_DROPOUT = 0.5
  22. DEC_DROPOUT = 0.5
  23. # 初始化模型
  24. enc = Encoder(INPUT_DIM, ENC_EMB_DIM, HID_DIM, N_LAYERS, ENC_DROPOUT)
  25. dec = Decoder(OUTPUT_DIM, DEC_EMB_DIM, HID_DIM, N_LAYERS, DEC_DROPOUT)
  26. model = Seq2Seq(enc, dec, device).to(device)
  27. # 定义优化器和损失函数
  28. optimizer = optim.Adam(model.parameters())
  29. criterion = nn.CrossEntropyLoss(ignore_index=dataset.zh_word2idx['<pad>'])
  30. # 训练模型
  31. N_EPOCHS = 50
  32. CLIP = 1
  33. for epoch in range(N_EPOCHS):
  34. train_loss = train(model, train_loader, optimizer, criterion, CLIP)
  35. print(f'Epoch: {epoch+1:02} | Train Loss: {train_loss:.3f}')
  36. # 在训练循环结束后保存模型
  37. torch.save(model.state_dict(), './translation_model_GRU.pth')
  38. end_time = time.time() # 结束计时
  39. # 计算并打印运行时间
  40. elapsed_time_minute = (end_time - start_time)/60
  41. print(f"Total running time: {elapsed_time_minute:.2f} minutes")

        可以注意到修改了N和N_EPOCHS两个参数。

  • N:选择数据集的前N个样本进行训练。
  • N_EPOCHS:一次epoch是指将所有数据训练一遍的次数。

        两者作用是将数据集中前N个样本抓取训练了N_EPOCHS轮。


hahaha都看到这里了,要是觉得有用的话就辛苦动动小手点个赞吧!

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/知新_RL/article/detail/859797
推荐阅读
相关标签
  

闽ICP备14008679号