当前位置:   article > 正文

自然语言处理(七):AG_NEWS新闻分类任务(TORCHTEXT)_ag news

ag news

自然语言处理笔记总目录


关于新闻主题分类任务: 以一段新闻报道中的文本描述内容为输入,使用模型帮助我们判断它最有可能属于哪一种类型的新闻,这是典型的文本分类问题,,我们这里假定每种类型是互斥的,即文本描述有且只有一种类型

本案例取自Pytorch官网的:TEXT CLASSIFICATION WITH THE TORCHTEXT LIBRARY,在此基础上增加了完整的注释以及通俗的讲解

本案例分为以下九个步骤

Step 1:Access to the raw dataset iterators

AG_NEWS数据集介绍:

AG_NEWS:新闻语料库,包含4个大类新闻:World、Sports、Business、Sci/Tec。

AG_NEWS共包含120000条训练样本集(train.csv), 7600测试样本数据集(test.csv)。每个类别分别拥有 30000 个训练样本及 1900 个测试样本。

import torch
from torchtext.datasets import AG_NEWS
train_iter = AG_NEWS(split='train')
  • 1
  • 2
  • 3

返回的是一个训练集的迭代器,通过以下方法可以查看训练集的内容:

next(train_iter)
>>> (3, "Wall St. Bears Claw Back Into the Black (Reuters) Reuters -
Short-sellers, Wall Street's dwindling\\band of ultra-cynics, are seeing green
again.")

next(train_iter)
>>> (3, 'Carlyle Looks Toward Commercial Aerospace (Reuters) Reuters - Private
investment firm Carlyle Group,\\which has a reputation for making well-timed
and occasionally\\controversial plays in the defense industry, has quietly
placed\\its bets on another part of the market.')
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10

Step 2:Prepare data processing pipelines

在训练之前,首先我们要处理新闻数据,对文本进行分词,构建词汇表vocab

使用get_tokenizer进行分词,同时build_vocab_from_iterator提供了使用迭代器构建词汇表的方法

from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator

tokenizer = get_tokenizer('basic_english')	# 基本的英文分词器
train_iter = AG_NEWS(split='train')	# 训练数据迭代器

# 分词生成器
def yield_tokens(data_iter):
    for _, text in data_iter:
        yield tokenizer(text)

# 构建词汇表
vocab = build_vocab_from_iterator(yield_tokens(train_iter), specials=["<unk>"])
# 设置默认索引,当某个单词不在词汇表中,则返回0
vocab.set_default_index(vocab["<unk>"])
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
vocab(['here', 'is', 'an', 'example'])
>>> [475, 21, 30, 5286]
print(vocab(["haha", "hehe", "xixi"]))
>>> [0, 0, 0]
  • 1
  • 2
  • 3
  • 4

接下来使用分词器以及词汇表构建Pipeline

text_pipeline = lambda x: vocab(tokenizer(x))
label_pipeline = lambda x: int(x) - 1
  • 1
  • 2
text_pipeline('here is an example')
>>> [475, 21, 30, 5286]
label_pipeline('10')
>>> 9
  • 1
  • 2
  • 3
  • 4

Step 3:Generate data batch and iterator

from torch.utils.data import DataLoader
# 使用GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 定义collate_batch函数,在DataLoader中会使用,对传入的样本数据进行批量处理
def collate_batch(batch):
	# 存放label以及text的列表,offses存放每条text的偏移量
    label_list, text_list, offsets = [], [], [0]
    for (_label, _text) in batch:
         label_list.append(label_pipeline(_label))
         processed_text = torch.tensor(text_pipeline(_text), dtype=torch.int64)
         text_list.append(processed_text)
         # 将每一条数据的长度放入offsets列表当中
         offsets.append(processed_text.size(0))
    label_list = torch.tensor(label_list, dtype=torch.int64)
    # 计算出每一条text的偏移量
    offsets = torch.tensor(offsets[:-1]).cumsum(dim=0)
    text_list = torch.cat(text_list)
    return label_list.to(device), text_list.to(device), offsets.to(device)

train_iter = AG_NEWS(split='train')
dataloader = DataLoader(train_iter, batch_size=8, shuffle=False, collate_fn=collate_batch)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22

cumsum()用于计算一个数组各行的累加值,示例如下:

>>>a = [1, 2, 3, 4, 5, 6, 7]
>>>cumsum(a)
array([1, 3, 6, 10, 15, 21, 28])
  • 1
  • 2
  • 3

Step 4:Define the model

定义神经网络模型: 由EmbeddingBag、隐藏层和全连接层组成
在这里插入图片描述

from torch import nn

class TextClassificationModel(nn.Module):

    def __init__(self, vocab_size, embed_dim, num_class):
        super(TextClassificationModel, self).__init__()
        self.embedding = nn.EmbeddingBag(vocab_size, embed_dim, sparse=True)
        self.fc = nn.Linear(embed_dim, num_class)
        self.init_weights()

    def init_weights(self):
        initrange = 0.5
        self.embedding.weight.data.uniform_(-initrange, initrange)
        self.fc.weight.data.uniform_(-initrange, initrange)
        self.fc.bias.data.zero_()

    def forward(self, text, offsets):
        embedded = self.embedding(text, offsets)
        return self.fc(embedded)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19

Step 5:Initiate an instance

AG_NEWS 数据集有四个标签,因此类的数量是四个

1 : World
2 : Sports
3 : Business
4 : Sci/Tec
  • 1
  • 2
  • 3
  • 4

实例一个模型

train_iter = AG_NEWS(split='train')
num_class = len(set([label for (label, text) in train_iter]))	# 获取分类数量
vocab_size = len(vocab)	# 词汇表大小
emsize = 64	# 词嵌入维度
model = TextClassificationModel(vocab_size, emsize, num_class).to(device)
  • 1
  • 2
  • 3
  • 4
  • 5

Step 6:Define functions to train the model and evaluate results

import time

def train(dataloader):
    model.train()
    total_acc, total_count = 0, 0
    log_interval = 500
    start_time = time.time()

    for idx, (label, text, offsets) in enumerate(dataloader):
        optimizer.zero_grad()
        predicted_label = model(text, offsets)
        loss = criterion(predicted_label, label)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 0.1)
        optimizer.step()
        total_acc += (predicted_label.argmax(1) == label).sum().item()
        total_count += label.size(0)
        if idx % log_interval == 0 and idx > 0:
            elapsed = time.time() - start_time
            print('| epoch {:3d} | {:5d}/{:5d} batches '
                  '| accuracy {:8.3f}'.format(epoch, idx, len(dataloader),
                                              total_acc/total_count))
            total_acc, total_count = 0, 0
            start_time = time.time()

def evaluate(dataloader):
    model.eval()
    total_acc, total_count = 0, 0

    with torch.no_grad():
        for idx, (label, text, offsets) in enumerate(dataloader):
            predicted_label = model(text, offsets)
            loss = criterion(predicted_label, label)
            total_acc += (predicted_label.argmax(1) == label).sum().item()
            total_count += label.size(0)
    return total_acc/total_count
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36

梯度裁剪 torch.nn.utils.clip_grad_norm_() 的使用应该在loss.backward()之后,optimizer.step()之前.

注意这个方法只在训练的时候使用,在测试的时候验证和测试的时候不用。

Step 7:Split the dataset and run the model

拆分训练集:拆分比率为训练集95%,验证集5%,使用torch.utils.data.dataset.random_split函数

to_map_style_dataset函数是将数据集从iterator变为map的形式,可以直接索引

from torch.utils.data.dataset import random_split
from torchtext.data.functional import to_map_style_dataset
# Hyperparameters
EPOCHS = 10 # epoch
LR = 5  # learning rate
BATCH_SIZE = 64 # batch size for training

criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=LR)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gamma=0.1)

