当前位置:   article > 正文

PyTorch实战:基于Seq2seq模型处理机器翻译任务(模型预测)_pytorch使用seq2seq模

pytorch使用seq2seq模

引言

随着全球化的深入,翻译需求日益增长。传统的人工翻译方式虽然质量高,但效率低,成本高。机器翻译的出现,为解决这一问题提供了可能。英译中机器翻译任务是机器翻译领域的一个重要分支,旨在将英文文本自动翻译成中文。本博客以《PyTorch自然语言处理入门与实战》第九章的Seq2seq模型处理英译中翻译任务作为基础,附上模型预测模块。

模型的训练及验证模块的详细解析见PyTorch实战:基于Seq2seq模型处理机器翻译任务(模型训练及验证)

数据预处理

加载字典对象en2idzh2id

在预测阶段中,需要加载模型训练及验证阶段保存的字典对象en2idzh2id

代码如下:

import pickle

with open("en2id.pkl", 'rb') as f:
    en2id = pickle.load(f)
with open("zh2id.pkl", 'rb') as f:
    zh2id = pickle.load(f)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6

文本分词

在对输入文本进行预测时,需要先将文本进行分词操作。参考代码如下:

def extract_words(sentence):  
    """  
    从给定的英文句子中提取单词,并去除单词后的标点符号。  
      
    Args:  
        sentence (str): 要提取单词的英文句子。  
          
    Returns:  
        List[str]: 提取并处理后的单词列表。  
    """  
    en_words = []  
    for w in sentence.split(' '):  # 将英文句子按空格分词  
        w = w.replace('.', '').replace(',', '')  # 去除跟单词连着的标点符号  
        w = w.lower()  # 统一单词大小写  
        if w:  
            en_words.append(w)  
    return en_words  
  
# 测试函数  
sentence = 'I am Dave Gallo.'  
print(extract_words(sentence))
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21

运行结果:

加载训练好的Seq2Seq模型

代码如下:

import torch
import torch.nn as nn


class Encoder(nn.Module):
    def __init__(self, input_dim, emb_dim, hid_dim, n_layers, dropout):
        super().__init__()
        self.hid_dim = hid_dim
        self.n_layers = n_layers
        self.embedding = nn.Embedding(input_dim, emb_dim)  # 词嵌入
        self.rnn = nn.LSTM(emb_dim, hid_dim, n_layers, dropout=dropout)
        self.dropout = nn.Dropout(dropout)

    def forward(self, src):
        # src = (src len, batch size)
        embedded = self.dropout(self.embedding(src))
        # embedded = (src len, batch size, emb dim)
        outputs, (hidden, cell) = self.rnn(embedded)
        # outputs = (src len, batch size, hid dim * n directions)
        # hidden = (n layers * n directions, batch size, hid dim)
        # cell = (n layers * n directions, batch size, hid dim)
        # rnn的输出总是来自顶部的隐藏层
        return hidden, cell


class Decoder(nn.Module):
    def __init__(self, output_dim, emb_dim, hid_dim, n_layers, dropout):
        super().__init__()
        self.output_dim = output_dim
        self.hid_dim = hid_dim
        self.n_layers = n_layers
        self.embedding = nn.Embedding(output_dim, emb_dim)
        self.rnn = nn.LSTM(emb_dim, hid_dim, n_layers, dropout=dropout)
        self.fc_out = nn.Linear(hid_dim, output_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, input, hidden, cell):
        # 各输入的形状
        # input = (batch size)
        # hidden = (n layers * n directions, batch size, hid dim)
        # cell = (n layers * n directions, batch size, hid dim)

        # LSTM是单向的  ==> n directions == 1
        # hidden = (n layers, batch size, hid dim)
        # cell = (n layers, batch size, hid dim)

        input = input.unsqueeze(0)  # (batch size)  --> [1, batch size)

        embedded = self.dropout(self.embedding(input))  # (1, batch size, emb dim)

        output, (hidden, cell) = self.rnn(embedded, (hidden, cell))
        # LSTM理论上的输出形状
        # output = (seq len, batch size, hid dim * n directions)
        # hidden = (n layers * n directions, batch size, hid dim)
        # cell = (n layers * n directions, batch size, hid dim)

        # 解码器中的序列长度 seq len == 1
        # 解码器的LSTM是单向的 n directions == 1 则实际上
        # output = (1, batch size, hid dim)
        # hidden = (n layers, batch size, hid dim)
        # cell = (n layers, batch size, hid dim)

        prediction = self.fc_out(output.squeeze(0))

        # prediction = (batch size, output dim)

        return prediction, hidden, cell


