当前位置:   article > 正文

AG_NEWS数据集文本分类实战(一)_agnews数据集

agnews数据集

AG_NEWS数据集文本分类实战(一)

一、数据集加载

我们使用AG_NEWS数据集实现一个简单的文本分类模型Text Classification)。首先,这是我们要用到的库。

import torch
import torch.nn as nn
import torchtext
from torchtext.datasets import AG_NEWS
import os
from collections import Counter, OrderedDict
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6

然后我们加载训练集与测试集。

os.makedirs('./data',exist_ok=True)
train_dataset, test_dataset = AG_NEWS(root='./data', split=('train', 'test'))
classes = ['World', 'Sports', 'Business', 'Sci/Tech']
  • 1
  • 2
  • 3

这里也贴一下训练集与测试集的下载链接,如果上述方式无法下载的话。

https://raw.githubusercontent.com/mhjabreel/CharCnn_Keras/master/data/ag_news_csv/train.csv
https://raw.githubusercontent.com/mhjabreel/CharCnn_Keras/master/data/ag_news_csv/test.csv
  • 1
  • 2

我们试着输出训练集前五个样本:

for i,x in zip(range(5),train_dataset):
    print(f"**{classes[x[0]-1]}** -> {x[1]}\n")
  • 1
  • 2

输出结果为

**Business** -> Wall St. Bears Claw Back Into the Black (Reuters) Reuters - Short-sellers, Wall Street's dwindling\band of ultra-cynics, are seeing green again.

**Business** -> Carlyle Looks Toward Commercial Aerospace (Reuters) Reuters - Private investment firm Carlyle Group,\which has a reputation for making well-timed and occasionally\controversial plays in the defense industry, has quietly placed\its bets on another part of the market.

**Business** -> Oil and Economy Cloud Stocks' Outlook (Reuters) Reuters - Soaring crude prices plus worries\about the economy and the outlook for earnings are expected to\hang over the stock market next week during the depth of the\summer doldrums.

**Business** -> Iraq Halts Oil Exports from Main Southern Pipeline (Reuters) Reuters - Authorities have halted oil export\flows from the main pipeline in southern Iraq after\intelligence showed a rebel militia could strike\infrastructure, an oil official said on Saturday.

**Business** -> Oil prices soar to all-time record, posing new menace to US economy (AFP) AFP - Tearaway world oil prices, toppling records and straining wallets, present a new economic menace barely three months before the US presidential elections.
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9

把训练集和测试集转化为列表对象

train_dataset = list(train_dataset)
test_dataset = list(test_dataset)
  • 1
  • 2

二、词表与DataLoader的构建

选择分词器

tokenizer = torchtext.data.utils.get_tokenizer('basic_english')
  • 1

构建词表:

counter = Counter()
for (label, line) in train_dataset:
    counter.update(tokenizer(line))
# 构建按词频降序排列的列表
order_dict =OrderedDict(sorted(counter.items(), key=lambda x:x[1], reverse=True))
# 构建以word-indice为键值对的词表
vocab = torchtext.vocab.vocab(order_dict, min_freq=1)
vocab_size = len(vocab)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8

我们知道在处理文本数据时,每一个句子长度是可能在变化的,我们可以在句尾增加padding,从而补齐。

# 处理文本时要注意,句子长度可能会变化,我们可以填充为最大长度
def padify(b):
    # b is the list of tuples of length batch_size
    #   - first element of a tuple = label, 
    #   - second = feature (text sequence)
    # build vectorized sequence
    v = [vocab.lookup_indices(tokenizer(x[1])) for x in b]
    # first, compute max length of a sequence in this minibatch
    l = max(map(len,v))
    return ( # tuple of two tensors - labels and features
        torch.LongTensor([t[0]-1 for t in b]),
        torch.stack([torch.nn.functional.pad(torch.tensor(t),(0,l-len(t)),mode='constant',value=0) for t in v]))
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12

下面是训练集的dataloader

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=16, collate_fn=padify, shuffle=True)
  • 1

三、模型构建

由于本次是文本分类实战的第一篇,我们使用一个极为简单的模型,在word embedding后经过一个全连接层,即将一个sentence的各个词向量取一个平均,再通过一个fully-connected层并进行softmax操作。

class EmbedClassifier(nn.Module):
    def __init__(self, vocab_size, embed_dim, num_class):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.fc = nn.Linear(embed_dim, num_class)
    def forward(self, x):
        x = self.embedding(x)
        x = torch.mean(x, dim=1)  # 把sentence中各个词向量取平均
        return self.fc(x)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9

下面构建训练函数,损失函数用交叉熵函数,优化算法使用Adam方法。

def train_epoch(net,dataloader,lr,optimizer=None,loss_fn = None,epoch_size=None, report_freq=200):
    optimizer = optimizer or torch.optim.Adam(net.parameters(),lr=lr)
    loss_fn = loss_fn or nn.CrossEntropyLoss()
    net.train()  # 训练模式
    total_loss, acc, count, i = 0, 0, 0, 0
    for labels, features in dataloader:
        optimizer.zero_grad()
        out = net(features)
        loss = loss_fn(out,labels) #cross_entropy(out,labels),交叉熵函数自带了softmax运算
        loss.backward()
        optimizer.step()
        total_loss+=loss
        predicted = torch.argmax(out,1)  # 1指每一列的最大值
        acc+=(predicted==labels).sum()
        count+=len(labels)
        i+=1
        if i%report_freq==0:
            print(f"{count}: acc={acc.item()/count}")
        if epoch_size and count>epoch_size:
            break
    print(f'loss is {total_loss.item()/count}')
    print(f'accuracy is {acc.item()/count*100}%')
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22

下面训练我们的模型实例

net = EmbedClassifier(vocab_size, 32, len(classes))
train_epoch(net,train_loader, lr=1, epoch_size=25000)
  • 1
  • 2

输出结果为

3200: acc=0.641875
6400: acc=0.67984375
9600: acc=0.7048958333333334
12800: acc=0.71765625
16000: acc=0.728375
19200: acc=0.7410416666666667
22400: acc=0.7509375
loss is 0.9129715990882917
accuracy is 75.58381317978247%

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10

可见经过一个epoch训练后,在训练集上准确率达到了75%左右,这样一个简单的模型性能还是不错的。

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/我家小花儿/article/detail/344594?site
推荐阅读
相关标签
  

闽ICP备14008679号