赞
踩
Fasttext历史意义:
1、提出一种新的文本分类方法-Fasttext,能够快速进行文本分类,效果较好
2、提出一种新的使用子词的词向量训练方法,能够在一定程度上解决oov问题
3、将Fasttext开源使得工业界和学术界能够快速的使用Fasttext
深度学习文本分类模型:
优点:效果好,能达到非常好的效果,不用做特征工程,模型简洁
缺点:速度比较慢,无法在大规模的文本分类任务上应用
机器学习文本分类模型:
优点:速度一般都很快,模型都是线性分类器,比较简单;效果还可以,在某些任务上可以取得比较好的结果
缺点:需要做特征工程,分类效果依赖于有效特征的提取
本文主要结构:
提出一种简单的高效文本分类模型,效果和其它深度学习模型相当,但是速度快很多倍
文本分类是自然语言处理的重要任务,可以用于信息检索,网页搜索、文档分类等;基于深度学习可以达到非常的好的效果,但是速度慢限制文本分类的应用;基于机器学习的线性分类器效果也很好,有用于大规模分类任务的潜力;从现在词向量中得到灵感,提出一种使用新的文本分类方法Fasttext,这种方法能够快速的训练和测试并且达到和最优结果相似的结果。
详细介绍Fasttext的模型结构以及两个技巧,分别是层次softmax和n-gram特征
模型结构如上图,与CBOW的模型结果相同,与CBOW模型的区别和联系如下所示:
联系:
1)都是log-line模型,模型简单
2)都是对输入的词向量做平均,然后进行预测
3)模型结构完全一样
区别:
1) Fasttext提取的是句子特征,CBOW提取的是上下文特征
2)Fasttext需要标注语料,是监督学习,CBOW不需要人工标注语料,是无监督学习
Fasttext存在的问题:
1) 当类别非常多的时候,最后的softmax速度比较慢(因为要构造词表大小的数据)
2) 使用的是词袋模型,没有词序信息
解决办法:
1) 层次softmax
和word2vec中的层次softmax一样,可以减少参数由原来的H*V -> H*log2V (V表示词表大小)
2) 添加使用n-gram特征
输入模型的数据中添加了n-gram特征,并且用到了hash
如果每一个词对应一个向量,那么词表太大;如果多个词对应一个向量,不够准确,所以构建hash方法
假如词表大小限制为10w,1-gram单词个数为3w,2-gram词组个数为10w,3-gram词组个数为40w,1-gram不用做hash,所以说词表10w中前3w是留给1-gram的,剩余7w个位置还有50w个词组没有位置安放,所以50w/7w约等于7,也就是说这50w个词组中大约有7个词对应同一个词向量。
Fasttext另一篇文章中提到subword,主要是根据n-gram把词拆开进行预测
在文本分类任务上和tag预测任务上都取得了非常好的结果,效果和其它深度模型相差不多,但是速度上会快很多
对论文进行一些总结
关键点:
基于深度学习的文本分类方法效果好,但速度比较慢;
基于线性分类器的机器学习方法速度比较快,但是需要做更多的特征工程;
提出Fasttext模型
创新点:
提出一种新的文本分类模型Fasttext;
提出一些加快和使得文本分类效果更高的技巧-层次softmax和n-gram特征;
在文本分类任务上和tag预测两个任务上都取得了又快又好的结果。
启发点:
虽然深度学习能够取得非常好的结果,但是在训练和测试的时候,非常慢限制了他们在大数据集上的应用(模型不一定在效果上大幅度提升,效果差不多,速度大幅度提升也是一种创新);
然而线性分类器不同特征和类别之间不共享参数,可能限制了一些只有少量样本类别的泛化能力(共享词向量);
大部分词向量方法对每个词分配一个独立的词向量,没有共享参数,特别是这些方法忽略之间的联系,而对于形态学丰富的语言更加重要。
- # ****** 数据预处理 *****
-
- # 主要包括几个部分-数据集加载、读取标签和数据、创建word2id、将数据转化为id, 本次实验还是使用AG数据集合,数据集下载位置 AG News: https://s3.amazonaws.com/fast-ai-nlp/ag_news_csv.tgz
-
- # encoding = 'utf-8'
-
- from torch.utils import data
- import os
- import csv
- import nltk
- import numpy as np
-
- # 数据集加载
-
- f = open("./data/AG/train.csv")
- rows = csv.reader(f,delimiter=',',quotechar='"')
- rows = list(rows)
- rows[1:5]
-
- [['3',
- 'Carlyle Looks Toward Commercial Aerospace (Reuters)',
- 'Reuters - Private investment firm Carlyle Group,\\which has a reputation for making well-timed and occasionally\\controversial plays in the defense industry, has quietly placed\\its bets on another part of the market.'],
- ['3',
- "Oil and Economy Cloud Stocks' Outlook (Reuters)",
- 'Reuters - Soaring crude prices plus worries\\about the economy and the outlook for earnings are expected to\\hang over the stock market next week during the depth of the\\summer doldrums.'],
- ['3',
- 'Iraq Halts Oil Exports from Main Southern Pipeline (Reuters)',
- 'Reuters - Authorities have halted oil export\\flows from the main pipeline in southern Iraq after\\intelligence showed a rebel militia could strike\\infrastructure, an oil official said on Saturday.'],
- ['3',
- 'Oil prices soar to all-time record, posing new menace to US economy (AFP)',
- 'AFP - Tearaway world oil prices, toppling records and straining wallets, present a new economic menace barely three months before the US presidential elections.']]
-
- # 读取标签和数据
-
- n_gram,lowercase,label,datas = 2,True,[],[]
-
- for row in rows:
- label.append(int(row[0])-1)
- txt = " ".join(row[1:])
- if lowercase:
- txt = txt.lower()
- txt = nltk.word_tokenize(txt) #将句子转化为词
- new_txt = []
- for i in range(len(txt)):
- for j in range(n_gram):
- if j<=i:
- new_txt.append(" ".join(txt[i-j:i+1]))
- datas.append(new_txt)
-
-
- # word2id
-
- min_count,word_freq = 3,{}
- for data in datas:
- for word in data:
- if word not in word_freq:
- word_freq[word] = 1
- else:
- word_freq[word] += 1
-
-
- # 首先构建uni-gram,不需要hash
-
- word2id = {"<pad>":0,"<unk>":1}
- for word in word_freq:
- if word_freq[word] < min_count or " " in word:
- continue
- word2id[word] = len(word2id)
-
- uniwords_num = len(word2id)
-
-
- # 构建2-gram以上的词,需要hash
-
- for word in word_freq:
- if word_freq[word] < min_count or " " not in word:
- continue
- word2id[word] = len(word2id)
-
-
- # 将文本中的词都转化为id,设置句子长度为100,词表最大限制为1w
- max_length = 100
-
- for i,data in enumerate(datas):
- for j,word in enumerate(data):
- if " " not in word:
- datas[i][j] = word2id.get(word,1)
- else:
- datas[i][j] = word2id.get(word,1)%10000 + uniwords_num
-
- datas[i] = datas[i][0:max_length] + [0]*(max_length - len(datas[i]))
-
-
-

