当前位置:   article > 正文

Pytorch与深度学习 —— 11. 使用 LSTM 做姓名分类预测之 RNN提高篇_姓氏分类lstm

姓氏分类lstm

在前面的章节里,已经给大家介绍了什么是RNN网络的进阶型——LSTM网络的基本知识,如果不清楚的同学请移步到《Pytorch与深度学习 —— 10. 什么是长短期记忆网络》。在《Pytorch与深度学习 —— 9. 使用 RNNCell 做文字序列的转化之 RNN 入门篇》 这篇文章里,我提前做了一些简单的铺垫,例如独热向量等基础知识后,现在我们就正式开始回答在介绍RNN网络模型一开始便提到的姓名分类问题。

回顾一下问题

我们现在有这样的一组数据集,它是按照拉丁文字进行拼写的来自不同国家的常见姓氏,如果打开这个数据集,可以发现它大概是这样

InputOutput
AbbasEnglish
AddamsEnglish
BrooksEnglish
MuirchertachIrish
NeilIrish
HaKorean

数据集我已经放在了CSDN的下载里,如果有需要的同学也可以自己去下载

在我们这个应用中,我们要考虑的是,当输入一个新的姓名后,比如 ‘Abbas’ 后,我们的程序能否判断出它是一个英语姓氏。

读取数据

回顾问题后,现在我们要来做一个读取数据的简单程序,把在文本里的姓氏,按照所在 { 语言 : [姓氏] } 这种字典-列表的形式导入到程序里。

# ASCII codes
all_letters = string.ascii_letters + " .,;'"

def find_files(path): return glob.glob(path)

def unicode_to_ascii(s):
    return ''.join(
        c for c in unicodedata.normalize('NFD', s)
        if unicodedata.category(c) != 'Mn'
        and c in all_letters
    )

def read_lines(filename):
    lines = open(filename, encoding='utf-8').read().strip().split('\n')
    return [unicode_to_ascii(line) for line in lines]

def load_data(path):
    for filename in find_files(path):
        category = os.path.splitext(os.path.basename(filename))[0]
        language_list.append(category)
        lines = read_lines(filename)
        names_dictionary[category] = lines
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22

我们主要依靠的是这几段代码,它们的作用就是从文本中依次读取每一个名字,然后把名字由UTF-8 转码成ASCII,然后以存储在前面我提到过的 {语言: [姓氏]} 这样的字典-列表结构中。

对文本进行编码

为了让程序能够理解数据集,我们需要对这些字符串数据进行一定程度的编码。One-Hot-Vector 我在前面的文章里已经解释过了,所以在这里不做过多的重复。

这里只做一些必要的补充性介绍。

使用 One-Hot-Vector 编码姓名

我们已经通过如下的代码,创建出了一个新的用于编码的字符序列,这个序列包括一些特定的符号(在西班牙语、葡萄牙语等传统拉丁语族国家才有的重音符号)。

all_letters = string.ascii_letters + " .,;’"

比方说我们要编码一个名叫 ‘abc’ 的姓名,那么每一个字符对应一个长度为57的One-Hot向量。然后按照顺序进行输出后,结果应该是

tensor([[[1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0.]],

        [[0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0.]],

        [[0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0.]]])
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14

由 ‘abc’ 所转化成的用于网络输入的张量维度就是 (3, 1, 57),也就是我在前面章节介绍过的对于NLP方向来说,pytorch接受的数据默认维度即:

( L , N , H i n ) ≈ ( S e q u e n c e , B a t c h , F e a t u r e s ) (L, N, H_{in}) \approx (Sequence, Batch, Features) (L,N,Hin)(Sequence,Batch,Features)

Batch

一个Batch对应一个单词,比如上面提到的 ‘abc’;

Sequence

即这个单词包含有多少个字符,对于单词 ‘abc’ 来说,它包含3个字符;

Features

即每一个字符所对应的 One-Hot 编码。

需要注意的地方

这里存在一个新手很容易犯的问题,就是对于我们的数据来说,有单词 ‘Muirchertach’ 长度为12, 也有单词像 ‘Ha’ 这样只有长度为2的,对于可执行并行计算的神经网络来说,如果数据长度不统一,那么不仅导致网络处理效率会降低,而且我们实际处理起来也特别麻烦。所以我们需要对这些数据做一个类似 padding 的操作。

对数据进行填充

这个概念在很多方面都有提到或者广泛使用,如果是第一次学习计算机或者数据分析的朋友,填充简而言之就是下面这个列表表示的意思。

