赞
踩
目录
Attention is all you need!
实验环境:
GPU:RTX4090
来源:作者报考院校课题组
查看你的环境是否有如下库,如果没有请自行安装。
- import math
- import torchtext
- import torch
- import torch.nn as nn
- from torch import Tensor
- from torch.nn.utils.rnn import pad_sequence
- from torch.utils.data import DataLoader
- from collections import Counter
- from torchtext.vocab import Vocab
- from torch.nn import TransformerEncoder, TransformerDecoder, TransformerEncoderLayer, TransformerDecoderLayer
- import io
- import time
- import pandas as pd
- import numpy as np
- import pickle
- import tqdm
- import sentencepiece as spm
- torch.manual_seed(0)
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
- # print(torch.cuda.get_device_name(0)) ## 如果你有GPU,请在你自己的电脑上尝试运行这一套代码
在本次博客,我们将使用从 JParaCrawl 下载的日英平行数据集!http://www.kecl.ntt.co.jp/icl/lirg/jparacrawl。该数据集被称为为目前最大公开可用的英日平行语料库。它主要通过抓取网络并自动对齐平行句子创建而成。给大家瞅一眼这个数据集具体长什么样子:
- # 使用 pandas 库读取文件 'zh-ja/zh-ja.bicleaner05.txt',文件使用制表符(\t)作为分隔符。
- # 使用 Python 引擎来解析文件,并且不使用文件的第一行作为列名。
- df = pd.read_csv('zh-ja.bicleaner05.txt', sep='\\t', engine='python', header=None)
- # 将数据框 df 的第3列(索引为2)的所有值转换为列表,并赋值给变量 trainen
- trainen = df[2].values.tolist()
- # 将数据框 df 的第4列(索引为3)的所有值转换为列表,并赋值给变量 trainja
- trainja = df[3].values.tolist()
- # trainen.pop(5972)
- # trainja.pop(5972)
在导入所有的日语和对应的英语句子后,这里删除了数据集中最后一条有缺失值的数据。总的来说,trainen 和 trainja 中的句子数是 5,973,071。然而,为了学习目的,通常建议对数据进行采样,并确保一切正常运行后,再一次性使用所有数据,以节省时间。
以下是数据集中包含的句子的示例。
- print(trainen[500])
- print(trainja[500])
打印结果如下所示:
我们也可以使用不同的平行数据集来配合本文,只需确保我们能够将数据处理成上面所示的两个字符串列表,分别包含日语和英语句子。
与英语或其他字母语言不同,日语句子中没有空格来分隔单词。我们可以使用 JParaCrawl 提供的分词器,这些分词器是使用 SentencePiece 为日语和英语创建的。你可以访问 JParaCrawl 网站下载它们,或者点击这里进行下载。
- # 使用 SentencePieceProcessor 加载英文的分词模型 'spm.en.nopretok.model',
- # 并将其实例化为 en_tokenizer
- en_tokenizer = spm.SentencePieceProcessor(model_file='spm.en.nopretok.model')
-
- # 使用 SentencePieceProcessor 加载日文的分词模型 'spm.ja.nopretok.model',
- # 并将其实例化为 ja_tokenizer
- ja_tokenizer = spm.SentencePieceProcessor(model_file='spm.ja.nopretok.model')
我们举个例子看一看~
en_tokenizer.encode("All residents aged 20 to 59 years who live in Japan must enroll in public pension system.", out_type='str')
ja_tokenizer.encode("年金 日本に住んでいる20歳~60歳の全ての人は、公的年金制度に加入しなければなりません。", out_type='str')
使用分词器和原始句子,我们可以构建从 TorchText 导入的 Vocab 对象。这个过程可能需要几秒钟或几分钟,具体取决于数据集的大小和计算能力。不同的分词器也会影响构建词汇表所需的时间。
- # 定义函数 build_vocab,用于基于给定的句子和分词器构建词汇表
- def build_vocab(sentences, tokenizer):
- # 初始化一个 Counter 对象,用于统计词频
- counter = Counter()
- # 遍历每一个句子
- for sentence in sentences:
- # 使用分词器将句子编码为词汇单元,并将这些单元更新到 Counter 中
- counter.update(tokenizer.encode(sentence, out_type=str))
- # 返回一个 Vocab 对象,使用 Counter 生成词汇表,并添加特殊符号
- return Vocab(counter, specials=['<unk>', '<pad>', '<bos>', '<eos>'])
- # 使用 build_vocab 函数构建日文词汇表,输入为训练集的日文句子和日文分词器
- ja_vocab = build_vocab(trainja, ja_tokenizer)
- # 使用 build_vocab 函数构建英文词汇表,输入为训练集的英文句子和英文分词器
- en_vocab = build_vocab(trainen, en_tokenizer)
在获得词汇表对象后,我们可以使用词汇表和分词器对象为训练数据构建张量。
- # 定义函数 data_process,用于将日文和英文句子转换为张量,并构建训练数据集
- def data_process(ja, en):
- # 初始化一个空列表,用于存储处理后的数据
- data = []
- # 使用 zip 函数将日文和英文句子配对,并遍历每一对句子
- for (raw_ja, raw_en) in zip(ja, en):
- # 对日文句子进行分词,并将每个词汇转换为词汇表中的索引,最终转换为长整型张量
- ja_tensor_ = torch.tensor([ja_vocab[token] for token in ja_tokenizer.encode(raw_ja.rstrip("\n"), out_type=str)],
- dtype=torch.long)
- # 对英文句子进行分词,并将每个词汇转换为词汇表中的索引,最终转换为长整型张量
- en_tensor_ = torch.tensor([en_vocab[token] for token in en_tokenizer.encode(raw_en.rstrip("\n"), out_type=str)],
- dtype=torch.long)
- # 将处理后的日文张量和英文张量作为一个元组,添加到数据列表中
- data.append((ja_tensor_, en_tensor_))
- # 返回处理后的数据列表
- return data
- # 调用 data_process 函数,处理训练集中的日文和英文句子,生成训练数据
- train_data = data_process(trainja, trainen)
DataLoader
对象我将批量大小设置为16,以防止“cuda内存不足”的错误,但这取决于多个因素,如您的机器内存容量、数据大小等等。因此,请根据您的需求自由调整批量大小(注意:PyTorch教程中使用Multi30k德英数据集将批量大小设置为128)
在这里笔者使用的是RTX4090显卡,批量设置为32一张卡可以承受,如果你有四张卡可以把批次设为128分到四张卡上去训练。
- # 定义批处理大小为8
- BATCH_SIZE = 8
- # 定义填充标记、句子开始标记和句子结束标记在词汇表中的索引
- PAD_IDX = ja_vocab['<pad>']
- BOS_IDX = ja_vocab['<bos>']
- EOS_IDX = ja_vocab['<eos>']
- # 定义生成批处理的函数
- def generate_batch(data_batch):
- # 初始化存储日语和英语句子的列表
- ja_batch, en_batch = [], []
- # 遍历每个数据对(日语句子和英语句子)
- for (ja_item, en_item) in data_batch:
- # 在每个日语句子的开始添加句子开始标记,末尾添加句子结束标记
- ja_batch.append(torch.cat([torch.tensor([BOS_IDX]), ja_item, torch.tensor([EOS_IDX])], dim=0))
- # 在每个英语句子的开始添加句子开始标记,末尾添加句子结束标记
- en_batch.append(torch.cat([torch.tensor([BOS_IDX]), en_item, torch.tensor([EOS_IDX])], dim=0))
- # 对日语句子进行填充,使其长度相同
- ja_batch = pad_sequence(ja_batch, padding_value=PAD_IDX)
- # 对英语句子进行填充,使其长度相同
- en_batch = pad_sequence(en_batch, padding_value=PAD_IDX)
- # 返回填充后的日语和英语句子批处理
- return ja_batch, en_batch
- # 定义训练数据的DataLoader
- # 设置批处理大小,打乱数据顺序,使用generate_batch函数进行批处理的生成
- train_iter = DataLoader(train_data, batch_size=BATCH_SIZE,
- shuffle=True, collate_fn=generate_batch)
这里做一点对transformer的简单介绍:
如下图所示,transformer模型同样是Encoder-Decoder结构,左边部分可以看作一个大的Encoder,同理右边可以看作一个大的Decoder,左边部分输入是一个句子(序列),经过Encoder部分编码得到的语义向量送入到右边Decoder部分。
先来说输入部分,这里的示例是“我有一只猫”。
第一步我们要对输入做如上处理,即
输入=token embedding+position embedding
这里你简单的记住就好,想知道为什么可以看看前言部分推荐的视频或者其他资料~
第二步将得到的单词表示向量矩阵 (如上图所示,每一行是一个单词的表示 x) 传入 Encoder 中,经过 6 个 Encoder block 后可以得到句子所有单词的编码信息矩阵 C,如下图。每一个 Encoder block 输出的矩阵维度与输入完全一致。
第三步:将 Encoder 输出的编码信息矩阵 C传递到 Decoder 中,Decoder 依次会根据当前翻译过的单词 1~ i 翻译下一个单词 i+1,如下图所示。在使用的过程中,翻译到单词 i+1 的时候需要通过 Mask (掩盖) 操作遮盖住 i+1 之后的单词。
上图 Decoder 接收了 Encoder 的编码矩阵 C,然后首先输入一个翻译开始符 "",预测第一个单词 "I";然后输入翻译开始符 "" 和单词 "I",预测单词 "have",以此类推。这是 Transformer 使用时候的大致流程。
这里只介绍了一个大致的流程,如果想要具体了解还请自行查阅资料。
接下来的几段代码和文本解释(用斜体标记)来自原始的PyTorch教程[https://pytorch.org/tutorials/beginner/translation_transformer.html]。
Transformer是一种Seq2Seq模型,最初在“Attention is all you need”论文中提出,用于解决机器翻译任务。Transformer模型由编码器和解码器模块组成,每个模块包含固定数量的层。
编码器通过一系列的多头注意力和前馈网络层处理输入序列。编码器的输出被称为“记忆”,将其与目标张量一起馈送给解码器。编码器和解码器通过教师强制(teacher forcing)技术进行端到端训练。
Transformer模型通过其强大的自注意力机制和前馈网络层,在处理序列到序列任务时表现出色,广泛应用于自然语言处理和机器翻译领域。
- # 定义一个序列到序列的 Transformer 模型
- class Seq2SeqTransformer(nn.Module):
- def __init__(self, num_encoder_layers: int, num_decoder_layers: int,
- emb_size: int, src_vocab_size: int, tgt_vocab_size: int,
- dim_feedforward: int = 512, dropout: float = 0.1):
- super(Seq2SeqTransformer, self).__init__()
- # 定义 Transformer 编码器层
- encoder_layer = TransformerEncoderLayer(d_model=emb_size, nhead=NHEAD,
- dim_feedforward=dim_feedforward)
- # 由多层编码器层组成的 Transformer 编码器
- self.transformer_encoder = TransformerEncoder(encoder_layer, num_layers=num_encoder_layers)
-
- # 定义 Transformer 解码器层
- decoder_layer = TransformerDecoderLayer(d_model=emb_size, nhead=NHEAD,
- dim_feedforward=dim_feedforward)
- # 由多层解码器层组成的 Transformer 解码器
- self.transformer_decoder = TransformerDecoder(decoder_layer, num_layers=num_decoder_layers)
-
- # 用于生成目标词汇的线性层
- self.generator = nn.Linear(emb_size, tgt_vocab_size)
-
- # 源语言和目标语言的词嵌入层
- self.src_tok_emb = TokenEmbedding(src_vocab_size, emb_size)
- self.tgt_tok_emb = TokenEmbedding(tgt_vocab_size, emb_size)
-
- # 位置编码层
- self.positional_encoding = PositionalEncoding(emb_size, dropout=dropout)
-
- def forward(self, src: Tensor, trg: Tensor, src_mask: Tensor,
- tgt_mask: Tensor, src_padding_mask: Tensor,
- tgt_padding_mask: Tensor, memory_key_padding_mask: Tensor):
- # 将源语言嵌入并添加位置编码
- src_emb = self.positional_encoding(self.src_tok_emb(src))
- # 将目标语言嵌入并添加位置编码
- tgt_emb = self.positional_encoding(self.tgt_tok_emb(trg))
-
- # 编码源语言
- memory = self.transformer_encoder(src_emb, src_mask, src_padding_mask)
-
- # 解码目标语言
- outs = self.transformer_decoder(tgt_emb, memory, tgt_mask, None,
- tgt_padding_mask, memory_key_padding_mask)
- # 生成最终的输出
- return self.generator(outs)
-
- def encode(self, src: Tensor, src_mask: Tensor):
- # 仅编码源语言(用于推理)
- return self.transformer_encoder(self.positional_encoding(
- self.src_tok_emb(src)), src_mask)
-
- def decode(self, tgt: Tensor, memory: Tensor, tgt_mask: Tensor):
- # 仅解码目标语言(用于推理)
- return self.transformer_decoder(self.positional_encoding(
- self.tgt_tok_emb(tgt)), memory,
- tgt_mask)
文本标记通过使用标记嵌入(token embeddings)来表示。在标记嵌入中加入位置编码(positional encoding),以引入单词顺序的概念。
- # 位置编码类,负责将位置信息添加到词嵌入中
- class PositionalEncoding(nn.Module):
- def __init__(self, emb_size: int, dropout: float, maxlen: int = 5000):
- super(PositionalEncoding, self).__init__()
- # 计算位置编码的分母部分
- den = torch.exp(- torch.arange(0, emb_size, 2) * math.log(10000) / emb_size)
- # 生成位置索引
- pos = torch.arange(0, maxlen).reshape(maxlen, 1)
- # 初始化位置编码矩阵
- pos_embedding = torch.zeros((maxlen, emb_size))
- # 为偶数索引位置计算 sin 值
- pos_embedding[:, 0::2] = torch.sin(pos * den)
- # 为奇数索引位置计算 cos 值
- pos_embedding[:, 1::2] = torch.cos(pos * den)
- # 添加一个新的维度以便与输入的词嵌入相加
- pos_embedding = pos_embedding.unsqueeze(-2)
- # 定义 dropout 层
- self.dropout = nn.Dropout(dropout)
- # 将位置编码矩阵注册为 buffer,避免其被训练
- self.register_buffer('pos_embedding', pos_embedding)
- def forward(self, token_embedding: Tensor):
- # 将位置编码添加到词嵌入并应用 dropout
- return self.dropout(token_embedding +
- self.pos_embedding[:token_embedding.size(0),:])
- # 词嵌入类,负责将输入的 tokens 转换为对应的词嵌入
- class TokenEmbedding(nn.Module):
- def __init__(self, vocab_size: int, emb_size: int):
- super(TokenEmbedding, self).__init__()
-
- # 定义词嵌入层
- self.embedding = nn.Embedding(vocab_size, emb_size)
- self.emb_size = emb_size
-
- def forward(self, tokens: Tensor):
- # 获取词嵌入并乘以嵌入维度的平方根
- return self.embedding(tokens.long()) * math.sqrt(self.emb_size)
我们创建一个后续词掩码(subsequent word mask),用于阻止目标词关注其后续词。我们还创建了用于掩盖源和目标填充标记的掩码。
- # 生成后续位置掩码,用于解码器的自注意力机制,防止模型看到未来的信息
- def generate_square_subsequent_mask(sz):
- # 创建一个上三角矩阵,并将对角线及其上方的位置设为1,其余位置为0
- mask = (torch.triu(torch.ones((sz, sz), device=device)) == 1).transpose(0, 1)
- # 将掩码转换为 float 类型,将 0 的位置填充为负无穷(-inf),1 的位置填充为0.0
- mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
- return mask
-
- # 创建源和目标的掩码,用于处理填充标记和自注意力
- def create_mask(src, tgt):
- # 获取源和目标序列的长度
- src_seq_len = src.shape[0]
- tgt_seq_len = tgt.shape[0]
-
- # 为目标序列生成后续位置掩码
- tgt_mask = generate_square_subsequent_mask(tgt_seq_len)
- # 为源序列生成全零掩码(源序列不需要后续位置掩码)
- src_mask = torch.zeros((src_seq_len, src_seq_len), device=device).type(torch.bool)
-
- # 创建源和目标的填充掩码,将填充值位置设为 True,其余位置为 False
- src_padding_mask = (src == PAD_IDX).transpose(0, 1)
- tgt_padding_mask = (tgt == PAD_IDX).transpose(0, 1)
-
- return src_mask, tgt_mask, src_padding_mask, tgt_padding_mask
下面的超参数可以视情况更改,如果你的设备足够出色,编码器和解码器的层数你可以变得更多,多头注意力的头数你也可以试试设置更大,看看这些更改对你的结果会产生什么不一样的影响。
- # 定义一些超参数
- SRC_VOCAB_SIZE = len(ja_vocab) # 源语言词汇表大小
- TGT_VOCAB_SIZE = len(en_vocab) # 目标语言词汇表大小
- EMB_SIZE = 512 # 词嵌入维度
- NHEAD = 8 # 多头注意力的头数
- FFN_HID_DIM = 512 # 前馈网络的维度
- BATCH_SIZE = 16 # 批量大小
- NUM_ENCODER_LAYERS = 3 # 编码器层数
- NUM_DECODER_LAYERS = 3 # 解码器层数
- NUM_EPOCHS = 16 # 训练的轮数
-
- # 初始化 Seq2SeqTransformer 模型
- transformer = Seq2SeqTransformer(NUM_ENCODER_LAYERS, NUM_DECODER_LAYERS,
- EMB_SIZE, SRC_VOCAB_SIZE, TGT_VOCAB_SIZE,
- FFN_HID_DIM)
-
- # 初始化模型参数
- for p in transformer.parameters():
- if p.dim() > 1:
- nn.init.xavier_uniform_(p)
- # 将模型移到设备(例如 GPU)上
- transformer = transformer.to(device)
- # 定义损失函数,忽略填充标记的损失
- loss_fn = torch.nn.CrossEntropyLoss(ignore_index=PAD_IDX)
- # 定义优化器,使用 Adam 优化算法
- optimizer = torch.optim.Adam(
- transformer.parameters(), lr=0.0001, betas=(0.9, 0.98), eps=1e-9
- )
- # 训练一个 epoch 的函数
- def train_epoch(model, train_iter, optimizer):
- model.train() # 设置模型为训练模式
- losses = 0
- for idx, (src, tgt) in enumerate(train_iter):
- src = src.to(device) # 将源序列移到设备上
- tgt = tgt.to(device) # 将目标序列移到设备上
- tgt_input = tgt[:-1, :] # 目标输入序列,不包括最后一个标记
- # 创建掩码
- src_mask, tgt_mask, src_padding_mask, tgt_padding_mask = create_mask(src, tgt_input)
- # 前向传播
- logits = model(src, tgt_input, src_mask, tgt_mask,
- src_padding_mask, tgt_padding_mask, src_padding_mask)
- optimizer.zero_grad() # 清除梯度
- tgt_out = tgt[1:, :] # 目标输出序列,不包括第一个标记
- loss = loss_fn(logits.reshape(-1, logits.shape[-1]), tgt_out.reshape(-1)) # 计算损失
- loss.backward() # 反向传播
- optimizer.step() # 更新模型参数
- losses += loss.item() # 累加损失
- return losses / len(train_iter) # 返回平均损失
-
- # 评估模型的函数
- def evaluate(model, val_iter):
- model.eval() # 设置模型为评估模式
- losses = 0
- for idx, (src, tgt) in enumerate(val_iter):
- src = src.to(device) # 将源序列移到设备上
- tgt = tgt.to(device) # 将目标序列移到设备上
-
- tgt_input = tgt[:-1, :] # 目标输入序列,不包括最后一个标记
-
- # 创建掩码
- src_mask, tgt_mask, src_padding_mask, tgt_padding_mask = create_mask(src, tgt_input)
-
- # 前向传播
- logits = model(src, tgt_input, src_mask, tgt_mask,
- src_padding_mask, tgt_padding_mask, src_padding_mask)
-
- tgt_out = tgt[1:, :] # 目标输出序列,不包括第一个标记
- loss = loss_fn(logits.reshape(-1, logits.shape[-1]), tgt_out.reshape(-1)) # 计算损失
- losses += loss.item() # 累加损失
- return losses / len(val_iter) # 返回平均损失
最后,在准备好必要的类和函数之后,我们就可以开始训练我们的模型了。毫无疑问,训练完成所需的时间可能会因计算能力、模型参数以及数据集大小等因素而大不相同。
当我使用来自JParaCrawl的完整句子列表进行训练时,每种语言大约有590万句子,我使用单个RTX4090显卡,训练16轮使用了36分钟。
以下是代码示例:
- # tqdm是一个快速,可扩展的Python进度条库,通常用于长循环
- for epoch in tqdm.tqdm(range(1, NUM_EPOCHS + 1)):
- # 记录每个epoch开始的时间
- start_time = time.time()
- # 执行训练过程,并计算训练损失
- # train_epoch 是一个函数,用于进行一个epoch的训练,并返回训练损失
- # transformer 是模型,train_iter 是训练数据迭代器,optimizer 是优化器
- train_loss = train_epoch(transformer, train_iter, optimizer)
- # 记录每个epoch结束的时间
- end_time = time.time()
- # 打印当前epoch的编号,训练损失和耗时
- print((f"Epoch: {epoch}, Train loss: {train_loss:.3f}, "
- f"Epoch time = {(end_time - start_time):.3f}s"))
训练过程如下所示:
首先,我们创建函数来翻译新的句子,包括以下步骤:获取日语句子、分词、转换为张量、推理,然后将结果解码成英语句子。
- def greedy_decode(model, src, src_mask, max_len, start_symbol):
- # 将输入数据和掩码移到设备上(通常是GPU)
- src = src.to(device)
- src_mask = src_mask.to(device)
- # 编码阶段:通过模型的编码器得到编码后的表示
- memory = model.encode(src, src_mask)
- # 初始化目标序列,以start_symbol(通常是<BOS>)开始
- ys = torch.ones(1, 1).fill_(start_symbol).type(torch.long).to(device)
-
- # 循环生成每个单词,最多生成max_len - 1个单词
- for i in range(max_len - 1):
- # 将memory移到设备上
- memory = memory.to(device)
-
- # 创建memory_mask,大小为(当前生成序列长度, memory长度),初始化为全零
- memory_mask = torch.zeros(ys.shape[0], memory.shape[0]).to(device).type(torch.bool)
-
- # 生成目标序列的掩码
- tgt_mask = (generate_square_subsequent_mask(ys.size(0)).type(torch.bool)).to(device)
-
- # 通过模型的解码器得到输出
- out = model.decode(ys, memory, tgt_mask)
-
- # 调整输出的维度
- out = out.transpose(0, 1)
-
- # 通过生成器得到下一个单词的概率分布
- prob = model.generator(out[:, -1])
-
- # 选择概率最大的单词作为下一个单词
- _, next_word = torch.max(prob, dim=1)
- next_word = next_word.item()
-
- # 将下一个单词添加到目标序列中
- ys = torch.cat([ys, torch.ones(1, 1).type_as(src.data).fill_(next_word)], dim=0)
-
- # 如果下一个单词是结束符(EOS),则停止生成
- if next_word == EOS_IDX:
- break
-
- # 返回生成的目标序列
- return ys
-
- def translate(model, src, src_vocab, tgt_vocab, src_tokenizer):
- # 设置模型为评估模式
- model.eval()
-
- # 将输入句子用开始符(BOS)和结束符(EOS)包裹起来,并将其转换为token序列
- tokens = [BOS_IDX] + [src_vocab.stoi[tok] for tok in src_tokenizer.encode(src, out_type=str)] + [EOS_IDX]
- num_tokens = len(tokens)
-
- # 将token序列转换为张量,并调整维度
- src = torch.LongTensor(tokens).reshape(num_tokens, 1)
-
- # 创建源序列的掩码,全零
- src_mask = torch.zeros(num_tokens, num_tokens).type(torch.bool)
-
- # 调用贪婪解码器生成目标序列
- tgt_tokens = greedy_decode(model, src, src_mask, max_len=num_tokens + 5, start_symbol=BOS_IDX).flatten()
-
- # 将生成的token序列转换为对应的单词,并去掉<BOS>和<EOS>
- return " ".join([tgt_vocab.itos[tok] for tok in tgt_tokens]).replace("<bos>", "").replace("<eos>", "")
然后,我们只需调用翻译函数并传递所需的参数。
- translate(transformer, "HSコード 8515 はんだ付け用、ろう付け用又は溶接用の機器(電気式(電気加熱ガス式を含む。)", ja_vocab, en_vocab, ja_tokenizer)
- trainen.pop(5)
- trainja.pop(5)
最后,在训练完成后,我们将首先使用Pickle保存词汇表对象(en_vocab和ja_vocab)。
- # 打开一个文件,以二进制写模式准备存储数据
- file = open('en_vocab.pkl', 'wb')
- # 使用pickle模块将en_vocab数据对象序列化并存储到文件中
- pickle.dump(en_vocab, file)
- # 关闭文件
- file.close()
-
- # 再次打开一个文件,以二进制写模式准备存储数据
- file = open('ja_vocab.pkl', 'wb')
- # 使用pickle模块将ja_vocab数据对象序列化并存储到文件中
- pickle.dump(ja_vocab, file)
- # 关闭文件
- file.close()
最后,我们还可以使用PyTorch的保存和加载函数将模型保存起来以便日后使用。一般来说,根据日后使用模型的目的,有两种保存模型的方式。第一种适用于仅推理的情况,我们可以加载模型并用它来进行日语到英语的翻译。
- # save model for inference
- torch.save(transformer.state_dict(), 'inference_model')
第二种方式也适用于推理,但同时也适用于当我们想要加载模型并恢复训练时的情况。
- # 保存模型和检查点,以便以后恢复训练
- torch.save({
- # 当前训练的epoch数
- 'epoch': NUM_EPOCHS,
- # 模型的状态字典,包含模型的所有参数
- 'model_state_dict': transformer.state_dict(),
- # 优化器的状态字典,包含优化器的所有参数
- 'optimizer_state_dict': optimizer.state_dict(),
- # 当前训练的损失值
- 'loss': train_loss,
- # 保存到文件'model_checkpoint.tar'中
- }, 'model_checkpoint.tar')
以下是笔者保存的文件及路径:
笔者能力有限,如果有不理解的地方请见谅。
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。