total_accu = None

train_iter, test_iter = AG_NEWS()
train_dataset = to_map_style_dataset(train_iter)
test_dataset = to_map_style_dataset(test_iter)

num_train = int(len(train_dataset) * 0.95)
split_train_, split_valid_ = \
    random_split(train_dataset, [num_train, len(train_dataset) - num_train])

train_dataloader = DataLoader(split_train_, batch_size=BATCH_SIZE,
                              shuffle=True, collate_fn=collate_batch)
valid_dataloader = DataLoader(split_valid_, batch_size=BATCH_SIZE,
                              shuffle=True, collate_fn=collate_batch)
test_dataloader = DataLoader(test_dataset, batch_size=BATCH_SIZE,
                             shuffle=True, collate_fn=collate_batch)

for epoch in range(1, EPOCHS + 1):
    epoch_start_time = time.time()
    train(train_dataloader)
    accu_val = evaluate(valid_dataloader)
    if total_accu is not None and total_accu > accu_val:
      scheduler.step()
    else:
       total_accu = accu_val
    print('-' * 59)
    print('| end of epoch {:3d} | time: {:5.2f}s | valid accuracy {:8.3f} '
          .format(epoch, time.time() - epoch_start_time, accu_val))
    print('-' * 59)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40

输出:

| epoch   1 |   500/ 1782 batches | accuracy    0.689
| epoch   1 |  1000/ 1782 batches | accuracy    0.856
| epoch   1 |  1500/ 1782 batches | accuracy    0.876
-----------------------------------------------------------
| end of epoch   1 | time:  8.17s | valid accuracy    0.882
-----------------------------------------------------------
| epoch   2 |   500/ 1782 batches | accuracy    0.897
| epoch   2 |  1000/ 1782 batches | accuracy    0.904
| epoch   2 |  1500/ 1782 batches | accuracy    0.900
-----------------------------------------------------------
| end of epoch   2 | time:  8.39s | valid accuracy    0.893
-----------------------------------------------------------
| epoch   3 |   500/ 1782 batches | accuracy    0.914
| epoch   3 |  1000/ 1782 batches | accuracy    0.916
| epoch   3 |  1500/ 1782 batches | accuracy    0.913
-----------------------------------------------------------
| end of epoch   3 | time:  8.44s | valid accuracy    0.903
-----------------------------------------------------------
| epoch   4 |   500/ 1782 batches | accuracy    0.924
| epoch   4 |  1000/ 1782 batches | accuracy    0.923
| epoch   4 |  1500/ 1782 batches | accuracy    0.924
-----------------------------------------------------------
| end of epoch   4 | time:  8.43s | valid accuracy    0.908
-----------------------------------------------------------
| epoch   5 |   500/ 1782 batches | accuracy    0.932
| epoch   5 |  1000/ 1782 batches | accuracy    0.930
| epoch   5 |  1500/ 1782 batches | accuracy    0.926
-----------------------------------------------------------
| end of epoch   5 | time:  8.37s | valid accuracy    0.903
-----------------------------------------------------------
| epoch   6 |   500/ 1782 batches | accuracy    0.941
| epoch   6 |  1000/ 1782 batches | accuracy    0.943
| epoch   6 |  1500/ 1782 batches | accuracy    0.941
-----------------------------------------------------------
| end of epoch   6 | time:  8.14s | valid accuracy    0.908
-----------------------------------------------------------
| epoch   7 |   500/ 1782 batches | accuracy    0.944
| epoch   7 |  1000/ 1782 batches | accuracy    0.942
| epoch   7 |  1500/ 1782 batches | accuracy    0.944
-----------------------------------------------------------
| end of epoch   7 | time:  8.15s | valid accuracy    0.907
-----------------------------------------------------------
| epoch   8 |   500/ 1782 batches | accuracy    0.943
| epoch   8 |  1000/ 1782 batches | accuracy    0.943
| epoch   8 |  1500/ 1782 batches | accuracy    0.945
-----------------------------------------------------------
| end of epoch   8 | time:  8.15s | valid accuracy    0.907
-----------------------------------------------------------
| epoch   9 |   500/ 1782 batches | accuracy    0.943
| epoch   9 |  1000/ 1782 batches | accuracy    0.944
| epoch   9 |  1500/ 1782 batches | accuracy    0.945
-----------------------------------------------------------
| end of epoch   9 | time:  8.15s | valid accuracy    0.907
-----------------------------------------------------------
| epoch  10 |   500/ 1782 batches | accuracy    0.943
| epoch  10 |  1000/ 1782 batches | accuracy    0.944
| epoch  10 |  1500/ 1782 batches | accuracy    0.945
-----------------------------------------------------------
| end of epoch  10 | time:  8.15s | valid accuracy    0.907
-----------------------------------------------------------
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60