模型细节:
- # """ 模型代码 """
-
- # encoding='utf-8'
-
- import torch
- import torch.nn as nn
- import numpy as np
-
-
- class Fasttext(nn.Module):
- def __init__(self,vocab_size,embedding_size,max_length,label_num):
- super(Fasttext,self).__init__()
- self.embedding = nn.Embedding(vocab_size,embedding_size)
- self.avg_pool = nn.AvgPool1d(kernel_size=max_length,stride=1)
- self.fc = nn.Linear(embedding_size,label_num)
-
- def forward(self,x):
- x = x.long()
- out = self.embedding(x) # batch_size * length * embedding_size
- out = out.transpose(1,2).contiguous() # batch_size * embedding_size * length
- out = self.avg_pool(out).squeeze() # batch_size * embedding_size
- out = self.fc(out) # batch_size * label_num
-
- return out
-
- fasttext = Fasttext(vocab_size=1000,embedding_size=10,max_length=100,label_num=4)
- test = torch.zeros([64,100]).long()
- out = fasttext(test)
-
-
-
- """ 查看网络参数 """
- from torchsummary import summary
-
- summary(fasttext,input_size=(100,))
-

- """ 模型训练 """
-
- # encoding = 'utf-8'
-
-
- import torch
- import torch.autograd as autograd
- import torch.nn as nn
- import torch.optim as optim
- from model import Fasttext
- from data import AG_Data
- import numpy as np
- from tqdm import tqdm
-
- import config as argumentparser
- config = argumentparser.ArgumentParser()
-
-
- """ 加载数据集 """
-
- training_set = AG_Data("/AG/train.csv",
- min_count = config.min_count,
- max_length=config.max_length,
- n_gram=config.n_gram)
-
- train_iter = torch.utils.data.DataLoader(dataset=training_set,
- batch_size=config.batch_size,
- shuffle=True,
- num_workers=0)
-
- test_set = AG_Data(data_path="/AG/test.csv",
- min_count=config.min_count,
- max_length = config.max_length,
- n_gram = config.n_gram,
- word2id = training_set.word2id,
- uniwords_num=training_set.uniwords_num)
-
- test_iter = torch.utils.data.DataLoader(dataset=test_set,
- batch_size=config.batch_size,
- shuffle = True,
- num_workers=0)
-
- """ 构建模型 """
- model = Fasttext(vocab_size=training_set.uniwords_num+100000,
- embedding_size=config.embed_size,
- max_length=config.max_length,
- label_num=config.label_num)
-
-
- if config.cuda and torch.cuda.is_available():
- model.cuda()
-
- criterion = nn.CrossEntropyLoss()
- optimizer = optim.Adam(model.parameters(),lr=config.learning_rate)
- loss = -1
-
- def get_test_result(data_iter,data_set):
-
- mode.eval()
- true_sample_num = 0
- for data,label in data_iter:
- if config.cuda and torch.cuda.is_available():
- data = data.cuda()
- label = label.cuda()
- else:
- data = torch.autograd.Variable(data).long()
-
- out = model(data)
- true_sample_num += np.sum((torch.argmax(out,1)==label.long()).cpu().numpy())
- acc = true_sample_num/data_set.__len__()
-
- return acc
-
- for epoch in range(1):
- model.train()
- process_bar = tqdm(train_iter)
- for data,label in process_bar:
- if config.cuda and torch.cuda.is_available():
- data = data.cuda()
- label = label.cuda()
- else:
- data = torch.autograd.Variable(data).long()
- label = torch.autograd.Variable(label).squeeze()
- out = model(data)
- loss_now = criterion(out,autograd.Variable(label.long()))
-
- if loss == -1:
- loss = loss_now.data.item()
- else:
- loss = 0.95*loss + 0.05*loss_now.data.item()
-
- process_bar.set_postfix(loss=loss_now.data.item())
- process_bar.update()
- optimizer.zero_grad()
- loss_now.backward()
- optimizer.step()
-
- test_acc = get_test_result(test_iter,test_set)
-
- print("The test acc is: %.5f" % test_acc)

完整代码,详见git:
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。