当前位置:   article > 正文

Pytorch RNN 实现新闻数据分类

Pytorch RNN 实现新闻数据分类

概述

RNN (Recurrent Netural Network) 是用于处理序列数据的神经网络. 所谓序列数据, 即前面的输入和后面的输入有一定的联系.

在这里插入图片描述

数据集

我们将使用 THUCNews 的一个子数据集, 该数据集包含 10 个类别的新闻数据, 单个类别有 10000 条数据.

在这里插入图片描述

Text RNN 模型

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np


class Config(object):

    """配置参数"""
    def __init__(self, dataset, embedding):
        self.model_name = 'TextCNN'
        self.train_path = dataset + '/data/train.txt'                                # 训练集
        self.dev_path = dataset + '/data/dev.txt'                                    # 验证集
        self.test_path = dataset + '/data/test.txt'                                  # 测试集
        self.class_list = [x.strip() for x in open(
            dataset + '/data/class.txt').readlines()]                                # 类别名单
        self.vocab_path = dataset + '/data/vocab.pkl'                                # 词表
        self.save_path = dataset + '/saved_dict/' + self.model_name + '.ckpt'        # 模型训练结果
        self.log_path = dataset + '/log/' + self.model_name
        self.embedding_pretrained = torch.tensor(
            np.load(dataset + '/data/' + embedding)["embeddings"].astype('float32'))\
            if embedding != 'random' else None                                       # 预训练词向量
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')   # 设备

        self.dropout = 0.5                                              # 随机失活
        self.require_improvement = 1000                                 # 若超过1000batch效果还没提升,则提前结束训练
        self.num_classes = len(self.class_list)                         # 类别数
        self.n_vocab = 0                                                # 词表大小,在运行时赋值
        self.num_epochs = 20                                            # epoch数
        self.batch_size = 128                                           # mini-batch大小
        self.pad_size = 32                                              # 每句话处理成的长度(短填长切)
        self.learning_rate = 1e-3                                       # 学习率
        self.embed = self.embedding_pretrained.size(1)\
            if self.embedding_pretrained is not None else 300           # 字向量维度
        self.filter_sizes = (2, 3, 4)                                   # 卷积核尺寸
        self.num_filters = 256                                          # 卷积核数量(channels数)


'''Convolutional Neural Networks for Sentence Classification'''


class Model(nn.Module):
    def __init__(self, config):
        super(Model, self).__init__()
        if config.embedding_pretrained is not None:
            self.embedding = nn.Embedding.from_pretrained(config.embedding_pretrained, freeze=False)
        else:
            self.embedding = nn.Embedding(config.n_vocab, config.embed, padding_idx=config.n_vocab - 1)
        self.convs = nn.ModuleList(
            [nn.Conv2d(1, config.num_filters, (k, config.embed)) for k in config.filter_sizes])
        self.dropout = nn.Dropout(config.dropout)
        self.fc = nn.Linear(config.num_filters * len(config.filter_sizes), config.num_classes)

    def conv_and_pool(self, x, conv):
        x = F.relu(conv(x)).squeeze(3)
        x = F.max_pool1d(x, x.size(2)).squeeze(2)
        return x

    def forward(self, x):
        #print (x[0].shape)
        out = self.embedding(x[0])
        out = out.unsqueeze(1)
        out = torch.cat([self.conv_and_pool(out, conv) for conv in self.convs], 1)
        out = self.dropout(out)
        out = self.fc(out)
        return out
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66

评估函数

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from sklearn import metrics
import time
from utils import get_time_dif
from tensorboardX import SummaryWriter


# 权重初始化,默认xavier
def init_network(model, method='xavier', exclude='embedding', seed=123):
    for name, w in model.named_parameters():
        if exclude not in name:
            if 'weight' in name:
                if method == 'xavier':
                    nn.init.xavier_normal_(w)
                elif method == 'kaiming':
                    nn.init.kaiming_normal_(w)
                else:
                    nn.init.normal_(w)
            elif 'bias' in name:
                nn.init.constant_(w, 0)
            else:
                pass


