当前位置:   article > 正文

fasttext 文本分类_深度学习系列––fasttext模型&帮助文档

fasttext classifier.test

e2d574496691978a40b6665e06590972.png

摘要

接着NLP/CV/领域的常见数据集介绍[1] ,本文开始介绍一个轻便建议而效果有很不错的模型——fasttext模型[2]。fastText是一种Facebook AI Research在16年开源的模型特点就是轻巧快捷其效果基本与textCNN相当甚至略好,而其训练速度会比其他模型的时耗节约数倍甚至数十倍,是一个非常值得推荐在工业简单场景应用的模型。

1 fasttext 模型原理

Joulin, A., Grave, E., Bojanowski, P., & Mikolov, T. (2016)[2]提出fasttext模型。fastText 模型架构和 Word2Vec 中的 CBOW 模型很类似,不同之处在于,fastText 预测标签,而 CBOW 模型预测中间词。一般情况下,使用fastText进行文本分类的同时也会产生词的embedding,即embedding是fastText分类的产物。

如图1所示,fastText的模型也是三层架构:输入层、 隐藏层、输出层(Hierarchical Softmax)。fastText的输入是多个单词及其n-gram特征,这些特征用来表示单个文档,将整个文本作为特征去预测文本对应的类别。

个人认为速度很快的原因有以下几个方面:

(1)模型总体只有三层,结构简单;

(2)文本表示的向量简单相加平均;

(3)在输出时,fastText采用了分层Softmax,大大降低了模型训练时间;

(4)我们直接调用的fastTextFacebook 2016年开源的一个词向量计算以及文本分类的工具,该项目是 C++ 写的。

1ca7d9be69ee475db65b0dc05aa4d412.png
图 1

2 fasttext 应用

fasttext是 Facebook fastText的python接口,fasttext包主要用途有两个:词向量表示、文本分类,对应使用数据结构 如:

财经:资管、私募、信托,傻傻分不清楚!__label__证券

实际上根据论文中给出的数据,fasttext模型在速度(图2) 和效果(图3)上都非常不错[2],尤其速度和模型大小是线上平响要求高的任务之所爱,在简易任务和获取embedding用于文本embedding初始化的话是值得推荐的。

c074a7789b0a2614edea532e539e7187.png

6c9762b1a75fab9a95640f74999d0aa4.png

(1)词向量表示

  1. import fasttext
  2. # skipgram model
  3. model = fasttext.skipgram('data.txt','model')
  4. print model.words # 输出为一个词向量字典
  5. # cbow model
  6. model = fasttext.cbow('data.txt', 'model')
  7. print model.words #
  8. 其中data.txt是一个utf-8编码文件,默认的n-grams范围:3-6
  9. 程序输出为两个文件:model.bin and model.vec
  10. model.vec 是一个每行为一个词向量的文本文件,model.bin是一个包含词典模型和所有超参数的二进制文件
  11. 获取OOV词的词向量表示
  12. print model['king'] # 获得单词king的词向量表示
  13. model.bin可以使用如下方式重建模型:
  14. model = fasttext.load_model('model.bin')
  15. print model.words # list of words in dictionary
  16. print model['king'] # get the vector of the word 'king'

(2) 文本分类

  1. 使用方式为:
  2. classifier = fasttext.supervised('data.train.txt', 'model', label_prefix='__label__') # 原作者使用模型方法
  3. 其中data.train.txt是每行包含标签的文本文件,默认标签前缀为__label__
  4. 模型建立好以后,可以用来检测测试集上的准确度:
  5. result = classifier.test('test.txt')
  6. print 'P@1:', result.precision # 准确率
  7. print 'R@1:', result.recall # 召回率
  8. print 'Number of examples:', result.nexamples # 测试样本数量
  9. 也可以使用训练好的模型进行预测:
  10. texts = ['example very long text 1', 'example very longtext 2']
  11. labels = classifier.predict(texts)
  12. print labels # 返回为一个二元数组[[labels]]
  13. # 或者同时包含概率值
  14. labels = classifier.predict_proba(texts)
  15. print labels
  16. 也可以返回最可能的k个标签值:
  17. labels = classifier.predict(texts, k=3)
  18. print labels
  19. # 同时包含概率值
  20. labels = classifier.predict_proba(texts, k=3)
  21. print labels
  22. ################################################################

3 fasttext python使用帮助文档

fasttext 有基于skipgram和cbow两种方式训练方式,以下给出一些常用的各个模型参数帮助文档:

(1)skipgram(params)

  1. input_file training file path (required) # 训练文件路径
  2. output output file path (required) # 输出文件路径
  3. lr learning rate [0.05] # 学习率
  4. lr_update_rate change the rate of updates for the learning rate [100] # 学习率的更新速度
  5. dim size of word vectors [100] # 词向量维度
  6. ws size of the context window [5] # 窗口宽度大小
  7. epoch number of epochs [5] # 迭代次数
  8. min_count minimal number of word occurences [5] # 最小词频数
  9. neg number of negatives sampled [5] # 负样本个数
  10. word_ngrams max length of word ngram [1] # 词ngram的最大长度
  11. loss loss function {ns, hs, softmax} [ns] # 损失函数
  12. bucket number of buckets [2000000] #
  13. minn min length of char ngram [3] # 字符ngram的最小长度
  14. maxn max length of char ngram [6] # 字符ngram的最大长度
  15. thread number of threads [12] # 线程数
  16. t sampling threshold [0.0001] #
  17. silent disable the log output from the C++ extension [1]
  18. encoding specify input_file encoding [utf-8] # 输入文件格式
  19. 示例说明:model = fasttext.skipgram('train.txt', 'model', lr=0.1, dim=300)