class Seq2Seq(nn.Module):
    def __init__(self, input_word_count, output_word_count, encode_dim, decode_dim, hidden_dim, n_layers,
                 encode_dropout, decode_dropout, device):
        """

        :param input_word_count:    英文词表的长度     34737
        :param output_word_count:   中文词表的长度     4015
        :param encode_dim:          编码器的词嵌入维度
        :param decode_dim:          解码器的词嵌入维度
        :param hidden_dim:          LSTM的隐藏层维度
        :param n_layers:            采用n层LSTM
        :param encode_dropout:      编码器的dropout概率
        :param decode_dropout:      编码器的dropout概率
        :param device:              cuda / cpu
        """
        super().__init__()
        self.encoder = Encoder(input_word_count, encode_dim, hidden_dim, n_layers, encode_dropout)
        self.decoder = Decoder(output_word_count, decode_dim, hidden_dim, n_layers, decode_dropout)
        self.device = device

    def forward(self, src):
        # src = (src len, batch size)

        # 编码器的隐藏层输出将作为解码器的第一个隐藏层输入
        hidden, cell = self.encoder(src)

        # 解码器的第一个输入应该是起始标识符<sos>
        input = src[0, :]  # 取trg的第“0”行所有列  “0”指的是索引
        pred = [0] # 预测的第一个输出应该是起始标识符
        top1 = 0
        while top1 != 1 and len(pred) < 100:
            # 解码器的输入包括:起始标识符的词嵌入input; 编码器输出的 hidden and cell states
            # 解码器的输出包括:输出张量(predictions) and new hidden and cell states
            output, hidden, cell = self.decoder(input, hidden, cell)
            top1 = output.argmax(dim=1)  # (batch size, )
            pred.append(top1.item())
            input = top1

        return pred

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')  # GPU可用 用GPU
# Seq2Seq模型实例化
source_word_count = 34737  # 英文词表的长度     34737
target_word_count = 4015  # 中文词表的长度     4015
encode_dim = 256  # 编码器的词嵌入维度
decode_dim = 256  # 解码器的词嵌入维度
hidden_dim = 512  # LSTM的隐藏层维度
n_layers = 2  # 采用n层LSTM
encode_dropout = 0.5  # 编码器的dropout概率
decode_dropout = 0.5  # 编码器的dropout概率
model = Seq2Seq(source_word_count, target_word_count, encode_dim, decode_dim, hidden_dim, n_layers, encode_dropout,
                decode_dropout, device).to(device)

# 加载训练好的模型
model.load_state_dict(torch.load("best_model.pth"))
model.eval()

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107
  • 108
  • 109
  • 110
  • 111
  • 112
  • 113
  • 114
  • 115
  • 116
  • 117
  • 118
  • 119
  • 120
  • 121
  • 122
  • 123
  • 124
  • 125
  • 126

模型预测完整代码

提示预测代码是我们基于训练及验证代码进行改造的,不一定完全正确,可以参考后自行修改~

import torch
import torch.nn as nn
import pickle


class Encoder(nn.Module):
    def __init__(self, input_dim, emb_dim, hid_dim, n_layers, dropout):
        super().__init__()
        self.hid_dim = hid_dim
        self.n_layers = n_layers
        self.embedding = nn.Embedding(input_dim, emb_dim)  # 词嵌入
        self.rnn = nn.LSTM(emb_dim, hid_dim, n_layers, dropout=dropout)
        self.dropout = nn.Dropout(dropout)

    def forward(self, src):
        # src = (src len, batch size)
        embedded = self.dropout(self.embedding(src))
        # embedded = (src len, batch size, emb dim)
        outputs, (hidden, cell) = self.rnn(embedded)
        # outputs = (src len, batch size, hid dim * n directions)
        # hidden = (n layers * n directions, batch size, hid dim)
        # cell = (n layers * n directions, batch size, hid dim)
        # rnn的输出总是来自顶部的隐藏层
        return hidden, cell


class Decoder(nn.Module):
    def __init__(self, output_dim, emb_dim, hid_dim, n_layers, dropout):
        super().__init__()
        self.output_dim = output_dim
        self.hid_dim = hid_dim
        self.n_layers = n_layers
        self.embedding = nn.Embedding(output_dim, emb_dim)
        self.rnn = nn.LSTM(emb_dim, hid_dim, n_layers, dropout=dropout)
        self.fc_out = nn.Linear(hid_dim, output_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, input, hidden, cell):
        # 各输入的形状
        # input = (batch size)
        # hidden = (n layers * n directions, batch size, hid dim)
        # cell = (n layers * n directions, batch size, hid dim)

        # LSTM是单向的  ==> n directions == 1
        # hidden = (n layers, batch size, hid dim)
        # cell = (n layers, batch size, hid dim)

        input = input.unsqueeze(0)  # (batch size)  --> [1, batch size)

        embedded = self.dropout(self.embedding(input))  # (1, batch size, emb dim)

        output, (hidden, cell) = self.rnn(embedded, (hidden, cell))
        # LSTM理论上的输出形状
        # output = (seq len, batch size, hid dim * n directions)
        # hidden = (n layers * n directions, batch size, hid dim)
        # cell = (n layers * n directions, batch size, hid dim)

        # 解码器中的序列长度 seq len == 1
        # 解码器的LSTM是单向的 n directions == 1 则实际上
        # output = (1, batch size, hid dim)
        # hidden = (n layers, batch size, hid dim)
        # cell = (n layers, batch size, hid dim)

        prediction = self.fc_out(output.squeeze(0))

        # prediction = (batch size, output dim)

        return prediction, hidden, cell