def train(config, model, train_iter, dev_iter, test_iter, writer):
    start_time = time.time()
    model.train()
    optimizer = torch.optim.Adam(model.parameters(), lr=config.learning_rate)

    # 学习率指数衰减,每次epoch:学习率 = gamma * 学习率
    # scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.9)
    total_batch = 0  # 记录进行到多少batch
    dev_best_loss = float('inf')
    last_improve = 0  # 记录上次验证集loss下降的batch数
    flag = False  # 记录是否很久没有效果提升
    # writer = SummaryWriter(log_dir=config.log_path + '/' + time.strftime('%m-%d_%H.%M', time.localtime()))
    for epoch in range(config.num_epochs):
        print('Epoch [{}/{}]'.format(epoch + 1, config.num_epochs))
        # scheduler.step() # 学习率衰减
        for i, (trains, labels) in enumerate(train_iter):
            # print (trains[0].shape)
            outputs = model(trains)
            model.zero_grad()
            loss = F.cross_entropy(outputs, labels)
            loss.backward()
            optimizer.step()
            if total_batch % 100 == 0:
                # 每多少轮输出在训练集和验证集上的效果
                true = labels.data.cpu()
                predic = torch.max(outputs.data, 1)[1].cpu()
                train_acc = metrics.accuracy_score(true, predic)
                dev_acc, dev_loss = evaluate(config, model, dev_iter)
                if dev_loss < dev_best_loss:
                    dev_best_loss = dev_loss
                    torch.save(model.state_dict(), config.save_path)
                    improve = '*'
                    last_improve = total_batch
                else:
                    improve = ''
                time_dif = get_time_dif(start_time)
                msg = 'Iter: {0:>6},  Train Loss: {1:>5.2},  Train Acc: {2:>6.2%},  Val Loss: {3:>5.2},  Val Acc: {4:>6.2%},  Time: {5} {6}'
                print(msg.format(total_batch, loss.item(), train_acc, dev_loss, dev_acc, time_dif, improve))
                writer.add_scalar("loss/train", loss.item(), total_batch)
                writer.add_scalar("loss/dev", dev_loss, total_batch)
                writer.add_scalar("acc/train", train_acc, total_batch)
                writer.add_scalar("acc/dev", dev_acc, total_batch)
                model.train()
            total_batch += 1
            if total_batch - last_improve > config.require_improvement:
                # 验证集loss超过1000batch没下降,结束训练
                print("No optimization for a long time, auto-stopping...")
                flag = True
                break
        if flag:
            break
    writer.close()
    test(config, model, test_iter)


def test(config, model, test_iter):
    # test
    model.load_state_dict(torch.load(config.save_path))
    model.eval()
    start_time = time.time()
    test_acc, test_loss, test_report, test_confusion = evaluate(config, model, test_iter, test=True)
    msg = 'Test Loss: {0:>5.2},  Test Acc: {1:>6.2%}'
    print(msg.format(test_loss, test_acc))
    print("Precision, Recall and F1-Score...")
    print(test_report)
    print("Confusion Matrix...")
    print(test_confusion)
    time_dif = get_time_dif(start_time)
    print("Time usage:", time_dif)


def evaluate(config, model, data_iter, test=False):
    model.eval()
    loss_total = 0
    predict_all = np.array([], dtype=int)
    labels_all = np.array([], dtype=int)
    with torch.no_grad():
        for texts, labels in data_iter:
            outputs = model(texts)
            loss = F.cross_entropy(outputs, labels)
            loss_total += loss
            labels = labels.data.cpu().numpy()
            predic = torch.max(outputs.data, 1)[1].cpu().numpy()
            labels_all = np.append(labels_all, labels)
            predict_all = np.append(predict_all, predic)

    acc = metrics.accuracy_score(labels_all, predict_all)
    if test:
        report = metrics.classification_report(labels_all, predict_all, target_names=config.class_list, digits=4)
        confusion = metrics.confusion_matrix(labels_all, predict_all)
        return acc, loss_total / len(data_iter), report, confusion
    return acc, loss_total / len(data_iter)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107
  • 108
  • 109
  • 110
  • 111
  • 112
  • 113
  • 114
  • 115
  • 116
  • 117
  • 118
  • 119

主函数

在这里插入图片描述

import time
import torch
import numpy as np
from train_eval import train, init_network
from importlib import import_module
import argparse
from tensorboardX import SummaryWriter

parser = argparse.ArgumentParser(description='Chinese Text Classification')
parser.add_argument('--model', type=str, default="TextRNN",
                    help='choose a model: TextCNN, TextRNN, FastText, TextRCNN, TextRNN_Att, DPCNN, Transformer')
parser.add_argument('--embedding', default='pre_trained', type=str, help='random or pre_trained')
parser.add_argument('--word', default=False, type=bool, help='True for word, False for char')
args = parser.parse_args()

