赞
踩
Transformer模型是一种深度学习模型,由Vaswani等人在2017年提出,主要用于自然语言处理(NLP)任务。它的核心思想是通过自注意力(Self-Attention)机制来捕捉输入数据之间的全局依赖关系,从而能够处理序列数据,如文本。
首先,让我们确保我们的系统中安装了以下软件包,如果发现缺少某些软件包,请确保安装它们。
- # 导入数学库
- import math
- # 导入torchtext库
- import torchtext
- # 导入torch库
- import torch
- # 导入torch.nn库
- import torch.nn as nn
- # 从torch中导入Tensor类
- from torch import Tensor
- # 从torch.nn.utils.rnn中导入pad_sequence函数
- from torch.nn.utils.rnn import pad_sequence
- # 从torch.utils.data中导入DataLoader类
- from torch.utils.data import DataLoader
- # 导入collections库中的Counter类
- from collections import Counter
- # 从torchtext.vocab中导入Vocab类
- from torchtext.vocab import Vocab
- # 从torch.nn中导入TransformerEncoder, TransformerDecoder, TransformerEncoderLayer, TransformerDecoderLayer类
- from torch.nn import TransformerEncoder, TransformerDecoder, TransformerEncoderLayer, TransformerDecoderLayer
- # 导入io库
- import io
- # 导入time库
- import time
- # 导入pandas库
- import pandas as pd
- # 导入numpy库
- import numpy as np
- # 导入pickle库
- import pickle
- # 导入tqdm库
- import tqdm
- # 导入sentencepiece库
- import sentencepiece as spm
- # 设置随机种子
- torch.manual_seed(0)
- # 判断是否有可用的GPU,如果有则使用GPU,否则使用CPU
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
- # print(torch.cuda.get_device_name(0)) ## 如果你有GPU,请在你自己的电脑上尝试运行这一套代码
device #在编程中,可以使用 device 变量来指定使用哪个设备进行计算。例如,在 PyTorch 中,可以使用 device 变量将张量移动到 GPU 或 CPU 上进行计算。
运行结果:
在本教程中,我们将使用从JParaCrawl[http://www.kecl.ntt.co.jp/icl/lirg/jparacrawl]下载的日英平行数据集,该数据集被描述为“由NTT创建的最大的可公开获取的英日平行语料库。它是通过在网上大规模爬取并自动对齐平行句子而创建的。”
- # 读取名为'zh-ja.bicleaner05.txt'的文件,使用制表符作为分隔符,指定引擎为Python,不包含表头信息
- df = pd.read_csv('./zh-ja/zh-ja.bicleaner05.txt', sep='\\t', engine='python', header=None)
- # 将第3列的数据转换为列表,赋值给trainen变量
- trainen = df[2].values.tolist()#[:10000]
- # 将第4列的数据转换为列表,赋值给trainja变量
- trainja = df[3].values.tolist()#[:10000]
- # trainen.pop(5972)
- # trainja.pop(5972)
导入所有日语及其英文对应文本后,我删除了数据集中的最后一条数据,因为它有一个缺失值。总共,在trainen和trainja中的句子数量为5,973,071条,然而,为了学习目的,通常建议对数据进行抽样,并确保一切按预期工作正常,然后再一次使用所有数据,以节省时间。
以下是数据集中包含的一句话的示例。
- print(trainen[500])
- print(trainja[500])
运行结果:
我们还可以使用不同的平行数据集来跟随本文,只需确保我们可以将数据处理成上面所示的两个字符串列表,其中包含日语和英语句子。
与英语或其他字母语言不同,日语句子中没有空格来分隔单词。我们可以使用由JParaCrawl提供的标记工具,该工具使用SentencePiece为日语和英语创建,您可以访问JParaCrawl网站下载它们,或点击这里。
- en_tokenizer = spm.SentencePieceProcessor(model_file='enja_spm_models/spm.en.nopretok.model') #en_tokenizer用于处理英文文本
- ja_tokenizer = spm.SentencePieceProcessor(model_file='enja_spm_models/spm.ja.nopretok.model') #ja_tokenizer用于处理日文文本
- #这两个分词器都使用了预训练的模型文件,分别为enja_spm_models/spm.en.nopretok.model和enja_spm_models/spm.ja.nopretok.model
载入分词器之后,你可以通过执行以下代码来测试它们。
en_tokenizer.encode("All residents aged 20 to 59 years who live in Japan must enroll in public pension system.", out_type='str')
ja_tokenizer.encode("年金 日本に住んでいる20歳~60歳の全ての人は、公的年金制度に加入しなければなりません。", out_type='str')
使用标记器和原始句子,然后构建从TorchText导入的词汇对象。这个过程可能需要几秒钟或几分钟,这取决于我们数据集的大小和计算能力。不同的标记器也会影响构建词汇所需的时间,我尝试了几种其他的日语标记器,但SentencePiece 似乎对我来说运行得很好且足够快。
- # 定义一个函数,用于构建词汇表
- def build_vocab(sentences, tokenizer):
- # 创建一个计数器对象
- counter = Counter()
- # 遍历输入的句子列表
- for sentence in sentences:
- # 使用分词器对句子进行编码,并将结果更新到计数器中
- counter.update(tokenizer.encode(sentence, out_type=str))
- # 返回一个词汇表对象,其中包含了计数器中的词汇,以及特殊符号:'<unk>', '<pad>', '<bos>', '<eos>'
- return Vocab(counter, specials=['<unk>', '<pad>', '<bos>', '<eos>'])
- # 使用训练集和分词器构建日语词汇表
- ja_vocab = build_vocab(trainja, ja_tokenizer)
- # 使用训练集和分词器构建英语词汇表
- en_vocab = build_vocab(trainen, en_tokenizer)
有了词汇对象之后,我们就能够利用词汇表和分词器对象来构建训练数据的张量。
- # 定义一个名为data_process的函数,接收两个参数:ja和en
- def data_process(ja, en):
- # 初始化一个空列表data
- data = []
- # 使用zip函数将ja和en中的元素一一对应地组合在一起,然后遍历这些组合
- for (raw_ja, raw_en) in zip(ja, en):
- # 对原始的日语文本进行分词,并将分词结果转换为词汇表中对应的索引值,然后将这些索引值转换为LongTensor类型的张量
- ja_tensor_ = torch.tensor([ja_vocab[token] for token in ja_tokenizer.encode(raw_ja.rstrip("\n"), out_type=str)],
- dtype=torch.long)
- # 对原始的英语文本进行分词,并将分词结果转换为词汇表中对应的索引值,然后将这些索引值转换为LongTensor类型的张量
- en_tensor_ = torch.tensor([en_vocab[token] for token in en_tokenizer.encode(raw_en.rstrip("\n"), out_type=str)],
- dtype=torch.long)
- # 将处理好的日语和英语张量作为元组添加到data列表中
- data.append((ja_tensor_, en_tensor_))
- # 返回处理后的数据列表
- return data
- # 调用data_process函数处理训练数据,并将结果赋值给train_data变量
- train_data = data_process(trainja, trainen)
这里,我将BATCH_SIZE设置为16,以防止“cuda内存不足”的问题,但这取决于诸如您的机器内存容量、数据大小等各种因素,因此根据您的需要随意更改批处理大小(注:PyTorch的教程使用Multi30k德英数据集将批处理大小设置为128)。
- # 设置批处理大小为8
- BATCH_SIZE = 8
- # 设置填充索引为日语词汇表中的<pad>对应的索引
- PAD_IDX = ja_vocab['<pad>']
- # 设置开始符号索引为日语词汇表中的<bos>对应的索引
- BOS_IDX = ja_vocab['<bos>']
- # 设置结束符号索引为日语词汇表中的<eos>对应的索引
- EOS_IDX = ja_vocab['<eos>']
- # 定义一个生成批处理数据的函数,输入为一个数据批次
- def generate_batch(data_batch):
- # 初始化两个空列表,分别用于存储日语和英语的数据
- ja_batch, en_batch = [], []
- # 遍历输入的数据批次中的每个元素(日语和英语的句子对)
- for (ja_item, en_item) in data_batch:
- # 将开始符号、日语句子、结束符号拼接起来,并添加到日语数据列表中
- ja_batch.append(torch.cat([torch.tensor([BOS_IDX]), ja_item, torch.tensor([EOS_IDX])], dim=0))
- # 将开始符号、英语句子、结束符号拼接起来,并添加到英语数据列表中
- en_batch.append(torch.cat([torch.tensor([BOS_IDX]), en_item, torch.tensor([EOS_IDX])], dim=0))
- # 对日语数据列表进行填充,使其长度相同,并用填充值替换短于最长句子的部分
- ja_batch = pad_sequence(ja_batch, padding_value=PAD_IDX)
- # 对英语数据列表进行填充,使其长度相同,并用填充值替换短于最长句子的部分
- en_batch = pad_sequence(en_batch, padding_value=PAD_IDX)
- # 返回填充后的日语和英语数据列表
- return ja_batch, en_batch
- # 使用DataLoader加载训练数据,设置批处理大小为8,打乱数据顺序,并使用generate_batch函数作为collate_fn参数
- train_iter = DataLoader(train_data, batch_size=BATCH_SIZE,
- shuffle=True, collate_fn=generate_batch)
下面几行代码和文本解释(用斜体)取自原始的PyTorch教程[https://pytorch.org/tutorials/beginner/translation_transformer.html]。我除了将BATCH_SIZE和单词de_vocab更改为ja_vocab外,没有做任何修改。
Transformer是一种Seq2Seq模型,介绍了“Attention is all you need”文档,用于解决机器翻译任务。Transformer模型包括一个编码器和一个解码器块,每个块包含固定数量的层。
编码器通过将输入序列传播到一系列多头注意力和前馈网络层来处理输入序列。编码器的输出称为内存,将其与目标张量一起馈送到解码器中。编码器和解码器使用教师强制技术进行端到端训练。
- from torch.nn import (TransformerEncoder, TransformerDecoder,
- TransformerEncoderLayer, TransformerDecoderLayer)
-
-
- class Seq2SeqTransformer(nn.Module):
- def __init__(self, num_encoder_layers: int, num_decoder_layers: int,
- emb_size: int, src_vocab_size: int, tgt_vocab_size: int,
- dim_feedforward:int = 512, dropout:float = 0.1):
- # 初始化Seq2SeqTransformer类,设置编码器和解码器层数、词嵌入大小、源语言词汇表大小、目标语言词汇表大小、前馈神经网络维度和dropout比例
- super(Seq2SeqTransformer, self).__init__()
- # 创建编码器层
- encoder_layer = TransformerEncoderLayer(d_model=emb_size, nhead=NHEAD,
- dim_feedforward=dim_feedforward)
- # 创建编码器
- self.transformer_encoder = TransformerEncoder(encoder_layer, num_layers=num_encoder_layers)
- # 创建解码器层
- decoder_layer = TransformerDecoderLayer(d_model=emb_size, nhead=NHEAD,
- dim_feedforward=dim_feedforward)
- # 创建解码器
- self.transformer_decoder = TransformerDecoder(decoder_layer, num_layers=num_decoder_layers)
-
- # 创建生成器,用于将解码器的输出映射到目标语言词汇表大小
- self.generator = nn.Linear(emb_size, tgt_vocab_size)
- # 创建源语言和目标语言的词嵌入层
- self.src_tok_emb = TokenEmbedding(src_vocab_size, emb_size)
- self.tgt_tok_emb = TokenEmbedding(tgt_vocab_size, emb_size)
- # 创建位置编码层
- self.positional_encoding = PositionalEncoding(emb_size, dropout=dropout)
-
- def forward(self, src: Tensor, trg: Tensor, src_mask: Tensor,
- tgt_mask: Tensor, src_padding_mask: Tensor,
- tgt_padding_mask: Tensor, memory_key_padding_mask: Tensor):
- # 前向传播函数,输入源语言序列、目标语言序列、源语言掩码、目标语言掩码、源语言填充掩码、目标语言填充掩码和记忆键填充掩码
- # 对源语言序列进行词嵌入和位置编码
- src_emb = self.positional_encoding(self.src_tok_emb(src))
- # 对目标语言序列进行词嵌入和位置编码
- tgt_emb = self.positional_encoding(self.tgt_tok_emb(trg))
- # 使用编码器对源语言序列进行编码,得到记忆
- memory = self.transformer_encoder(src_emb, src_mask, src_padding_mask)
- # 使用解码器对目标语言序列进行解码,得到输出
- outs = self.transformer_decoder(tgt_emb, memory, tgt_mask, None,
- tgt_padding_mask, memory_key_padding_mask)
- # 通过生成器将解码器的输出映射到目标语言词汇表大小,得到最终输出
- return self.generator(outs)
-
- def encode(self, src: Tensor, src_mask: Tensor):
- # 编码函数,输入源语言序列和源语言掩码,返回编码后的记忆
- return self.transformer_encoder(self.positional_encoding(
- self.src_tok_emb(src)), src_mask)
-
- def decode(self, tgt: Tensor, memory: Tensor, tgt_mask: Tensor):
- # 解码函数,输入目标语言序列、记忆和目标语言掩码,返回解码后的输出
- return self.transformer_decoder(self.positional_encoding(
- self.tgt_tok_emb(tgt)), memory,
- tgt_mask)
文本标记通过使用令牌嵌入表示。为了引入单词顺序的概念,位置编码被添加到令牌嵌入中。
- class PositionalEncoding(nn.Module):
- def __init__(self, emb_size: int, dropout, maxlen: int = 5000):
- # 初始化位置编码类,输入参数为词嵌入维度、dropout比例和最大序列长度
- super(PositionalEncoding, self).__init__()
- # 计算位置编码的分母部分
- den = torch.exp(- torch.arange(0, emb_size, 2) * math.log(10000) / emb_size)
- # 生成位置序列
- pos = torch.arange(0, maxlen).reshape(maxlen, 1)
- # 初始化位置嵌入矩阵
- pos_embedding = torch.zeros((maxlen, emb_size))
- # 计算位置嵌入矩阵的偶数列
- pos_embedding[:, 0::2] = torch.sin(pos * den)
- # 计算位置嵌入矩阵的奇数列
- pos_embedding[:, 1::2] = torch.cos(pos * den)
- # 增加一个维度
- pos_embedding = pos_embedding.unsqueeze(-2)
-
- # 定义dropout层
- self.dropout = nn.Dropout(dropout)
- # 注册位置嵌入矩阵为buffer,不需要梯度更新
- self.register_buffer('pos_embedding', pos_embedding)
-
- def forward(self, token_embedding: Tensor):
- # 前向传播,将位置嵌入矩阵与词嵌入相加,然后应用dropout
- return self.dropout(token_embedding +
- self.pos_embedding[:token_embedding.size(0),:])
-
- class TokenEmbedding(nn.Module):
- def __init__(self, vocab_size: int, emb_size):
- # 初始化词嵌入类,输入参数为词汇表大小和词嵌入维度
- super(TokenEmbedding, self).__init__()
- # 定义词嵌入层
- self.embedding = nn.Embedding(vocab_size, emb_size)
- # 记录词嵌入维度
- self.emb_size = emb_size
- def forward(self, tokens: Tensor):
- # 前向传播,将输入的token转换为对应的词嵌入向量,并进行归一化处理
- return self.embedding(tokens.long()) * math.sqrt(self.emb_size)
我们创建一个后续的单词掩码,防止目标单词关注其后续的单词。我们还创建掩码,用于掩盖源和目标的填充标记。
- # 定义一个函数,用于生成一个上三角矩阵,其中对角线以下的元素为负无穷,对角线以上的元素为0
- def generate_square_subsequent_mask(sz):
- # 创建一个全1的矩阵,然后使用torch.triu()函数将其转换为上三角矩阵
- mask = (torch.triu(torch.ones((sz, sz), device=device)) == 1).transpose(0, 1)
- # 将矩阵中的元素类型转换为float,并将0替换为负无穷,1替换为0
- mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
- return mask
- # 定义一个函数,用于创建源序列和目标序列的掩码
- def create_mask(src, tgt):
- # 获取源序列和目标序列的长度
- src_seq_len = src.shape[0]
- tgt_seq_len = tgt.shape[0]
- # 生成目标序列的掩码
- tgt_mask = generate_square_subsequent_mask(tgt_seq_len)
- # 创建一个全0的矩阵,用于表示源序列的掩码
- src_mask = torch.zeros((src_seq_len, src_seq_len), device=device).type(torch.bool)
- # 获取源序列和目标序列的填充掩码
- src_padding_mask = (src == PAD_IDX).transpose(0, 1)
- tgt_padding_mask = (tgt == PAD_IDX).transpose(0, 1)
- return src_mask, tgt_mask, src_padding_mask, tgt_padding_mask
Define model parameters and instantiate model. 这里我们服务器实在是计算能力有限,按照以下配置可以训练但是效果应该是不行的。如果想要看到训练的效果请使用你自己的带GPU的电脑运行这一套代码。
当你使用自己的GPU的时候,NUM_ENCODER_LAYERS 和 NUM_DECODER_LAYERS 设置为3或者更高,NHEAD设置8,EMB_SIZE设置为512。
- # 设置源语言和目标语言的词汇表大小
- SRC_VOCAB_SIZE = len(ja_vocab)
- TGT_VOCAB_SIZE = len(en_vocab)
- # 设置模型参数
- EMB_SIZE = 512
- NHEAD = 8
- FFN_HID_DIM = 512
- BATCH_SIZE = 16
- NUM_ENCODER_LAYERS = 3
- NUM_DECODER_LAYERS = 3
- NUM_EPOCHS = 16
- # 创建Transformer模型
- transformer = Seq2SeqTransformer(NUM_ENCODER_LAYERS, NUM_DECODER_LAYERS,
- EMB_SIZE, SRC_VOCAB_SIZE, TGT_VOCAB_SIZE,
- FFN_HID_DIM)
- # 初始化模型参数
- for p in transformer.parameters():
- if p.dim() > 1:
- nn.init.xavier_uniform_(p)
- # 将模型放到设备上(GPU或CPU)
- transformer = transformer.to(device)
- # 定义损失函数
- loss_fn = torch.nn.CrossEntropyLoss(ignore_index=PAD_IDX)
- # 定义优化器
- optimizer = torch.optim.Adam(
- transformer.parameters(), lr=0.0001, betas=(0.9, 0.98), eps=1e-9
- )
-
- # 训练一个epoch
- def train_epoch(model, train_iter, optimizer):
- model.train()
- losses = 0
- for idx, (src, tgt) in enumerate(train_iter):
- src = src.to(device)
- tgt = tgt.to(device)
-
- tgt_input = tgt[:-1, :]
-
- src_mask, tgt_mask, src_padding_mask, tgt_padding_mask = create_mask(src, tgt_input)
-
- logits = model(src, tgt_input, src_mask, tgt_mask,
- src_padding_mask, tgt_padding_mask, src_padding_mask)
-
- optimizer.zero_grad()
-
- tgt_out = tgt[1:,:]
- loss = loss_fn(logits.reshape(-1, logits.shape[-1]), tgt_out.reshape(-1))
- loss.backward()
-
- optimizer.step()
- losses += loss.item()
- return losses / len(train_iter)
-
- # 评估模型
- def evaluate(model, val_iter):
- model.eval()
- losses = 0
- for idx, (src, tgt) in (enumerate(valid_iter)):
- src = src.to(device)
- tgt = tgt.to(device)
-
- tgt_input = tgt[:-1, :]
-
- src_mask, tgt_mask, src_padding_mask, tgt_padding_mask = create_mask(src, tgt_input)
-
- logits = model(src, tgt_input, src_mask, tgt_mask,
- src_padding_mask, tgt_padding_mask, src_padding_mask)
- tgt_out = tgt[1:,:]
- loss = loss_fn(logits.reshape(-1, logits.shape[-1]), tgt_out.reshape(-1))
- losses += loss.item()
- return losses / len(val_iter)
终于,在准备好必要的课程和功能之后,我们准备好训练我们的模型了。毋庸置疑,但训练所需的时间可能会因许多因素而大不相同,比如计算能力、参数和数据集的大小。
我使用了来自JParaCrawl的完整句子列表来训练模型,该列表每种语言约有590万个句子,使用单个NVIDIA GeForce RTX 3070 GPU每个时代大约需要5个小时。
以下是代码:
- # 使用tqdm库显示训练进度条
- for epoch in tqdm.tqdm(range(1, NUM_EPOCHS+1)):
- # 记录当前时间,用于计算每个epoch的耗时
- start_time = time.time()
-
- # 调用train_epoch函数进行一轮训练,返回训练损失值
- train_loss = train_epoch(transformer, train_iter, optimizer)
- # 记录当前时间,用于计算每个epoch的耗时
- end_time = time.time()
- # 打印当前轮次、训练损失值和耗时信息
- print((f"Epoch: {epoch}, Train loss: {train_loss:.3f}, "
- f"Epoch time = {(end_time - start_time):.3f}s"))
尝试使用训练模型翻译一个日语句子,首先,我们创建函数来翻译一句新的句子,包括获取日语句子、分词、转换为张量、推理,然后将结果解码回一句句子,但这次是用英语。
- # 定义贪婪解码函数,输入模型、源序列、源序列掩码、最大长度和起始符号
- def greedy_decode(model, src, src_mask, max_len, start_symbol):
- # 将源序列和源序列掩码放到设备上
- src = src.to(device)
- src_mask = src_mask.to(device)
- # 对源序列进行编码,得到内存表示
- memory = model.encode(src, src_mask)
- # 初始化目标序列,将起始符号作为第一个元素
- ys = torch.ones(1, 1).fill_(start_symbol).type(torch.long).to(device)
- # 循环生成目标序列
- for i in range(max_len-1):
- # 将内存表示放到设备上
- memory = memory.to(device)
- # 创建目标序列掩码
- memory_mask = torch.zeros(ys.shape[0], memory.shape[0]).to(device).type(torch.bool)
- tgt_mask = (generate_square_subsequent_mask(ys.size(0))
- .type(torch.bool)).to(device)
- # 对目标序列进行解码,得到输出
- out = model.decode(ys, memory, tgt_mask)
- # 转置输出
- out = out.transpose(0, 1)
- # 通过生成器计算概率分布
- prob = model.generator(out[:, -1])
- # 选择概率最大的单词作为下一个单词
- _, next_word = torch.max(prob, dim = 1)
- next_word = next_word.item()
- # 将下一个单词添加到目标序列中
- ys = torch.cat([ys,
- torch.ones(1, 1).type_as(src.data).fill_(next_word)], dim=0)
- # 如果遇到结束符,则停止生成
- if next_word == EOS_IDX:
- break
- return ys
- # 定义翻译函数,输入模型、源序列、源词汇表、目标词汇表和源分词器
- def translate(model, src, src_vocab, tgt_vocab, src_tokenizer):
- # 对源序列进行分词,并添加起始符和结束符
- model.eval()
- tokens = [BOS_IDX] + [src_vocab.stoi[tok] for tok in src_tokenizer.encode(src, out_type=str)]+ [EOS_IDX]
- # 将分词后的序列转换为张量,并创建源序列掩码
- num_tokens = len(tokens)
- src = (torch.LongTensor(tokens).reshape(num_tokens, 1) )
- src_mask = (torch.zeros(num_tokens, num_tokens)).type(torch.bool)
- # 使用贪婪解码生成目标序列
- tgt_tokens = greedy_decode(model, src, src_mask, max_len=num_tokens + 5, start_symbol=BOS_IDX).flatten()
- # 将目标序列转换为字符串,并去掉起始符和结束符
- return " ".join([tgt_vocab.itos[tok] for tok in tgt_tokens]).replace("<bos>", "").replace("<eos>", "")
例1:
translate(transformer, "HSコード 8515 はんだ付け用、ろう付け用又は溶接用の機器(電気式(電気加熱ガス式を含む。)", ja_vocab, en_vocab, ja_tokenizer)
运行结果:
例2:
trainen.pop(5)
运行结果:
例3:
trainja.pop(5)
运行结果:
最后, 培训完成后, 我们将首先使用Pickle保存Vocab对象(en_vocab和ja_vocab)
- # 导入pickle模块
- import pickle
- # 以二进制写入模式打开一个名为'en_vocab.pkl'的文件
- file = open('en_vocab.pkl', 'wb')
- # 将en_vocab变量的数据存储到文件中
- pickle.dump(en_vocab, file)
- # 关闭文件
- file.close()
- # 以二进制写入模式打开一个名为'ja_vocab.pkl'的文件
- file = open('ja_vocab.pkl', 'wb')
- # 将ja_vocab变量的数据存储到文件中
- pickle.dump(ja_vocab, file)
- # 关闭文件
- file.close()
最后,我们还可以使用PyTorch的保存和加载函数来保存模型以备以后使用。通常,根据我们以后想要使用模型的方式,有两种保存模型的方法。第一种是仅用于推断,可以稍后加载模型并将其用于从日语翻译成英语。
- # save model for inference
- torch.save(transformer.state_dict(), 'inference_model')
第二个参数也用于推断,但是还用于当我们想要稍后加载模型并恢复训练时。
- # save model + checkpoint to resume training later
- torch.save({
- 'epoch': NUM_EPOCHS,
- 'model_state_dict': transformer.state_dict(),
- 'optimizer_state_dict': optimizer.state_dict(),
- 'loss': train_loss,
- }, 'model_checkpoint.tar')
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。