Step 8:Evaluate the model with test dataset

检验模型在测试集上的效能

print('Checking the results of test dataset.')
accu_test = evaluate(test_dataloader)
print('test accuracy {:8.3f}'.format(accu_test))
  • 1
  • 2
  • 3

输出:

Checking the results of test dataset.
test accuracy    0.909
  • 1
  • 2

Step 9:Test on a random news

随机输入一段新闻,测试模型效果:

ag_news_label = {1: "World",
                 2: "Sports",
                 3: "Business",
                 4: "Sci/Tec"}

def predict(text, pipeline):
    with torch.no_grad():
        text = torch.tensor(pipeline(text))
        output = model(text, torch.tensor([0]))
        return output.argmax(1).item() + 1

ex_text_str = "MEMPHIS, Tenn. – Four days ago, Jon Rahm was \
    enduring the season’s worst weather conditions on Sunday at The \
    Open on his way to a closing 75 at Royal Portrush, which \
    considering the wind and the rain was a respectable showing. \
    Thursday’s first round at the WGC-FedEx St. Jude Invitational \
    was another story. With temperatures in the mid-80s and hardly any \
    wind, the Spaniard was 13 strokes better in a flawless round. \
    Thanks to his best putting performance on the PGA Tour, Rahm \
    finished with an 8-under 62 for a three-stroke lead, which \
    was even more impressive considering he’d never played the \
    front nine at TPC Southwind."

model = model.to('cpu')
res = predict(ex_text_str, text_pipeline)
print("This is a %s news" % ag_news_label[res])
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26

结果:

This is a Sports news
  • 1
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/小小林熬夜学编程/article/detail/344587
推荐阅读
相关标签
  

闽ICP备14008679号