当前位置:   article > 正文

Bag of Tricks for Efficient Text Classification(Fasttext)

bag of tricks for efficient text classification

Fasttext历史意义:

1、提出一种新的文本分类方法-Fasttext,能够快速进行文本分类,效果较好

2、提出一种新的使用子词的词向量训练方法,能够在一定程度上解决oov问题

3、将Fasttext开源使得工业界和学术界能够快速的使用Fasttext

 

深度学习文本分类模型:

优点:效果好,能达到非常好的效果,不用做特征工程,模型简洁

缺点:速度比较慢,无法在大规模的文本分类任务上应用

 

机器学习文本分类模型:

优点:速度一般都很快,模型都是线性分类器,比较简单;效果还可以,在某些任务上可以取得比较好的结果

缺点:需要做特征工程,分类效果依赖于有效特征的提取

 

本文主要结构:

一、Abstract

       提出一种简单的高效文本分类模型,效果和其它深度学习模型相当,但是速度快很多倍

二、Inrtroduction

       文本分类是自然语言处理的重要任务,可以用于信息检索,网页搜索、文档分类等;基于深度学习可以达到非常的好的效果,但是速度慢限制文本分类的应用;基于机器学习的线性分类器效果也很好,有用于大规模分类任务的潜力;从现在词向量中得到灵感,提出一种使用新的文本分类方法Fasttext,这种方法能够快速的训练和测试并且达到和最优结果相似的结果。

三、Model architerture

       详细介绍Fasttext的模型结构以及两个技巧,分别是层次softmax和n-gram特征     

模型结构如上图,与CBOW的模型结果相同,与CBOW模型的区别和联系如下所示:

联系:

     1)都是log-line模型,模型简单

     2)都是对输入的词向量做平均,然后进行预测

     3)模型结构完全一样

区别:

     1)  Fasttext提取的是句子特征,CBOW提取的是上下文特征 

      2)Fasttext需要标注语料,是监督学习,CBOW不需要人工标注语料,是无监督学习

Fasttext存在的问题:

      1)  当类别非常多的时候,最后的softmax速度比较慢(因为要构造词表大小的数据)

      2)  使用的是词袋模型,没有词序信息

解决办法:

       1) 层次softmax

             和word2vec中的层次softmax一样,可以减少参数由原来的H*V -> H*log2V (V表示词表大小)

        2) 添加使用n-gram特征

             输入模型的数据中添加了n-gram特征,并且用到了hash

             如果每一个词对应一个向量,那么词表太大;如果多个词对应一个向量,不够准确,所以构建hash方法

假如词表大小限制为10w,1-gram单词个数为3w,2-gram词组个数为10w,3-gram词组个数为40w,1-gram不用做hash,所以说词表10w中前3w是留给1-gram的,剩余7w个位置还有50w个词组没有位置安放,所以50w/7w约等于7,也就是说这50w个词组中大约有7个词对应同一个词向量。

              Fasttext另一篇文章中提到subword,主要是根据n-gram把词拆开进行预测

 

四、Experiments

       在文本分类任务上和tag预测任务上都取得了非常好的结果,效果和其它深度模型相差不多,但是速度上会快很多

五、Discussion and conclusion

       对论文进行一些总结

       关键点:

             基于深度学习的文本分类方法效果好,但速度比较慢;

             基于线性分类器的机器学习方法速度比较快,但是需要做更多的特征工程;

             提出Fasttext模型

        创新点:

             提出一种新的文本分类模型Fasttext;

             提出一些加快和使得文本分类效果更高的技巧-层次softmax和n-gram特征;

             在文本分类任务上和tag预测两个任务上都取得了又快又好的结果。

        启发点:

            虽然深度学习能够取得非常好的结果,但是在训练和测试的时候,非常慢限制了他们在大数据集上的应用(模型不一定在效果上大幅度提升,效果差不多,速度大幅度提升也是一种创新);

            然而线性分类器不同特征和类别之间不共享参数,可能限制了一些只有少量样本类别的泛化能力(共享词向量);

            大部分词向量方法对每个词分配一个独立的词向量,没有共享参数,特别是这些方法忽略之间的联系,而对于形态学丰富的语言更加重要。

 

