赞
踩
文本分类是指将一篇文章归到事先定义好的某一类或者某几类,在数据平台的一个典型的应用场景是,通过爬取用户浏览过的页面内容,识别出用户的浏览偏好,从而丰富该用户的画像。
本文介绍使用Spark MLlib提供的朴素贝叶斯(Naive Bayes)算法,完成对中文文本的分类过程。主要包括中文分词、文本表示(TF-IDF)、模型训练、分类预测等。
对于中文文本分类,需要先对内容进行分词,我使用的是ansj中文分析工具,其中自己可以配置扩展词库来使分词结果更合理,同时可以加一些停用词可以提高准确率,需要把数据样本分割成两批数据,一份用于训练模型,一份用于测试模型效果。
目录结构
DataFactory.java
package com.maweiming.spark.mllib.classifier; import com.alibaba.fastjson.JSON; import com.alibaba.fastjson.JSONObject; import com.maweiming.spark.mllib.utils.AnsjUtils; import com.maweiming.spark.mllib.utils.FileUtils; import org.apache.commons.lang3.StringUtils; import java.io.File; import java.util.HashMap; import java.util.List; import java.util.Map; /** * 1、first step * data format * Created by Coder-Ma on 2017/6/12. */ public class DataFactory { public static final String CLASS_PATH = "/Users/coderma/coders/github/SparkTextClassifier/src/main/resources"; public static final String STOP_WORD_PATH = CLASS_PATH + "/data/stopWord.txt"; public static final String NEWS_DATA_PATH = CLASS_PATH + "/data/NewsData"; public static final String DATA_TRAIN_PATH = CLASS_PATH + "/data/data-train.txt"; public static final String DATA_TEST_PATH = CLASS_PATH + "/data/data-test.txt"; public static final String MODELS = CLASS_PATH + "/models"; public static final String MODEL_PATH = CLASS_PATH + "/models/category-4"; public static final String LABEL_PATH = CLASS_PATH + "/models/labels.txt"; public static final String TF_PATH = CLASS_PATH + "/models/tf"; public static final String IDF_PATH = CLASS_PATH + "/models/idf"; public static void main(String[] args) { /** * 收集数据、特征工程 * 1、遍历数据样本目录 * 2、对数据进行清洗,剔除掉停用词 */ //数据样本切割比例 80%用于训练样本,20%数据用于测试模型准确率 Double spiltRate = 0.8; //停用词 List<String> stopWords = FileUtils.readLine(line -> line, STOP_WORD_PATH); //分类标签(标签id,分类名) Map<Integer, String> labels = new HashMap<>(); Integer dirIndex = 0; String[] dirNames = new File(NEWS_DATA_PATH).list(); for (String dirName : dirNames) { dirIndex++; labels.put(dirIndex, dirName); String fileDirPath = String.format("%s/%s", NEWS_DATA_PATH, dirName); String[] fileNames = new File(fileDirPath).list(); //当前分类目录的样本总数 * 切割比率 int spilt = Double.valueOf(fileNames.length * spiltRate).intValue(); for (int i = 0; i < fileNames.length; i++) { String fileName = fileNames[i]; String filePath = String.format("%s/%s", fileDirPath, fileName); System.out.println(filePath); String text = FileUtils.readFile(filePath); for (String stopWord : stopWords) { text = text.replaceAll(stopWord, ""); } if (StringUtils.isBlank(text)) { continue; } //把文本内容进行分词 List<String> wordList = AnsjUtils.participle(text); JSONObject data = new JSONObject(); data.put("text", wordList); data.put("category", Double.valueOf(dirIndex)); if (i > spilt) { //测试数据 FileUtils.appendText(DATA_TEST_PATH, data.toJSONString() + "\n"); } else { //训练数据 FileUtils.appendText(DATA_TRAIN_PATH, data.toJSONString() + "\n"); } } } FileUtils.writer(LABEL_PATH, JSON.toJSONString(labels));//data labels System.out.println("Data processing successfully !"); System.out.println("======================================================="); System.out.println("trainData:" + DATA_TRAIN_PATH); System.out.println("testData:" + DATA_TEST_PATH); System.out.println("labes:" + LABEL_PATH); System.out.println("======================================================="); } }
分好词后,每一个词都作为一个特征,需要将中文词语转换成Double型来表示,通常使用该词语的TF-IDF值作为特征值,Spark提供了全面的特征抽取及转换的API,非常方便,详见http://spark.apache.org/docs/latest/ml-features.html
为原始属于设置标签,按照resource->NewsData目录下面文件夹索引区分。
这里将中文词语转换成INT型的Hashing算法,类似于Bloomfilter,下面的setNumFeatures(500000)表示将Hash分桶的数量设置为500000个,这个值默认为2的20次方,即1048576,可以根据你的词语数量来调整,一般来说,这个值越大,不同的词被计算为一个Hash值的概率就越小,数据也更准确,但需要消耗更大的内存,和Bloomfilter是一个道理。
然后就可以训练模型,下面代码
package com.maweiming.spark.mllib.classifier; import com.maweiming.spark.mllib.utils.FileUtils; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.ml.feature.HashingTF; import org.apache.spark.ml.feature.IDF; import org.apache.spark.ml.feature.IDFModel; import org.apache.spark.ml.linalg.SparseVector; import org.apache.spark.mllib.classification.NaiveBayes; import org.apache.spark.mllib.classification.NaiveBayesModel; import org.apache.spark.mllib.linalg.Vector; import org.apache.spark.mllib.linalg.Vectors; import org.apache.spark.mllib.regression.LabeledPoint; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; import org.apache.spark.sql.SparkSession; import java.io.File; import java.io.IOException; /** * 2、The second step * Created by Coder-Ma on 2017/6/26. */ public class NaiveBayesTrain { public static void main(String[] args) throws IOException { //1、创建一个SparkSession SparkSession spark = SparkSession.builder().appName("NaiveBayes").master("local") .getOrCreate(); //2、加载训练数据样本 Dataset<Row> train = spark.read().json(DataFactory.DATA_TRAIN_PATH); //3、通过tf-idf计算数据样本中的词频 //word frequency count HashingTF hashingTF = new HashingTF().setNumFeatures(500000).setInputCol("text").setOutputCol("rawFeatures"); Dataset<Row> featurizedData = hashingTF.transform(train); //count tf-idf IDF idf = new IDF().setInputCol("rawFeatures").setOutputCol("features"); IDFModel idfModel = idf.fit(featurizedData); Dataset<Row> rescaledData = idfModel.transform(featurizedData); //4、把数据样本转换成向量 JavaRDD<LabeledPoint> trainDataRdd = rescaledData.select("category", "features").javaRDD().map(v1 -> { Double category = v1.getAs("category"); SparseVector features = v1.getAs("features"); Vector featuresVector = Vectors.dense(features.toArray()); return new LabeledPoint(Double.valueOf(category),featuresVector); }); System.out.println("Start training..."); //调用朴素贝叶斯算法,传入向量数据训练模型 NaiveBayesModel model = NaiveBayes.train(trainDataRdd.rdd()); //save model model.save(spark.sparkContext(), DataFactory.MODEL_PATH); //save tf hashingTF.save(DataFactory.TF_PATH); //save idf idfModel.save(DataFactory.IDF_PATH); System.out.println("train successfully !"); System.out.println("======================================================="); System.out.println("modelPath:"+DataFactory.MODEL_PATH); System.out.println("tfPath:"+DataFactory.TF_PATH); System.out.println("idfPath:"+DataFactory.IDF_PATH); System.outprintln("======================================================="); } }
训练模型完成
train successfully !
=======================================================
modelPath:/Users/coderma/coders/github/SparkTextClassifier/src/main/resources/models/category-4
tfPath:/Users/coderma/coders/github/SparkTextClassifier/src/main/resources/models/tf
idfPath:/Users/coderma/coders/github/SparkTextClassifier/src/main/resources/models/idf
=======================================================
package com.maweiming.spark.mllib.classifier; import com.alibaba.fastjson.JSON; import com.maweiming.spark.mllib.dto.Result; import com.maweiming.spark.mllib.utils.AnsjUtils; import com.maweiming.spark.mllib.utils.FileUtils; import org.apache.spark.ml.feature.HashingTF; import org.apache.spark.ml.feature.IDFModel; import org.apache.spark.ml.linalg.SparseVector; import org.apache.spark.mllib.classification.NaiveBayesModel; import org.apache.spark.mllib.linalg.Vector; import org.apache.spark.mllib.linalg.Vectors; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; import org.apache.spark.sql.RowFactory; import org.apache.spark.sql.SparkSession; import org.apache.spark.sql.types.*; import java.io.File; import java.text.DecimalFormat; import java.util.*; /** * 3、the third step * Created by Coder-Ma on 2017/6/26. */ public class NaiveBayesTest { private static HashingTF hashingTF; private static IDFModel idfModel; private static NaiveBayesModel model; private static Map<Integer,String> labels = new HashMap<>(); public static void main(String[] args) { SparkSession spark = SparkSession.builder().appName("NaiveBayes").master("local") .getOrCreate(); //load tf file hashingTF = HashingTF.load(DataFactory.TF_PATH); //load idf file idfModel = IDFModel.load(DataFactory.IDF_PATH); //load model model = NaiveBayesModel.load(spark.sparkContext(), DataFactory.MODEL_PATH); //batch test batchTestModel(spark, DataFactory.DATA_TEST_PATH); //test a single testModel(spark,"最近这段时间,由于印度三哥可能有些膨胀,在边境地区总想“搞事情”,这也让不少人的目光集中到此。事实上,我国在与印度的交界处有一军事要地,只要解放军一抬高水位,那么印军或就“不战而退”。它就是地处我国西藏与印度控制克什米尔交界的班公湖。\n" + "\n" + "\n" + "众所周知,从古至今那些地处与军事险要易守难攻的形胜之地,都具有非常重要的军事意义。经常能左右一场战争的胜负。据悉,班公湖位于西藏自治区阿里地区日土县城西北。全长有600多公里,其中地处中国的有400多公里,地处与印度约有200公里。整体成东西走向,海拔在4000多米以上。湖水整体为淡水湖,但由于湖水在西段的淡水补给量的大方面建少,东西方向上交替不通畅,使西部的区域变成了咸水湖。于是便出现了一个有趣的现象,在东部的中国境内班公湖为淡水湖,在西部的印度境内,班公湖为咸水湖。\n" + "\n" + "\n" + "而我军在于印度交界的班公湖区域有一个阀门,这个区域有着非常大的军事作用,而如果印军将部队部署在班公湖地区,我军只需打开阀门,抬高班公湖的东部水位。将他们的军事设施和军用要道给全部淹没。而印军的军事物资和后勤保障都将全部瘫痪,到时印度的军事部署都将全部不攻自破。\n" + "\n" + "\n" + "而印度应该知道现代战争最为重要的便是后勤制度的保障,军事行动能否取得胜利,很大程度取决于后勤能否及时的跟上。而我军在班公湖地区地势上就有了绝对的军事优势,军用物资也可源源不断的运输上来,而印度却优势全无。而我国自古以来就是爱好和平的国家,人不犯我我不犯人。只希望印军能认清与我国军事力量的差距,不要盲目自信。\n" + "\n"); } public static void batchTestModel(SparkSession sparkSession, String testPath) { Dataset<Row> test = sparkSession.read().json(testPath); //word frequency count Dataset<Row> featurizedData = hashingTF.transform(test); //count tf-idf Dataset<Row> rescaledData = idfModel.transform(featurizedData); List<Row> rowList = rescaledData.select("category", "features").javaRDD().collect(); List<Result> dataResults = new ArrayList<>(); for (Row row : rowList) { Double category = row.getAs("category"); SparseVector sparseVector = row.getAs("features"); Vector features = Vectors.dense(sparseVector.toArray()); double predict = model.predict(features); dataResults.add(new Result(category, predict)); } Integer successNum = 0; Integer errorNum = 0; for (Result result : dataResults) { if (result.isCorrect()) { successNum++; } else { errorNum++; } } DecimalFormat df = new DecimalFormat("######0.0000"); Double result = (Double.valueOf(successNum) / Double.valueOf(dataResults.size())) * 100; System.out.println("batch test"); System.out.println("======================================================="); System.out.println("Summary"); System.out.println("-------------------------------------------------------"); System.out.println(String.format("Correctly Classified Instances : %s\t %s%%",successNum,df.format(result))); System.out.println(String.format("Incorrectly Classified Instances : %s\t %s%%",errorNum,df.format(100-result))); System.out.println(String.format("Total Classified Instances : %s",dataResults.size())); System.out.println("==================================="); } public static void testModel(SparkSession sparkSession, String content){ List<Row> data = Arrays.asList( RowFactory.create(AnsjUtils.participle(content)) ); StructType schema = new StructType(new StructField[]{ new StructField("text", new ArrayType(DataTypes.StringType, false), false, Metadata.empty()) }); Dataset<Row> testData = sparkSession.createDataFrame(data, schema); //word frequency count Dataset<Row> transform = hashingTF.transform(testData); //count tf-idf Dataset<Row> rescaledData = idfModel.transform(transform); Row row =rescaledData.select("features").first(); SparseVector sparseVector = row.getAs("features"); Vector features = Vectors.dense(sparseVector.toArray()); Double predict = model.predict(features); System.out.println("test a single"); System.out.println("======================================================="); System.out.println("Result"); System.out.println("-------------------------------------------------------"); System.out.println(labels.get(predict.intValue())); System.out.println("==================================="); } }
测试结果
batch test
=======================================================
Summary
-------------------------------------------------------
Correctly Classified Instances : 785 98.6181%
Incorrectly Classified Instances : 11 1.3819%
Total Classified Instances : 796
===================================
准确率98%,还可以。以上就是文本分类器的实现,我们还可以直接把数据样本换成 正常邮件|垃圾邮件 这两类的数据,就可以实现一个垃圾邮箱分类器了
https://github.com/Maweiming/SparkTextClassifier
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。