if __name__ == '__main__':
    dataset = 'THUCNews'  # 数据集

    # 搜狗新闻:embedding_SougouNews.npz, 腾讯:embedding_Tencent.npz, 随机初始化:random
    embedding = 'embedding_SougouNews.npz'
    if args.embedding == 'random':
        embedding = 'random'
    model_name = args.model  # TextCNN, TextRNN,
    if model_name == 'FastText':
        from utils_fasttext import build_dataset, build_iterator, get_time_dif

        embedding = 'random'
    else:
        from utils import build_dataset, build_iterator, get_time_dif

    x = import_module('models.' + model_name)
    config = x.Config(dataset, embedding)
    np.random.seed(1)
    torch.manual_seed(1)
    torch.cuda.manual_seed_all(1)
    torch.backends.cudnn.deterministic = True  # 保证每次结果一样

    start_time = time.time()
    print("Loading data...")
    vocab, train_data, dev_data, test_data = build_dataset(config, args.word)
    train_iter = build_iterator(train_data, config)
    dev_iter = build_iterator(dev_data, config)
    test_iter = build_iterator(test_data, config)
    time_dif = get_time_dif(start_time)
    print("Time usage:", time_dif)

    # train
    config.n_vocab = len(vocab)
    model = x.Model(config).to(config.device)
    writer = SummaryWriter(log_dir=config.log_path + '/' + time.strftime('%m-%d_%H.%M', time.localtime()))
    if model_name != 'Transformer':
        init_network(model)
    print(model.parameters)
    train(config, model, train_iter, dev_iter, test_iter, writer)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54

输出结果