六、代码实现

  1. # ****** 数据预处理 *****
  2. # 主要包括几个部分-数据集加载、读取标签和数据、创建word2id、将数据转化为id, 本次实验还是使用AG数据集合,数据集下载位置 AG News: https://s3.amazonaws.com/fast-ai-nlp/ag_news_csv.tgz
  3. # encoding = 'utf-8'
  4. from torch.utils import data
  5. import os
  6. import csv
  7. import nltk
  8. import numpy as np
  9. # 数据集加载
  10. f = open("./data/AG/train.csv")
  11. rows = csv.reader(f,delimiter=',',quotechar='"')
  12. rows = list(rows)
  13. rows[1:5]
  14. [['3',
  15. 'Carlyle Looks Toward Commercial Aerospace (Reuters)',
  16. 'Reuters - Private investment firm Carlyle Group,\\which has a reputation for making well-timed and occasionally\\controversial plays in the defense industry, has quietly placed\\its bets on another part of the market.'],
  17. ['3',
  18. "Oil and Economy Cloud Stocks' Outlook (Reuters)",
  19. 'Reuters - Soaring crude prices plus worries\\about the economy and the outlook for earnings are expected to\\hang over the stock market next week during the depth of the\\summer doldrums.'],
  20. ['3',
  21. 'Iraq Halts Oil Exports from Main Southern Pipeline (Reuters)',
  22. 'Reuters - Authorities have halted oil export\\flows from the main pipeline in southern Iraq after\\intelligence showed a rebel militia could strike\\infrastructure, an oil official said on Saturday.'],
  23. ['3',
  24. 'Oil prices soar to all-time record, posing new menace to US economy (AFP)',
  25. 'AFP - Tearaway world oil prices, toppling records and straining wallets, present a new economic menace barely three months before the US presidential elections.']]
  26. # 读取标签和数据
  27. n_gram,lowercase,label,datas = 2,True,[],[]
  28. for row in rows:
  29. label.append(int(row[0])-1)
  30. txt = " ".join(row[1:])
  31. if lowercase:
  32. txt = txt.lower()
  33. txt = nltk.word_tokenize(txt) #将句子转化为词
  34. new_txt = []
  35. for i in range(len(txt)):
  36. for j in range(n_gram):
  37. if j<=i:
  38. new_txt.append(" ".join(txt[i-j:i+1]))
  39. datas.append(new_txt)
  40. # word2id
  41. min_count,word_freq = 3,{}
  42. for data in datas:
  43. for word in data:
  44. if word not in word_freq:
  45. word_freq[word] = 1
  46. else:
  47. word_freq[word] += 1
  48. # 首先构建uni-gram,不需要hash
  49. word2id = {"<pad>":0,"<unk>":1}
  50. for word in word_freq:
  51. if word_freq[word] < min_count or " " in word:
  52. continue
  53. word2id[word] = len(word2id)
  54. uniwords_num = len(word2id)
  55. # 构建2-gram以上的词,需要hash
  56. for word in word_freq:
  57. if word_freq[word] < min_count or " " not in word:
  58. continue
  59. word2id[word] = len(word2id)
  60. # 将文本中的词都转化为id,设置句子长度为100,词表最大限制为1w
  61. max_length = 100
  62. for i,data in enumerate(datas):
  63. for j,word in enumerate(data):
  64. if " " not in word:
  65. datas[i][j] = word2id.get(word,1)
  66. else:
  67. datas[i][j] = word2id.get(word,1)%10000 + uniwords_num
  68. datas[i] = datas[i][0:max_length] + [0]*(max_length - len(datas[i]))