原始列表填充后至8位长
[‘a’, ‘b’, ‘c’][‘a’, ‘b’, ‘c’, 0, 0, 0, 0, 0]
[‘h’, ‘e’, ‘l’, ‘l’, ‘o’][‘h’, ‘e’, ‘l’, ‘l’, ‘o’, 0, 0, 0]

理解这些基本概念后,我们就可以使用代码实现这个过程。

def line_to_one_hot_tensor(line, max_padding=0):
    """
    Turn a line into a one-hot based tensor (character, one-hot-vector)
    """
    if max_padding >= len(line):
        tensor = torch.zeros(max_padding, len(all_letters))
    else:
        tensor = torch.zeros(len(line), len(all_letters))

    for idx, letter in enumerate(line):
        tensor[idx][_letter_to_index(letter)] = 1
    return tensor
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12

这个代码片段会把比如 ‘abc’ 这样的单词转换为 (Sequence, Features) 这样结构的独热向量表示的张量。之所以没有直接转换为 (L, N, H)这样的结构,是因为我们在读取数据的时候可能一次性要读取很多个不同的单词,所以得到的单词组,比如 [‘Abbas’, ‘Addams’, …] 这样的数组,就可以通过下面这段代码,再转换成 (L、N、H)的结构了。

from torch.nn.utils.rnn import pad_sequence

def concatenate_tensors(tensor_list):
    return pad_sequence(tensor_list)

def to_one_hot_based_tensor(surnames: list, padding=20):
    tensors = []
    for name in surnames:
        tensor = line_to_one_hot_tensor(name, padding)
        tensors.append(tensor)

    return concatenate_tensors(tensors)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12

使用序列编码语言

我们提到过,通过神经元网络输出的结果,其实是个概率。比如通过网络输入的Features是这样的一组 One-Hot 向量