Loading data...
Vocab size: 4762
180000it [00:03, 56090.03it/s]
10000it [00:00, 32232.86it/s]
10000it [00:00, 61166.60it/s]
Time usage: 0:00:04
<bound method Module.parameters of Model(
  (embedding): Embedding(4762, 300)
  (lstm): LSTM(300, 128, num_layers=2, batch_first=True, dropout=0.5, bidirectional=True)
  (fc): Linear(in_features=256, out_features=10, bias=True)
)>
Epoch [1/10]
Iter:      0,  Train Loss:   2.3,  Train Acc: 11.52%,  Val Loss:   2.3,  Val Acc: 10.00%,  Time: 0:00:00 *
Iter:    100,  Train Loss:   1.3,  Train Acc: 50.39%,  Val Loss:   1.3,  Val Acc: 49.63%,  Time: 0:00:02 *
Iter:    200,  Train Loss:  0.72,  Train Acc: 77.54%,  Val Loss:  0.74,  Val Acc: 75.92%,  Time: 0:00:04 *
Iter:    300,  Train Loss:  0.47,  Train Acc: 84.18%,  Val Loss:  0.55,  Val Acc: 82.34%,  Time: 0:00:06 *
Epoch [2/10]
Iter:    400,  Train Loss:   0.5,  Train Acc: 83.59%,  Val Loss:  0.48,  Val Acc: 85.13%,  Time: 0:00:07 *
Iter:    500,  Train Loss:  0.41,  Train Acc: 88.48%,  Val Loss:  0.43,  Val Acc: 86.42%,  Time: 0:00:09 *
Iter:    600,  Train Loss:  0.37,  Train Acc: 88.48%,  Val Loss:  0.41,  Val Acc: 86.93%,  Time: 0:00:11 *
Iter:    700,  Train Loss:  0.42,  Train Acc: 86.33%,  Val Loss:  0.37,  Val Acc: 87.90%,  Time: 0:00:12 *
Epoch [3/10]
Iter:    800,  Train Loss:  0.35,  Train Acc: 89.06%,  Val Loss:  0.39,  Val Acc: 87.81%,  Time: 0:00:14 
Iter:    900,  Train Loss:   0.3,  Train Acc: 89.06%,  Val Loss:  0.36,  Val Acc: 88.51%,  Time: 0:00:16 *
Iter:   1000,  Train Loss:   0.3,  Train Acc: 90.43%,  Val Loss:  0.36,  Val Acc: 88.81%,  Time: 0:00:17 
Epoch [4/10]
Iter:   1100,  Train Loss:  0.29,  Train Acc: 90.82%,  Val Loss:  0.34,  Val Acc: 89.07%,  Time: 0:00:19 *
Iter:   1200,  Train Loss:  0.28,  Train Acc: 90.82%,  Val Loss:  0.33,  Val Acc: 89.43%,  Time: 0:00:21 *
Iter:   1300,  Train Loss:  0.28,  Train Acc: 90.62%,  Val Loss:  0.33,  Val Acc: 89.41%,  Time: 0:00:22 
Iter:   1400,  Train Loss:  0.25,  Train Acc: 91.60%,  Val Loss:  0.33,  Val Acc: 89.37%,  Time: 0:00:24 
Epoch [5/10]
Iter:   1500,  Train Loss:  0.26,  Train Acc: 91.80%,  Val Loss:  0.34,  Val Acc: 89.56%,  Time: 0:00:26 
Iter:   1600,  Train Loss:  0.18,  Train Acc: 94.34%,  Val Loss:  0.35,  Val Acc: 89.14%,  Time: 0:00:27 
Iter:   1700,  Train Loss:  0.23,  Train Acc: 92.58%,  Val Loss:  0.33,  Val Acc: 89.80%,  Time: 0:00:29 *
Epoch [6/10]
Iter:   1800,  Train Loss:  0.23,  Train Acc: 92.97%,  Val Loss:  0.34,  Val Acc: 89.46%,  Time: 0:00:31 
Iter:   1900,  Train Loss:  0.18,  Train Acc: 94.34%,  Val Loss:  0.32,  Val Acc: 89.76%,  Time: 0:00:33 *
Iter:   2000,  Train Loss:  0.16,  Train Acc: 93.75%,  Val Loss:  0.34,  Val Acc: 89.28%,  Time: 0:00:34 
Iter:   2100,  Train Loss:  0.22,  Train Acc: 92.19%,  Val Loss:  0.32,  Val Acc: 90.12%,  Time: 0:00:36 *
Epoch [7/10]
Iter:   2200,  Train Loss:  0.21,  Train Acc: 92.77%,  Val Loss:  0.34,  Val Acc: 89.67%,  Time: 0:00:38 
Iter:   2300,  Train Loss:  0.18,  Train Acc: 94.73%,  Val Loss:  0.35,  Val Acc: 89.81%,  Time: 0:00:39 
Iter:   2400,  Train Loss:  0.21,  Train Acc: 92.38%,  Val Loss:  0.36,  Val Acc: 89.21%,  Time: 0:00:41 
Epoch [8/10]
Iter:   2500,  Train Loss:  0.19,  Train Acc: 93.75%,  Val Loss:  0.35,  Val Acc: 89.56%,  Time: 0:00:43 
Iter:   2600,  Train Loss:  0.19,  Train Acc: 94.53%,  Val Loss:  0.31,  Val Acc: 90.38%,  Time: 0:00:45 *
Iter:   2700,  Train Loss:   0.2,  Train Acc: 93.75%,  Val Loss:  0.33,  Val Acc: 89.95%,  Time: 0:00:46 
Iter:   2800,  Train Loss:  0.15,  Train Acc: 94.92%,  Val Loss:  0.33,  Val Acc: 90.05%,  Time: 0:00:48 
Epoch [9/10]
Iter:   2900,  Train Loss:  0.22,  Train Acc: 93.16%,  Val Loss:  0.35,  Val Acc: 89.47%,  Time: 0:00:49 
Iter:   3000,  Train Loss:  0.16,  Train Acc: 94.53%,  Val Loss:  0.36,  Val Acc: 89.72%,  Time: 0:00:51 
Iter:   3100,  Train Loss:  0.19,  Train Acc: 93.95%,  Val Loss:  0.37,  Val Acc: 89.51%,  Time: 0:00:53 
Epoch [10/10]
Iter:   3200,  Train Loss:  0.13,  Train Acc: 95.70%,  Val Loss:  0.35,  Val Acc: 89.67%,  Time: 0:00:54 
Iter:   3300,  Train Loss:   0.2,  Train Acc: 93.36%,  Val Loss:  0.35,  Val Acc: 90.27%,  Time: 0:00:56 
Iter:   3400,  Train Loss:  0.12,  Train Acc: 96.48%,  Val Loss:  0.34,  Val Acc: 89.92%,  Time: 0:00:57 
Iter:   3500,  Train Loss:  0.12,  Train Acc: 95.70%,  Val Loss:  0.35,  Val Acc: 89.98%,  Time: 0:00:59 
Test Loss:   0.3,  Test Acc: 90.66%
Precision, Recall and F1-Score...
               precision    recall  f1-score   support

      finance     0.8777    0.9040    0.8906      1000
       realty     0.9353    0.9110    0.9230      1000
       stocks     0.8843    0.7950    0.8373      1000
    education     0.9319    0.9440    0.9379      1000
      science     0.8297    0.8770    0.8527      1000
      society     0.9012    0.9210    0.9110      1000
     politics     0.9001    0.8740    0.8869      1000
       sports     0.9788    0.9680    0.9734      1000
         game     0.9299    0.9290    0.9295      1000
entertainment     0.9015    0.9430    0.9218      1000

     accuracy                         0.9066     10000
    macro avg     0.9070    0.9066    0.9064     10000
 weighted avg     0.9070    0.9066    0.9064     10000

Confusion Matrix...
[[904  11  38   5  16  10   9   1   1   5]
 [ 14 911  14   6   9  12  10   4   6  14]
 [ 72  25 795   5  57   1  33   0   9   3]
 [  2   1   2 944  10  18   7   0   5  11]
 [ 11   6  18   8 877  17  15   0  32  16]
 [  4  12   1  18   7 921  14   1   7  15]
 [ 16   3  21  14  26  29 874   4   2  11]
 [  1   1   3   1   3   2   4 968   0  17]
 [  2   1   5   5  39   4   3   1 929  11]
 [  4   3   2   7  13   8   2  10   8 943]]
Time usage: 0:00:00
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88

在这里插入图片描述
在这里插入图片描述

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/知新_RL/article/detail/954484
推荐阅读
相关标签
  

闽ICP备14008679号