(2)cbow(params)

  1. input_file training file path (required)
  2. output output file path (required)
  3. lr learning rate [0.05]
  4. lr_update_rate change the rate of updates for the learning rate [100]
  5. dim size of word vectors [100]
  6. ws size of the context window [5]
  7. epoch number of epochs [5]
  8. min_count minimal number of word occurences [5]
  9. neg number of negatives sampled [5]
  10. word_ngrams max length of word ngram [1]
  11. loss loss function {ns, hs, softmax} [ns]
  12. bucket number of buckets [2000000]
  13. minn min length of char ngram [3]
  14. maxn max length of char ngram [6]
  15. thread number of threads [12]
  16. t sampling threshold [0.0001]
  17. silent disable the log output from the C++ extension [1]
  18. encoding specify input_file encoding [utf-8]
  19. 示例说明:model = fasttext.cbow('train.txt', 'model', lr=0.1, dim=300)

(3)skipgram和cbow模型的返回值字段参数

  1. model.model_name # Model name 模型名称
  2. model.words # List of words in the dictionary 词典单词向量列表
  3. model.dim # Size of word vector 词向量维度
  4. model.ws # Size of context window 内容窗口大小
  5. model.epoch # Number of epochs 迭代训练次数
  6. model.min_count # Minimal number of word occurences
  7. model.neg # Number of negative sampled 负样本个数
  8. model.word_ngrams # Max length of word ngram 词ngram的最大长度
  9. model.loss_name # Loss function name 损失函数名称
  10. model.bucket # Number of buckets
  11. model.minn # Min length of char ngram 字符ngram的最小长度
  12. model.maxn # Max length of char ngram 字符ngram的最大长度
  13. model.lr_update_rate # Rate of updates for the learning rate 学习率更新速度
  14. model.t # Value of sampling threshold 样本门限值
  15. model.encoding # Encoding of the model 模型编码
  16. model[word] # Get the vector of specified word 返回给定词的预测词向量

(4)supervised(params)

  1. input_file training file path (required) # 训练文件路径
  2. output output file path (required) # 输出文件路径
  3. label_prefix label prefix ['__label__'] # 标签前缀
  4. lr learning rate [0.1] # 学习率
  5. lr_update_rate change the rate of updates for the learning rate [100] # 学习率的更新速度
  6. dim size of word vectors [100] # 词向量维度
  7. ws size of the context window [5] # 内容窗口大小
  8. epoch number of epochs [5] # 迭代次数
  9. min_count minimal number of word occurences [1] 最小词频数
  10. neg number of negatives sampled [5] # 负样本个数
  11. word_ngrams max length of word ngram [1] # 词ngram的最大长度
  12. loss loss function {ns, hs, softmax} [softmax] # 损失函数
  13. bucket number of buckets [0]
  14. minn min length of char ngram [0] # 字符ngram的最小长度
  15. maxn max length of char ngram [0] # 字符ngram的最大长度
  16. thread number of threads [12]
  17. t sampling threshold [0.0001] #
  18. silent disable the log output from the C++ extension [1]
  19. encoding specify input_file encoding [utf-8] # 默认编码
  20. pretrained_vectors pretrained word vectors (.vec file) for supervised learning [] # 是否保持词向量输出文件model.vec,默认不保持
  21. 示例说明:classifier = fasttext.supervised('train.txt', 'model', label_prefix='__myprefix__', thread=4)

(5)supervised模型返回值字段参数

  1. classifier.labels # List of labels 标签列表
  2. classifier.label_prefix # Prefix of the label 标签前缀
  3. classifier.dim # Size of word vector 词向量维度
  4. classifier.ws # Size of context window 内容窗口大小
  5. classifier.epoch # Number of epochs 迭代次数
  6. classifier.min_count # Minimal number of word occurences
  7. classifier.neg # Number of negative sampled 负样本个数
  8. classifier.word_ngrams # Max length of word ngram 词ngram的最大长度
  9. classifier.loss_name # Loss function name 损失函数名称
  10. classifier.bucket # Number of buckets
  11. classifier.minn # Min length of char ngram 字符ngram的最小长度
  12. classifier.maxn # Max length of char ngram 字符ngram的最大长度
  13. classifier.lr_update_rate # Rate of updates for the learning rate 学习率的更新速度
  14. classifier.t # Value of sampling threshold
  15. classifier.encoding # Encoding that used by classifier 分类器使用编码
  16. classifier.test(filename, k) # Test the classifier 用分类器进行测试
  17. classifier.predict(texts, k) # Predict the most likely label 使用分类器进行文本预测
  18. classifier.predict_proba(texts, k) # Predict the most likely label include their probability 使用分类器进行文本预测类别并且返回他们的概率值

参考文献

[1] https://zhuanlan.zhihu.com/p/89520802:

debuluoyi:深度学习系列––NLP/CV常见数据集整理​zhuanlan.zhihu.com
af560b505a05c080ba131a8c0ab8284d.png

[2] Joulin, A., Grave, E., Bojanowski, P., & Mikolov, T. (2016). Bag of tricks for efficient text classification.arXiv preprint arXiv:1607.01759.

[3] https://github.com/debuluoyi:

debuluoyi - Overview​github.com
4cd52d4ec8143bdfc6a78ee704d2d374.png
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/凡人多烦事01/article/detail/448109
推荐阅读
相关标签
  

闽ICP备14008679号