赞
踩
我们使用AG_NEWS数据集实现一个简单的文本分类模型(Text Classification)。首先,这是我们要用到的库。
import torch
import torch.nn as nn
import torchtext
from torchtext.datasets import AG_NEWS
import os
from collections import Counter, OrderedDict
然后我们加载训练集与测试集。
os.makedirs('./data',exist_ok=True)
train_dataset, test_dataset = AG_NEWS(root='./data', split=('train', 'test'))
classes = ['World', 'Sports', 'Business', 'Sci/Tech']
这里也贴一下训练集与测试集的下载链接,如果上述方式无法下载的话。
https://raw.githubusercontent.com/mhjabreel/CharCnn_Keras/master/data/ag_news_csv/train.csv
https://raw.githubusercontent.com/mhjabreel/CharCnn_Keras/master/data/ag_news_csv/test.csv
我们试着输出训练集前五个样本:
for i,x in zip(range(5),train_dataset):
print(f"**{classes[x[0]-1]}** -> {x[1]}\n")
输出结果为
**Business** -> Wall St. Bears Claw Back Into the Black (Reuters) Reuters - Short-sellers, Wall Street's dwindling\band of ultra-cynics, are seeing green again.
**Business** -> Carlyle Looks Toward Commercial Aerospace (Reuters) Reuters - Private investment firm Carlyle Group,\which has a reputation for making well-timed and occasionally\controversial plays in the defense industry, has quietly placed\its bets on another part of the market.
**Business** -> Oil and Economy Cloud Stocks' Outlook (Reuters) Reuters - Soaring crude prices plus worries\about the economy and the outlook for earnings are expected to\hang over the stock market next week during the depth of the\summer doldrums.
**Business** -> Iraq Halts Oil Exports from Main Southern Pipeline (Reuters) Reuters - Authorities have halted oil export\flows from the main pipeline in southern Iraq after\intelligence showed a rebel militia could strike\infrastructure, an oil official said on Saturday.
**Business** -> Oil prices soar to all-time record, posing new menace to US economy (AFP) AFP - Tearaway world oil prices, toppling records and straining wallets, present a new economic menace barely three months before the US presidential elections.
把训练集和测试集转化为列表对象
train_dataset = list(train_dataset)
test_dataset = list(test_dataset)
选择分词器
tokenizer = torchtext.data.utils.get_tokenizer('basic_english')
构建词表:
counter = Counter()
for (label, line) in train_dataset:
counter.update(tokenizer(line))
# 构建按词频降序排列的列表
order_dict =OrderedDict(sorted(counter.items(), key=lambda x:x[1], reverse=True))
# 构建以word-indice为键值对的词表
vocab = torchtext.vocab.vocab(order_dict, min_freq=1)
vocab_size = len(vocab)
我们知道在处理文本数据时,每一个句子长度是可能在变化的,我们可以在句尾增加padding,从而补齐。
# 处理文本时要注意,句子长度可能会变化,我们可以填充为最大长度
def padify(b):
# b is the list of tuples of length batch_size
# - first element of a tuple = label,
# - second = feature (text sequence)
# build vectorized sequence
v = [vocab.lookup_indices(tokenizer(x[1])) for x in b]
# first, compute max length of a sequence in this minibatch
l = max(map(len,v))
return ( # tuple of two tensors - labels and features
torch.LongTensor([t[0]-1 for t in b]),
torch.stack([torch.nn.functional.pad(torch.tensor(t),(0,l-len(t)),mode='constant',value=0) for t in v]))
下面是训练集的dataloader
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=16, collate_fn=padify, shuffle=True)
由于本次是文本分类实战的第一篇,我们使用一个极为简单的模型,在word embedding后经过一个全连接层,即将一个sentence的各个词向量取一个平均,再通过一个fully-connected层并进行softmax操作。
class EmbedClassifier(nn.Module):
def __init__(self, vocab_size, embed_dim, num_class):
super().__init__()
self.embedding = nn.Embedding(vocab_size, embed_dim)
self.fc = nn.Linear(embed_dim, num_class)
def forward(self, x):
x = self.embedding(x)
x = torch.mean(x, dim=1) # 把sentence中各个词向量取平均
return self.fc(x)
下面构建训练函数,损失函数用交叉熵函数,优化算法使用Adam方法。
def train_epoch(net,dataloader,lr,optimizer=None,loss_fn = None,epoch_size=None, report_freq=200): optimizer = optimizer or torch.optim.Adam(net.parameters(),lr=lr) loss_fn = loss_fn or nn.CrossEntropyLoss() net.train() # 训练模式 total_loss, acc, count, i = 0, 0, 0, 0 for labels, features in dataloader: optimizer.zero_grad() out = net(features) loss = loss_fn(out,labels) #cross_entropy(out,labels),交叉熵函数自带了softmax运算 loss.backward() optimizer.step() total_loss+=loss predicted = torch.argmax(out,1) # 1指每一列的最大值 acc+=(predicted==labels).sum() count+=len(labels) i+=1 if i%report_freq==0: print(f"{count}: acc={acc.item()/count}") if epoch_size and count>epoch_size: break print(f'loss is {total_loss.item()/count}') print(f'accuracy is {acc.item()/count*100}%')
下面训练我们的模型实例
net = EmbedClassifier(vocab_size, 32, len(classes))
train_epoch(net,train_loader, lr=1, epoch_size=25000)
输出结果为
3200: acc=0.641875
6400: acc=0.67984375
9600: acc=0.7048958333333334
12800: acc=0.71765625
16000: acc=0.728375
19200: acc=0.7410416666666667
22400: acc=0.7509375
loss is 0.9129715990882917
accuracy is 75.58381317978247%
可见经过一个epoch训练后,在训练集上准确率达到了75%左右,这样一个简单的模型性能还是不错的。
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。