class Seq2Seq(nn.Module):
    def __init__(self, input_word_count, output_word_count, encode_dim, decode_dim, hidden_dim, n_layers,
                 encode_dropout, decode_dropout, device):
        """

        :param input_word_count:    英文词表的长度     34737
        :param output_word_count:   中文词表的长度     4015
        :param encode_dim:          编码器的词嵌入维度
        :param decode_dim:          解码器的词嵌入维度
        :param hidden_dim:          LSTM的隐藏层维度
        :param n_layers:            采用n层LSTM
        :param encode_dropout:      编码器的dropout概率
        :param decode_dropout:      编码器的dropout概率
        :param device:              cuda / cpu
        """
        super().__init__()
        self.encoder = Encoder(input_word_count, encode_dim, hidden_dim, n_layers, encode_dropout)
        self.decoder = Decoder(output_word_count, decode_dim, hidden_dim, n_layers, decode_dropout)
        self.device = device

    def forward(self, src):
        # src = (src len, batch size)

        # 编码器的隐藏层输出将作为解码器的第一个隐藏层输入
        hidden, cell = self.encoder(src)

        # 解码器的第一个输入应该是起始标识符<sos>
        input = src[0, :]  # 取trg的第“0”行所有列  “0”指的是索引
        pred = [0] # 预测的第一个输出应该是起始标识符
        top1 = 0
        while top1 != 1 and len(pred) < 100:
            # 解码器的输入包括:起始标识符的词嵌入input; 编码器输出的 hidden and cell states
            # 解码器的输出包括:输出张量(predictions) and new hidden and cell states
            output, hidden, cell = self.decoder(input, hidden, cell)
            top1 = output.argmax(dim=1)  # (batch size, )
            pred.append(top1.item())
            input = top1

        return pred


if __name__ == '__main__':
    sentence = 'I am Dave Gallo.'
    en_words = []

    for w in sentence.split(' '):  # 英文内容按照空格字符进行分词
        # 按照空格进行分词后,某些单词后面会跟着标点符号 "." 和 “,”
        w = w.replace('.', '').replace(',', '')  # 去掉跟单词连着的标点符号
        w = w.lower()  # 统一单词大小写
        if w:
            en_words.append(w)

    print(en_words)

    with open("en2id.pkl", 'rb') as f:
        en2id = pickle.load(f)
    with open("zh2id.pkl", 'rb') as f:
        zh2id = pickle.load(f)

    device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')  # GPU可用 用GPU
    # Seq2Seq模型实例化
    source_word_count = 34737  # 英文词表的长度     34737
    target_word_count = 4015  # 中文词表的长度     4015
    encode_dim = 256  # 编码器的词嵌入维度
    decode_dim = 256  # 解码器的词嵌入维度
    hidden_dim = 512  # LSTM的隐藏层维度
    n_layers = 2  # 采用n层LSTM
    encode_dropout = 0.5  # 编码器的dropout概率
    decode_dropout = 0.5  # 编码器的dropout概率
    model = Seq2Seq(source_word_count, target_word_count, encode_dim, decode_dim, hidden_dim, n_layers, encode_dropout,
                    decode_dropout, device).to(device)

    model.load_state_dict(torch.load("best_model.pth"))
    model.eval()

    src = [0] # 0 --> 起始标识符的编码
    for i in range(len(en_words)):
        src.append(en2id[en_words[i]])
    src = src + [1] # 1 --> 终止标识符的编码

    text_input = torch.LongTensor(src)
    text_input = text_input.unsqueeze(-1).to(device)

    text_output = model(text_input)
    print(text_output)
    id2zh = dict()
    for k, v in zh2id.items():
        id2zh[v] = k

    text_output = [id2zh[index] for index in text_output]
    text_output = " ".join(text_output)
    print(text_output)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107
  • 108
  • 109
  • 110
  • 111
  • 112
  • 113
  • 114
  • 115
  • 116
  • 117
  • 118
  • 119
  • 120
  • 121
  • 122
  • 123
  • 124
  • 125
  • 126
  • 127
  • 128
  • 129
  • 130
  • 131
  • 132
  • 133
  • 134
  • 135
  • 136
  • 137
  • 138
  • 139
  • 140
  • 141
  • 142
  • 143
  • 144
  • 145
  • 146
  • 147
  • 148
  • 149
  • 150
  • 151
  • 152
  • 153
  • 154
  • 155
  • 156
  • 157
  • 158
  • 159
  • 160
  • 161
  • 162

结束语

  • 亲爱的读者,感谢您花时间阅读我们的博客。我们非常重视您的反馈和意见,因此在这里鼓励您对我们的博客进行评论。
  • 您的建议和看法对我们来说非常重要,这有助于我们更好地了解您的需求,并提供更高质量的内容和服务。
  • 无论您是喜欢我们的博客还是对其有任何疑问或建议,我们都非常期待您的留言。让我们一起互动,共同进步!谢谢您的支持和参与!
  • 我会坚持不懈地创作,并持续优化博文质量,为您提供更好的阅读体验。
  • 谢谢您的阅读!
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/不正经/article/detail/688885
推荐阅读
相关标签
  

闽ICP备14008679号