赞
踩
fastText是Facebook Research在2016年开源的一个词向量及文本分类工具,今天这篇文章主要使用fasttext在来做文本分类,测试fasttext用于分类的实际效果。
本文所使用的数据及代码均已上传至GitHub
传送门: fasttext_classify
由于数据集太大了,无法上传至GitHub,数据集链接:fasttext分类数据集
百度云:链接
提取码:96fu
在windows上安装fasttext得去https://www.lfd.uci.edu/~gohlke/pythonlibs/#fasttext下载对应python版本的whl,然后在命令行使用pip install xxx.whl
安装。
fasttext要求的数据格式有点奇怪,需要将label处理成__label__
的格式,假如你原本的label为0
,则需要将其处理成__label__0
。
格式如下:
“并且 在 世界 范围 内 广为流传” + \t + __label__0
这样也可以:
__label__0 + \t + “并且 在 世界 范围 内 广为流传”
不知道 \t
换成 ,
行不行,感兴趣的同学可以试下
下面进入正题:
首先我们观察下数据:
import pandas as pd
import numpy as np
data = pd.read_csv('./data/train.csv', sep='\t')
list(np.unique(data['label']))
总共35个类别
[' 文化', '中小学教辅','传记', '健身与保健', '农业/林业', '动漫', '励志与成功', '医学', '历史', '哲学/宗教', '国学/古籍','外语学习', '大中专教材教辅',
'婚恋与两性', '孕产/胎教', '小说', '工业技术', '建筑', '政治/军事', '文学', '旅游/地图', '法律', '烹饪/美食', '社会科学',
'科学与自然', '科普读物', '童书', '管理', '经济', '考试','育儿/家教', '艺术', '计算机与互联网', '金融与投资', '青春文学']
统计下各标签的数据量:
from collections import Counter
Counter(data['label'])
Counter(data['label']).most_common(5)
Counter({'文学': 13469, '童书': 5996, '大中专教材教辅': 5396, '工业技术': 3292, '中小学教辅': 2603, '
艺术': 2397, '社会科学': 2317, '小说': 2191, '计算机与互联网': 2054, '管理': 1852, '建筑': 1788, '外语
学习': 1494, '历史': 1455, '科学与自然': 1421, '法律': 1256, '政治/军事': 1210, '哲学/宗教': 1012, '医
学': 998, '经济': 938, '励志与成功': 921, '考试': 869, '传记': 761, '青春文学': 746, ' 文化': 707, '农
业/林业': 567, '动漫': 442, '育儿/家教': 390, '烹饪/美食': 375, '国学/古籍': 357, '旅游/地图': 354, '
健身与保健': 348, '科普读物': 329, '孕产/胎教': 301, '金融与投资': 186, '婚恋与两性': 63})
[('文学', 49868),
('童书', 18926),
('工业技术', 15714),
('大中专教材教辅', 12229),
('艺术', 10104)]
看来标签不平衡问题挺严重的。
由于时间问题,我们就不去做数据不平衡的相关处理了,感兴趣的同学可以去了解一下过采样、欠采样
这里我们直接挑选几个类别出来尝试下,可以看出标签数量前五的类别中,童书、工业技术、大中专教材教辅
三个类别的数据量相差不是很大,于是我们挑选出这三个类别来训练我们的三分类模型。
数据处理代码:
def extract_three_cls_data(data_path,save_path, txt_save_path): map_path = './base_fasttext/data/three_class/map.json' data = pd.read_csv(data_path, sep='\t') cls_data = data[(data['label'] == '童书') | (data['label'] == '工业技术') | (data['label'] == '大中专教材教辅')] cls_data.index = range(len(cls_data)) print(Counter(cls_data['label'])) print('总共 {} 个类别'.format(len(np.unique(cls_data['label'])))) label_map = {key:index for index, key in enumerate(np.unique(cls_data['label']))} label_map_json = json.dumps(label_map, ensure_ascii=False, indent=3) if not os.path.exists(label_map_json): with open(map_path, 'w', encoding='utf-8') as f: f.write(label_map_json) cls_data['fasttext_label'] = cls_data['label'].map(label_map) for i in range(len(cls_data['fasttext_label'])): cls_data['fasttext_label'][i] = '__label__{}'.format(cls_data['fasttext_label'][i]) print(len(cls_data)) with open('./data/stopwords.txt', 'r', encoding='utf-8') as f: stopwords = f.readlines() stopwords = [i.strip() for i in stopwords] cls_data.to_csv(save_path, index=False) with open(txt_save_path, 'a+', encoding='utf-8') as f: for idx,row in tqdm(cls_data.iterrows(), desc='去除停用词:', total=len(cls_data)): words = row['text'].split(' ') out_str = '' for word in words: if word not in stopwords: out_str += word out_str += ' ' row['text'] = out_str.encode('utf-8') line = str(row['text']) + '\t' + row['fasttext_label'] + '\n' f.write(line)
记得要做下停用词过滤,实验发现过滤停用词可以将准确率提高1%
左右
注意一下这一行row['text'] = out_str.encode('utf-8')
,在调试代码的过程中我发现,不加encode('utf-8')
,生成的txt
和len(data)
不一致,但训练出来的结果是一样的,暂时没找到啥原因,加入之后就一样了。记得做predict
的时候也需要对输入的string
做下encode('utf-8')
转换。
生成的txt格式:
b'\xe5\xa6\x88\xe5\xa6\x88 \xe6\xb2\xa1 \xe6\x83\xb3 \xe8\xbd\xaf\xe5\xbc\xb1 \xe5\x81\x9a \xe6\x9c\x80 \xe4\xbc\x98\xe7\xa7\x80 1 \xe4\xb8\xbb\xe9\xa2\x98\xe9\xb2\x9c\xe6\x98\x8e \xe7\xa7\xaf\xe6\x9e\x81\xe5\x90\x91\xe4\xb8\x8a \xe5\x85\x85\xe6\xbb\xa1 \xe6\xad\xa3 \xe8\x83\xbd\xe9\x87\x8f 2 \xe5\x85\xa8\xe5\xbd\xa9 \xe6\x8f\x92\xe5\x9b\xbe \xe7\xb2\xbe\xe7\xbe\x8e \xe6\x89\x8b\xe7\xbb\x98 \xe7\x8e\xaf\xe4\xbf\x9d \xe6\xb2\xb9\xe5\xa2\xa8 \xe5\x8d\xb0\xe5\x88\xb7 3 \xe5\x9f\xb9\xe5\x85\xbb \xe5\xad\xa9\xe5\xad\x90 \xe5\x9d\x9a\xe5\xbc\xba \xe6\x80\xa7\xe6\xa0\xbc \xe9\x94\xbb\xe7\x82\xbc \xe5\xad\xa9\xe5\xad\x90 \xe7\x8b\xac\xe7\xab\x8b \xe5\x93\x81\xe6\xa0\xbc 4 \xe6\x95\x99\xe4\xbc\x9a \xe5\xad\xa9\xe5\xad\x90 \xe8\xae\xa4\xe8\xaf\x86 \xe6\xbd\x9c\xe8\x83\xbd \xe6\xa0\x91\xe7\xab\x8b \xe5\xbc\xba\xe5\xa4\xa7 \xe8\x87\xaa\xe4\xbf\xa1\xe5\xbf\x83 ' __label__2
b'\xe6\x9c\xba\xe6\xa2\xb0\xe5\x88\xb6\xe9\x80\xa0 \xe5\xb7\xa5\xe8\x89\xba\xe5\xad\xa6 \xe6\x95\x99\xe6\x9d\x90 \xe7\xbc\x96\xe5\x86\x99 \xe8\xbf\x87\xe7\xa8\x8b \xe4\xb8\xad \xe5\x85\xb7\xe6\x9c\x89 \xe4\xbb\xa5\xe4\xb8\x8b \xe7\x89\xb9\xe8\x89\xb2 1 \xe8\xaf\xb7 \xe7\x90\x86\xe8\xae\xba \xe9\x87\x8d \xe5\xae\x9e\xe8\xb7\xb5 2 \xe4\xbc\x81\xe4\xb8\x9a \xe7\xae\xa1\xe7\x90\x86\xe4\xba\xba\xe5\x91\x98 \xe5\x90\x88\xe4\xbd\x9c \xe7\xbc\x96\xe5\x86\x99\xe6\x95\x99\xe6\x9d\x90 \xe7\xaa\x81\xe5\x87\xba \xe5\xb7\xa5\xe7\xa8\x8b \xe5\xae\x9e\xe4\xbe\x8b \xe5\x88\x86\xe6\x9e\x90 \xe8\xae\xb2\xe8\xa7\xa3 3 \xe8\xb4\xaf\xe5\xbd\xbb \xe5\x90\x8d\xe7\xa7\xb0 \xe6\x9c\xaf\xe8\xaf\xad \xe4\xbb\xa3\xe5\x8f\xb7 \xe9\x87\x8f \xe5\x8d\x95\xe4\xbd\x8d \xe7\x8e\xb0\xe8\xa1\x8c \xe5\x9b\xbd\xe5\xae\xb6\xe6\xa0\x87\xe5\x87\x86 ' __label__1
训练代码还是比较简单的,直接将处理好的数据作为输入,再设置下参数,就可以了。
def train_three_class(): train_data_path = './data/train.csv' train_csv_path = './base_fasttext/data/three_class/train.csv' train_txt_path = './base_fasttext/data/three_class/train.txt' if not os.path.exists(train_txt_path): extract_three_cls_data(train_data_path, train_csv_path, train_txt_path) test_data_path = './data/test.csv' test_csv_path = './base_fasttext/data/three_class/test.csv' test_txt_path = './base_fasttext/data/three_class/test.txt' if not os.path.exists(test_txt_path): extract_three_cls_data(test_data_path, test_csv_path, test_txt_path) dev_data_path = './data/dev.csv' dev_csv_path = './base_fasttext/data/three_class/dev.csv' dev_txt_path = './base_fasttext/data/three_class/dev.txt' if not os.path.exists(dev_txt_path): extract_three_cls_data(dev_data_path, dev_csv_path, dev_txt_path) # classifier = fasttext.train_supervised(input= train_txt_path, autotuneValidationFile = dev_txt_path) model_path = './base_fasttext/model/fasttext_three_class.pkl' if not os.path.exists(model_path): classifier = fasttext.train_supervised(train_txt_path, label="__label__", dim=100, epoch=10, lr=0.1, wordNgrams=3, loss='softmax', thread=8, verbose=True, minCount = 5) classifier.save_model(model_path) result = classifier.test(test_txt_path) print('F1 Score: {}'.format(result[1] * result[2] * 2 / (result[2] + result[1]))) else: classifier = fasttext.load_model(model_path) # result = classifier.test(test_txt_path) # print('F1 Score: {}'.format(result[1] * result[2] * 2 / (result[2] + result[1]))) return classifier
得益于分层Softmax
,训练过程非常快。
F1 Score: 0.9315296251511487
还是相当不错的,拿个例子来测试一下:
three_classifier = train_three_class()
three_classifier_map_path = './base_fasttext/data/three_class/map.json'
with open(three_classifier_map_path, 'r', encoding='utf-8') as f:
three_classifier_map = json.load(f)
true_class = '工业技术'
test_data = '通信 原理 - ( 第 3 版 ) 本书 系统地 介绍 通信 的 基本概念 、 基本 理论 和 基本 分析方法 。 在 保持 一定 理论 深度 的 基础 上 , 本书 尽可能 简化 数学分析 过程 , 突出 对 概念 、 新 技术 的 介绍 ; 叙述 上 力求 概念 清楚 、 重点 突出 、 深入浅出 、 通俗易懂 ; 内容 上 力求 科学性 、 先进性 、 系统性 与 实用性 的 统一 。 本书 共 10 章 , 内容 包括 : 绪论 、 信号 与 噪声 分析 、 模拟 调制 系统 、 模拟信号 的 数字传输 、 数字信号 的 基带 传输 、 数字信号 的 载波 传输 、 现代 数字 调制 技术 、 信道 、 信道编码 和 扩频通信 。 内容 涵盖 国内 通信 原理 教学 的'.encode('utf-8')
result = three_classifier.predict(str(test_data))[0][0]
predicT_class = list(three_classifier_map.keys())[list(three_classifier_map.values()).index(int(result[-1]))]
print('预测的类别为:{}'.format(predicT_class))
print('真实的类别为:{}'.format(true_class))
记得测试用例要用encode('utf-8')
转换一下
预测的类别为:工业技术
真实的类别为:工业技术
预测正确。
关于调参: 也可以使用fasttext的自动寻参来训练,但是太慢了,五六分钟还没搞定,于是我放弃了。
classifier = fasttext.train_supervised(input= train_txt_path, autotuneValidationFile = dev_txt_path)
三分类的准确率达到了93%
,效果相当不错,那么在35个类别上的效果怎么样呢?
于是,我用所有数据测试了一下:
F1 Score: 0.755784146181149
预测的类别为:励志与成功
真实的类别为:工业技术
35
个类别有75%
的准确率,整体效果还不错。
由于存在严重的数据不平衡问题,在单一类别的准确率应该翻车了,这里就不再测试了。感兴趣的同学可以自己测试后在评论区留言。
1、fasttext在文本分类任务上效果确实很不错
2、fasttext采用层次化 softmax
使其训练速度非常快
本文所有代码及数据Github链接:fasttext_classify
相关文章:TextCNN文本分类Pytorch
赞
踩
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。