I n p u t → = { ( 1 , 0 , 0 , 0 ) ( 1 , 0 , 0 , 0 ) ( 0 , 1 , 0 , 0 ) ( 0 , 0 , 1 , 0 ) \overrightarrow{Input} = \left\{

(1,0,0,0)(1,0,0,0)(0,1,0,0)(0,0,1,0)
\right. Input =(1,0,0,0)(1,0,0,0)(0,1,0,0)(0,0,1,0)

经过我们的网络处理后,输出的 Output 是对应的每一个标签的可能性:

O u t p u t → = { l a b e l . 1 0.15 l a b e l . 2 0.35 l a b e l . 3 0.1 l a b e l . 4 0.4 \overrightarrow{Output} = \left\{

label.10.15label.20.35label.30.1label.40.4
\right. Output =label.1label.2label.3label.40.150.350.10.4

然后经过比如交叉熵进行比对的时候,我们告诉这个网络输出的值其实应该是 label 2, 你给出的 label 4 是错误的,所以网络会根据我们告诉它的情况,执行反馈计算的时候调整 label 2 和 label 4的权重。

因此,对于我们这个例子来说,我们就需要把 [‘English’, ‘Irish’, ‘French’, …] 这一类字符标签,转化成 [0, 1, 2, 3, 4, 5, …] 这样的形式,所以其实还是挺简单的。


def line_to_index(line: str, data_list: list):
    """
    Turn a line into an index based from dataset
    """
    return data_list.index(line)

def to_simple_tensor(languages):
    indices = []
    for lang in languages:
        index = line_to_index(lang, languages)
        indices.append(index)

    return torch.tensor(indices)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14

所以我们用到了这样两段极为简单的代码,帮助我们转化标签。

数据加载

数据加载这里,由于我们使用的是自己的数据,所以没法直接用 Pytorch 提供的 DataLoader,但是我们可以重载名为Dataset的类。

from torch.utils.data import Dataset

class MyNameDataset(Dataset):

    def __init__(self, dict_data: dict):
        self.x_data = []
        self.y_data = []
        self.languages = []

        for lang, names in dict_data.items():
            for name in names:
                self.x_data.append(name)
                self.y_data.append(lang)

            self.languages.append(lang)

    def __len__(self):
        return len(self.x_data)

    def __getitem__(self, item):
        return self.x_data[item], self.y_data[item]
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21

我们把 {‘语言’: [姓名]} 这个字典-列表类型输入这个重载类后,再加载到DataLoader里,就可以根据需要输出我们想要的 {标签:姓名} 姓名对了。

比如我们通过DataLoader,让它一次性抓取10条数据,输出的结果就是这样的

[(‘Durdin’, ‘Guliev’, ‘Palmer’, ‘Gerhard’, ‘Timpe’, ‘Jelvakov’, ‘Seighin’, ‘Neverov’, ‘Babayants’, ‘Robishaw’), (‘Russian’, ‘Russian’, ‘English’, ‘German’, ‘Czech’, ‘Russian’, ‘Irish’, ‘Russian’, ‘Russian’, ‘English’)]

编写LSTM网络模型

我们的这个模型比较简单,用到了一层LSTM作为主要的数据处理,以及一层线性层做最终的输出。

class LSTMModel(torch.nn.Module):

    def __init__(self, input_size, hidden_size, output_size, batch_size, sequence_size, num_layers=1):
        super().__init__()

        self.input_size = input_size
        self.hidden_size = hidden_size
        self.num_layers = num_layers

        self.output_size = output_size

        self.batch_size = batch_size

        # lstm layer
        self.cell = torch.nn.LSTM(
            input_size=self.input_size,
            hidden_size=self.hidden_size,
            num_layers=num_layers)

        # linear layer for output
        self.linear = torch.nn.Linear(sequence_size * hidden_size, self.output_size)

    def forward(self, input_x):
        """
        forward computation

        @param input_x, tensor of shape (L, N, H_in)
        @return tensor of shape (N, H_out)
        """

        # get dimension from input_x
        _, batch, features = input_x.size()

        # hidden, tensor of shape (D * num_layers, N, H_hidden)
        hidden = self.init_zeros(batch)

        # cell, tensor of shape (D * num_layers, N, H_hidden)
        cell = self.init_zeros(batch)

        # output tensor (L, N, D * H_hidden)
        output, _ = self.cell(input_x, (hidden, cell))

        # convert the shape of output to (N, L * H_hidden)
        hidden = convert_hidden_shape(output, batch)

        # (N, L * H_hidden) to (N, H_out)
        output = self.linear(hidden)

        return output

    def init_zeros(self, batch_size=0, hidden_size=0):
        if batch_size == 0:
            batch_size = self.batch_size

        if hidden_size == 0:
            hidden_size = self.hidden_size

        return torch.zeros(self.num_layers, batch_size, hidden_size)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58

这里需要注意的是经过LSTM计算后的网络,输出的Output维度是

( L , N , D ∗ H o u t ) (L, N, D * H_{out}) (L,N,DHout)

由于我们使用的是单向,所以D=1,最终输出的维度是

( L , N , H o u t ) (L, N, H_{out}) (L,N,Hout)

但是线性层能接受的输入维度是

( N , H i n ) (N, H_{in}) (N,Hin)

这意味着我们要把LSTM网络输出的结构转化成线性层可接受的维度, 即

( L , N , H o u t ) → ( N , H i n ) = ( N , L ∗ H o u t ) (L, N, H_{out}) \rightarrow (N, H_{in}) = ( N, L * H_{out}) (L,N,Hout)(N,Hin)=(N,LHout)

这里我提供一个比较笨的转化方法,你可以在学会LSTM之后对这部分进行修改。

def convert_hidden_shape(hidden, batch_size):
    tensor_list = []

    for i in range(batch_size):
        ts = hidden[:, i, :].reshape(1, -1)
        tensor_list.append(ts)

    ts = torch.cat(tensor_list)
    return ts
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9

另外就是需要注意下,维度转换的时候,一定要注意保证数据的正确性和完整性,否则会影响最终的输出。

把上面的内容拼接起来

这部分就是例行公事了,创建网络对象、选择合适的损失函数、合适的优化函数,然后制定训练和测试过程。

主要过程

    # define a model
    model = LSTMModel(
        input_size=INPUT_SIZE,
        hidden_size=HIDDEN_SIZE,
        output_size=OUTPUT_SIZE,
        sequence_size=SEQUENCE_SIZE,
        batch_size=BATCH_SIZE)

    # loss function
    criterion = torch.nn.CrossEntropyLoss()

    # majorized function
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

    # Training and testing process
    for epoch in range(10):

        # load dataset
        languages, language_idx, train_loader, test_loader = load_dataset()

        # training
        train(epoch, model, optimizer, criterion)

        # testing        
        test(model)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25

这里稍微玩了点小花招,我们在每一次训练网络的过程中,都会重新随机的加载一次数据,这样保证每次训练集和测试集都有所不同,能够更好的评估和修正模型的准确度。

然后分别就是训练过程和测试过程的代码了

训练过程

def train(epoch, model, optimizer, criterion):
    running_loss = 0

    for idx, data in enumerate(train_loader, 0):

        # convert data
        input_x, label_y = convert_data(data)

        # clear the gradients
        optimizer.zero_grad()

        # forward computation
        predicate_y = model(input_x)

        # loss computation
        loss = criterion(predicate_y, label_y)

        # backward propagation
        loss.backward()

        # update network parameters
        optimizer.step()

        # print loss
        running_loss += loss.item()
        if idx % 100 == 0:
            print('[%d, %5d] loss: %.3f' % (epoch, idx, running_loss / 100))
            running_loss = 0
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28

测试过程

def test(model):
    correct = 0
    total = 0

    with torch.no_grad():
        for idx, data in enumerate(test_loader, 0):
            # convert data
            input_x, label_y = convert_data(data)

            # predicate
            predicate_y = model(input_x)

            # check output
            _, predicated = torch.max(predicate_y.data, dim=1)
            total += label_y.cpu().size(0)
            correct += (predicated == label_y).sum().item()

    print("Accuracy on test set: %d %%" % (100 * correct / total))
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18

实验输出结果

如果程序一切正常,输出的结果就是这样的


[0,     0] loss: 0.029
[0,   100] loss: 2.041
[0,   200] loss: 1.928
[0,   300] loss: 1.969
[0,   400] loss: 1.946
[0,   500] loss: 1.885
[0,   600] loss: 1.895
[0,   700] loss: 1.858
[0,   800] loss: 1.898
[0,   900] loss: 1.960
[0,  1000] loss: 1.897
[0,  1100] loss: 1.880
[0,  1200] loss: 1.911
[0,  1300] loss: 1.816
[0,  1400] loss: 1.925
[0,  1500] loss: 1.913
Accuracy on test set: 97 %
[1,     0] loss: 0.015
[1,   100] loss: 1.933
[1,   200] loss: 1.880
[1,   300] loss: 1.795
[1,   400] loss: 1.908
[1,   500] loss: 1.881
[1,   600] loss: 1.821
[1,   700] loss: 1.874
[1,   800] loss: 1.845
[1,   900] loss: 1.880
[1,  1000] loss: 1.806
[1,  1100] loss: 1.794
[1,  1200] loss: 1.820
[1,  1300] loss: 1.951
[1,  1400] loss: 1.852
[1,  1500] loss: 1.865
Accuracy on test set: 97 %
[2,     0] loss: 0.012
[2,   100] loss: 1.798
[2,   200] loss: 1.799
[2,   300] loss: 1.864
[2,   400] loss: 1.848
[2,   500] loss: 1.878
[2,   600] loss: 1.766
[2,   700] loss: 1.816
[2,   800] loss: 1.857
[2,   900] loss: 1.821
[2,  1000] loss: 1.929
[2,  1100] loss: 1.905
[2,  1200] loss: 1.830
[2,  1300] loss: 1.870
[2,  1400] loss: 1.867
[2,  1500] loss: 1.897
Accuracy on test set: 97 %
[3,     0] loss: 0.013
[3,   100] loss: 1.824
[3,   200] loss: 1.843
[3,   300] loss: 1.884
[3,   400] loss: 1.870
[3,   500] loss: 1.831
[3,   600] loss: 1.907
[3,   700] loss: 1.807
[3,   800] loss: 1.858
[3,   900] loss: 1.837
[3,  1000] loss: 1.809
[3,  1100] loss: 1.873
[3,  1200] loss: 1.904
[3,  1300] loss: 1.848
[3,  1400] loss: 1.886
[3,  1500] loss: 1.860
Accuracy on test set: 97 %
[4,     0] loss: 0.017
[4,   100] loss: 1.834
[4,   200] loss: 1.847
[4,   300] loss: 1.870
[4,   400] loss: 1.779
[4,   500] loss: 1.773
[4,   600] loss: 1.900
[4,   700] loss: 1.862
[4,   800] loss: 1.828
[4,   900] loss: 1.831
[4,  1000] loss: 1.804
[4,  1100] loss: 1.846
[4,  1200] loss: 1.898
[4,  1300] loss: 1.883
[4,  1400] loss: 1.888
[4,  1500] loss: 1.820
Accuracy on test set: 97 %
[5,     0] loss: 0.018
[5,   100] loss: 1.815
[5,   200] loss: 1.886
[5,   300] loss: 1.853
[5,   400] loss: 1.897
[5,   500] loss: 1.862
[5,   600] loss: 1.894
[5,   700] loss: 1.865
[5,   800] loss: 1.818
[5,   900] loss: 1.868
[5,  1000] loss: 1.790
[5,  1100] loss: 1.815
[5,  1200] loss: 1.813
[5,  1300] loss: 1.890
[5,  1400] loss: 1.784
[5,  1500] loss: 1.848
Accuracy on test set: 94 %
[6,     0] loss: 0.023
[6,   100] loss: 1.870
[6,   200] loss: 1.861
[6,   300] loss: 1.850
[6,   400] loss: 1.868
[6,   500] loss: 1.892
[6,   600] loss: 1.866
[6,   700] loss: 1.853
[6,   800] loss: 1.803
[6,   900] loss: 1.805
[6,  1000] loss: 1.801
[6,  1100] loss: 1.895
[6,  1200] loss: 1.821
[6,  1300] loss: 1.808
[6,  1400] loss: 1.906
[6,  1500] loss: 1.864
Accuracy on test set: 97 %
[7,     0] loss: 0.021
[7,   100] loss: 1.807
[7,   200] loss: 1.785
[7,   300] loss: 1.900
[7,   400] loss: 1.863
[7,   500] loss: 1.830
[7,   600] loss: 1.809
[7,   700] loss: 1.844
[7,   800] loss: 1.794
[7,   900] loss: 1.901
[7,  1000] loss: 1.892
[7,  1100] loss: 1.829
[7,  1200] loss: 1.875
[7,  1300] loss: 1.873
[7,  1400] loss: 1.825
[7,  1500] loss: 1.788
Accuracy on test set: 97 %
[8,     0] loss: 0.015
[8,   100] loss: 1.850
[8,   200] loss: 1.792
[8,   300] loss: 1.860
[8,   400] loss: 1.863
[8,   500] loss: 1.835
[8,   600] loss: 1.776
[8,   700] loss: 1.865
[8,   800] loss: 1.780
[8,   900] loss: 1.851
[8,  1000] loss: 1.873
[8,  1100] loss: 1.819
[8,  1200] loss: 1.804
[8,  1300] loss: 1.855
[8,  1400] loss: 1.892
[8,  1500] loss: 1.869
Accuracy on test set: 97 %
[9,     0] loss: 0.015
[9,   100] loss: 1.818
[9,   200] loss: 1.795
[9,   300] loss: 1.839
[9,   400] loss: 1.911
[9,   500] loss: 1.854
[9,   600] loss: 1.859
[9,   700] loss: 1.810
[9,   800] loss: 1.852
[9,   900] loss: 1.842
[9,  1000] loss: 1.797
[9,  1100] loss: 1.842
[9,  1200] loss: 1.777
[9,  1300] loss: 1.822
[9,  1400] loss: 1.783
[9,  1500] loss: 1.872
Accuracy on test set: 95 %
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107
  • 108
  • 109
  • 110
  • 111
  • 112
  • 113
  • 114
  • 115
  • 116
  • 117
  • 118
  • 119
  • 120
  • 121
  • 122
  • 123
  • 124
  • 125
  • 126
  • 127
  • 128
  • 129
  • 130
  • 131
  • 132
  • 133
  • 134
  • 135
  • 136
  • 137
  • 138
  • 139
  • 140
  • 141
  • 142
  • 143
  • 144
  • 145
  • 146
  • 147
  • 148
  • 149
  • 150
  • 151
  • 152
  • 153
  • 154
  • 155
  • 156
  • 157
  • 158
  • 159
  • 160
  • 161
  • 162
  • 163
  • 164
  • 165
  • 166
  • 167
  • 168
  • 169
  • 170
  • 171

可以看到整体的实验结果还算满意,大概有96%左右的准确度。如果你希望把训练好的模型保存起来,并且期待能否把模型应用到一般的应用程序里,比如说C程序里,那么就可以把模型和参数都保存起来。

保存模型和数据

文件的后缀名没啥强制要求,我比较喜欢叫ptm,因为是 pytorch model 的简写,你也可以自己定义个喜欢的后缀名。

    # finally save the model
    torch.save(model, "LSTM_Surname_Classfication.ptm")
  • 1
  • 2

在下一章节里,我将给大家演示如何使用C程序加载训练好的模型。

欢迎关注我的博客~

Adios~~

参考资料

  • 《NLP FROM SCRATCH: CLASSIFYING NAMES WITH A CHARACTER-LEVEL RNN》, Sean Robertson,https://pytorch.org/tutorials/intermediate/char_rnn_classification_tutorial.html
本文内容由网友自发贡献,转载请注明出处:【wpsshop博客】
推荐阅读
相关标签
  

闽ICP备14008679号