模型细节:

  1. # """ 模型代码 """
  2. # encoding='utf-8'
  3. import torch
  4. import torch.nn as nn
  5. import numpy as np
  6. class Fasttext(nn.Module):
  7. def __init__(self,vocab_size,embedding_size,max_length,label_num):
  8. super(Fasttext,self).__init__()
  9. self.embedding = nn.Embedding(vocab_size,embedding_size)
  10. self.avg_pool = nn.AvgPool1d(kernel_size=max_length,stride=1)
  11. self.fc = nn.Linear(embedding_size,label_num)
  12. def forward(self,x):
  13. x = x.long()
  14. out = self.embedding(x) # batch_size * length * embedding_size
  15. out = out.transpose(1,2).contiguous() # batch_size * embedding_size * length
  16. out = self.avg_pool(out).squeeze() # batch_size * embedding_size
  17. out = self.fc(out) # batch_size * label_num
  18. return out
  19. fasttext = Fasttext(vocab_size=1000,embedding_size=10,max_length=100,label_num=4)
  20. test = torch.zeros([64,100]).long()
  21. out = fasttext(test)
  22. """ 查看网络参数 """
  23. from torchsummary import summary
  24. summary(fasttext,input_size=(100,))
  1. """ 模型训练 """
  2. # encoding = 'utf-8'
  3. import torch
  4. import torch.autograd as autograd
  5. import torch.nn as nn
  6. import torch.optim as optim
  7. from model import Fasttext
  8. from data import AG_Data
  9. import numpy as np
  10. from tqdm import tqdm
  11. import config as argumentparser
  12. config = argumentparser.ArgumentParser()
  13. """ 加载数据集 """
  14. training_set = AG_Data("/AG/train.csv",
  15. min_count = config.min_count,
  16. max_length=config.max_length,
  17. n_gram=config.n_gram)
  18. train_iter = torch.utils.data.DataLoader(dataset=training_set,
  19. batch_size=config.batch_size,
  20. shuffle=True,
  21. num_workers=0)
  22. test_set = AG_Data(data_path="/AG/test.csv",
  23. min_count=config.min_count,
  24. max_length = config.max_length,
  25. n_gram = config.n_gram,
  26. word2id = training_set.word2id,
  27. uniwords_num=training_set.uniwords_num)
  28. test_iter = torch.utils.data.DataLoader(dataset=test_set,
  29. batch_size=config.batch_size,
  30. shuffle = True,
  31. num_workers=0)
  32. """ 构建模型 """
  33. model = Fasttext(vocab_size=training_set.uniwords_num+100000,
  34. embedding_size=config.embed_size,
  35. max_length=config.max_length,
  36. label_num=config.label_num)
  37. if config.cuda and torch.cuda.is_available():
  38. model.cuda()
  39. criterion = nn.CrossEntropyLoss()
  40. optimizer = optim.Adam(model.parameters(),lr=config.learning_rate)
  41. loss = -1
  42. def get_test_result(data_iter,data_set):
  43. mode.eval()
  44. true_sample_num = 0
  45. for data,label in data_iter:
  46. if config.cuda and torch.cuda.is_available():
  47. data = data.cuda()
  48. label = label.cuda()
  49. else:
  50. data = torch.autograd.Variable(data).long()
  51. out = model(data)
  52. true_sample_num += np.sum((torch.argmax(out,1)==label.long()).cpu().numpy())
  53. acc = true_sample_num/data_set.__len__()
  54. return acc
  55. for epoch in range(1):
  56. model.train()
  57. process_bar = tqdm(train_iter)
  58. for data,label in process_bar:
  59. if config.cuda and torch.cuda.is_available():
  60. data = data.cuda()
  61. label = label.cuda()
  62. else:
  63. data = torch.autograd.Variable(data).long()
  64. label = torch.autograd.Variable(label).squeeze()
  65. out = model(data)
  66. loss_now = criterion(out,autograd.Variable(label.long()))
  67. if loss == -1:
  68. loss = loss_now.data.item()
  69. else:
  70. loss = 0.95*loss + 0.05*loss_now.data.item()
  71. process_bar.set_postfix(loss=loss_now.data.item())
  72. process_bar.update()
  73. optimizer.zero_grad()
  74. loss_now.backward()
  75. optimizer.step()
  76. test_acc = get_test_result(test_iter,test_set)
  77. print("The test acc is: %.5f" % test_acc)

完整代码,详见git:

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/菜鸟追梦旅行/article/detail/350908
推荐阅读
相关标签
  

闽ICP备14008679号