当前位置:   article > 正文

Java Spark ML实现的文本分类_java 文本分类

java 文本分类

文本分类是指将一篇文章归到事先定义好的某一类或者某几类,在数据平台的一个典型的应用场景是,通过爬取用户浏览过的页面内容,识别出用户的浏览偏好,从而丰富该用户的画像。
本文介绍使用Spark MLlib提供的朴素贝叶斯(Naive Bayes)算法,完成对中文文本的分类过程。主要包括中文分词、文本表示(TF-IDF)、模型训练、分类预测等。

特征工程

文本处理

对于中文文本分类,需要先对内容进行分词,我使用的是ansj中文分析工具,其中自己可以配置扩展词库来使分词结果更合理,同时可以加一些停用词可以提高准确率,需要把数据样本分割成两批数据,一份用于训练模型,一份用于测试模型效果。

代码

目录结构
在这里插入图片描述

DataFactory.java

package com.maweiming.spark.mllib.classifier;

import com.alibaba.fastjson.JSON;
import com.alibaba.fastjson.JSONObject;
import com.maweiming.spark.mllib.utils.AnsjUtils;
import com.maweiming.spark.mllib.utils.FileUtils;
import org.apache.commons.lang3.StringUtils;

import java.io.File;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

/**
 * 1、first step
 * data format
 * Created by Coder-Ma on 2017/6/12.
 */
public class DataFactory {

    public static final String CLASS_PATH = "/Users/coderma/coders/github/SparkTextClassifier/src/main/resources";

    public static final String STOP_WORD_PATH = CLASS_PATH + "/data/stopWord.txt";

    public static final String NEWS_DATA_PATH = CLASS_PATH + "/data/NewsData";

    public static final String DATA_TRAIN_PATH = CLASS_PATH + "/data/data-train.txt";

    public static final String DATA_TEST_PATH = CLASS_PATH + "/data/data-test.txt";

    public static final String MODELS = CLASS_PATH + "/models";

    public static final String MODEL_PATH = CLASS_PATH + "/models/category-4";

    public static final String LABEL_PATH = CLASS_PATH + "/models/labels.txt";

    public static final String TF_PATH = CLASS_PATH + "/models/tf";

    public static final String IDF_PATH = CLASS_PATH + "/models/idf";

    public static void main(String[] args) {

        /**
         * 收集数据、特征工程
         * 1、遍历数据样本目录
         * 2、对数据进行清洗,剔除掉停用词
         */

        //数据样本切割比例 80%用于训练样本,20%数据用于测试模型准确率
        Double spiltRate = 0.8;

        //停用词
        List<String> stopWords = FileUtils.readLine(line -> line, STOP_WORD_PATH);

        //分类标签(标签id,分类名)
        Map<Integer, String> labels = new HashMap<>();

        Integer dirIndex = 0;
        String[] dirNames = new File(NEWS_DATA_PATH).list();
        for (String dirName : dirNames) {
            dirIndex++;
            labels.put(dirIndex, dirName);

            String fileDirPath = String.format("%s/%s", NEWS_DATA_PATH, dirName);
            String[] fileNames = new File(fileDirPath).list();
            //当前分类目录的样本总数 * 切割比率
            int spilt = Double.valueOf(fileNames.length * spiltRate).intValue();
            for (int i = 0; i < fileNames.length; i++) {
                String fileName = fileNames[i];
                String filePath = String.format("%s/%s", fileDirPath, fileName);
                System.out.println(filePath);
                String text = FileUtils.readFile(filePath);
                for (String stopWord : stopWords) {
                    text = text.replaceAll(stopWord, "");
                }
                if (StringUtils.isBlank(text)) {
                    continue;
                }

                //把文本内容进行分词
                List<String> wordList = AnsjUtils.participle(text);
                JSONObject data = new JSONObject();
                data.put("text", wordList);
                data.put("category", Double.valueOf(dirIndex));

                if (i > spilt) {
                    //测试数据
                    FileUtils.appendText(DATA_TEST_PATH, data.toJSONString() + "\n");
                } else {
                    //训练数据
                    FileUtils.appendText(DATA_TRAIN_PATH, data.toJSONString() + "\n");
                }
            }

        }

        FileUtils.writer(LABEL_PATH, JSON.toJSONString(labels));//data labels

        System.out.println("Data processing successfully !");
        System.out.println("=======================================================");
        System.out.println("trainData:" + DATA_TRAIN_PATH);
        System.out.println("testData:" + DATA_TEST_PATH);
        System.out.println("labes:" + LABEL_PATH);
        System.out.println("=======================================================");

    }

}
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107
  • 108

训练模型

词语特征值处理(TF-IDF)

分好词后,每一个词都作为一个特征,需要将中文词语转换成Double型来表示,通常使用该词语的TF-IDF值作为特征值,Spark提供了全面的特征抽取及转换的API,非常方便,详见http://spark.apache.org/docs/latest/ml-features.html
为原始属于设置标签,按照resource->NewsData目录下面文件夹索引区分。

  1. car
  2. game
  3. it
  4. military

这里将中文词语转换成INT型的Hashing算法,类似于Bloomfilter,下面的setNumFeatures(500000)表示将Hash分桶的数量设置为500000个,这个值默认为2的20次方,即1048576,可以根据你的词语数量来调整,一般来说,这个值越大,不同的词被计算为一个Hash值的概率就越小,数据也更准确,但需要消耗更大的内存,和Bloomfilter是一个道理。
然后就可以训练模型,下面代码

代码

package com.maweiming.spark.mllib.classifier;

import com.maweiming.spark.mllib.utils.FileUtils;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.ml.feature.HashingTF;
import org.apache.spark.ml.feature.IDF;
import org.apache.spark.ml.feature.IDFModel;
import org.apache.spark.ml.linalg.SparseVector;
import org.apache.spark.mllib.classification.NaiveBayes;
import org.apache.spark.mllib.classification.NaiveBayesModel;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.linalg.Vectors;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;

import java.io.File;
import java.io.IOException;

/**
 * 2、The second step
 * Created by Coder-Ma on 2017/6/26.
 */
public class NaiveBayesTrain {

    public static void main(String[] args) throws IOException {

        //1、创建一个SparkSession
        SparkSession spark = SparkSession.builder().appName("NaiveBayes").master("local")
                .getOrCreate();

        //2、加载训练数据样本
        Dataset<Row> train = spark.read().json(DataFactory.DATA_TRAIN_PATH);

        //3、通过tf-idf计算数据样本中的词频
        //word frequency count
        HashingTF hashingTF = new HashingTF().setNumFeatures(500000).setInputCol("text").setOutputCol("rawFeatures");
        Dataset<Row> featurizedData  = hashingTF.transform(train);
        //count tf-idf
        IDF idf = new IDF().setInputCol("rawFeatures").setOutputCol("features");
        IDFModel idfModel = idf.fit(featurizedData);
        Dataset<Row> rescaledData = idfModel.transform(featurizedData);

        //4、把数据样本转换成向量
        JavaRDD<LabeledPoint> trainDataRdd = rescaledData.select("category", "features").javaRDD().map(v1 -> {
            Double category = v1.getAs("category");
            SparseVector features = v1.getAs("features");
            Vector featuresVector = Vectors.dense(features.toArray());
            return new LabeledPoint(Double.valueOf(category),featuresVector);
        });

        System.out.println("Start training...");
        //调用朴素贝叶斯算法,传入向量数据训练模型
        NaiveBayesModel model  = NaiveBayes.train(trainDataRdd.rdd());
        //save model
        model.save(spark.sparkContext(), DataFactory.MODEL_PATH);
        //save tf
        hashingTF.save(DataFactory.TF_PATH);
        //save idf
        idfModel.save(DataFactory.IDF_PATH);

        System.out.println("train successfully !");
        System.out.println("=======================================================");
        System.out.println("modelPath:"+DataFactory.MODEL_PATH);
        System.out.println("tfPath:"+DataFactory.TF_PATH);
        System.out.println("idfPath:"+DataFactory.IDF_PATH);
        System.outprintln("=======================================================");
    }

}
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71

训练模型完成

train successfully !
=======================================================
modelPath:/Users/coderma/coders/github/SparkTextClassifier/src/main/resources/models/category-4
tfPath:/Users/coderma/coders/github/SparkTextClassifier/src/main/resources/models/tf
idfPath:/Users/coderma/coders/github/SparkTextClassifier/src/main/resources/models/idf
=======================================================
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6

测试模型

package com.maweiming.spark.mllib.classifier;

import com.alibaba.fastjson.JSON;
import com.maweiming.spark.mllib.dto.Result;
import com.maweiming.spark.mllib.utils.AnsjUtils;
import com.maweiming.spark.mllib.utils.FileUtils;
import org.apache.spark.ml.feature.HashingTF;
import org.apache.spark.ml.feature.IDFModel;
import org.apache.spark.ml.linalg.SparseVector;
import org.apache.spark.mllib.classification.NaiveBayesModel;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.linalg.Vectors;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.types.*;

import java.io.File;
import java.text.DecimalFormat;
import java.util.*;

/**
 * 3、the third step
 * Created by Coder-Ma on 2017/6/26.
 */
public class NaiveBayesTest {

    private static HashingTF hashingTF;

    private static IDFModel idfModel;

    private static NaiveBayesModel model;

    private static Map<Integer,String> labels = new HashMap<>();

    public static void main(String[] args) {

        SparkSession spark = SparkSession.builder().appName("NaiveBayes").master("local")
                .getOrCreate();

        //load tf file
        hashingTF = HashingTF.load(DataFactory.TF_PATH);
        //load idf file
        idfModel = IDFModel.load(DataFactory.IDF_PATH);
        //load model
        model = NaiveBayesModel.load(spark.sparkContext(), DataFactory.MODEL_PATH);

        //batch test
        batchTestModel(spark, DataFactory.DATA_TEST_PATH);

        //test a single
        testModel(spark,"最近这段时间,由于印度三哥可能有些膨胀,在边境地区总想“搞事情”,这也让不少人的目光集中到此。事实上,我国在与印度的交界处有一军事要地,只要解放军一抬高水位,那么印军或就“不战而退”。它就是地处我国西藏与印度控制克什米尔交界的班公湖。\n" +
                "\n" +
                "\n" +
                "众所周知,从古至今那些地处与军事险要易守难攻的形胜之地,都具有非常重要的军事意义。经常能左右一场战争的胜负。据悉,班公湖位于西藏自治区阿里地区日土县城西北。全长有600多公里,其中地处中国的有400多公里,地处与印度约有200公里。整体成东西走向,海拔在4000多米以上。湖水整体为淡水湖,但由于湖水在西段的淡水补给量的大方面建少,东西方向上交替不通畅,使西部的区域变成了咸水湖。于是便出现了一个有趣的现象,在东部的中国境内班公湖为淡水湖,在西部的印度境内,班公湖为咸水湖。\n" +
                "\n" +
                "\n" +
                "而我军在于印度交界的班公湖区域有一个阀门,这个区域有着非常大的军事作用,而如果印军将部队部署在班公湖地区,我军只需打开阀门,抬高班公湖的东部水位。将他们的军事设施和军用要道给全部淹没。而印军的军事物资和后勤保障都将全部瘫痪,到时印度的军事部署都将全部不攻自破。\n" +
                "\n" +
                "\n" +
                "而印度应该知道现代战争最为重要的便是后勤制度的保障,军事行动能否取得胜利,很大程度取决于后勤能否及时的跟上。而我军在班公湖地区地势上就有了绝对的军事优势,军用物资也可源源不断的运输上来,而印度却优势全无。而我国自古以来就是爱好和平的国家,人不犯我我不犯人。只希望印军能认清与我国军事力量的差距,不要盲目自信。\n" +
                "\n");

    }

    public static void batchTestModel(SparkSession sparkSession, String testPath) {

        Dataset<Row> test = sparkSession.read().json(testPath);
        //word frequency count
        Dataset<Row> featurizedData = hashingTF.transform(test);
        //count tf-idf
        Dataset<Row> rescaledData = idfModel.transform(featurizedData);

        List<Row> rowList = rescaledData.select("category", "features").javaRDD().collect();

        List<Result> dataResults = new ArrayList<>();
        for (Row row : rowList) {
            Double category = row.getAs("category");
            SparseVector sparseVector = row.getAs("features");
            Vector features = Vectors.dense(sparseVector.toArray());
            double predict = model.predict(features);
            dataResults.add(new Result(category, predict));
        }

        Integer successNum = 0;
        Integer errorNum = 0;

        for (Result result : dataResults) {
            if (result.isCorrect()) {
                successNum++;
            } else {
                errorNum++;
            }
        }

        DecimalFormat df = new DecimalFormat("######0.0000");
        Double result = (Double.valueOf(successNum) / Double.valueOf(dataResults.size())) * 100;

        System.out.println("batch test");
        System.out.println("=======================================================");
        System.out.println("Summary");
        System.out.println("-------------------------------------------------------");
        System.out.println(String.format("Correctly Classified Instances          :      %s\t   %s%%",successNum,df.format(result)));
        System.out.println(String.format("Incorrectly Classified Instances        :       %s\t    %s%%",errorNum,df.format(100-result)));
        System.out.println(String.format("Total Classified Instances              :      %s",dataResults.size()));
        System.out.println("===================================");

    }

    public static void testModel(SparkSession sparkSession, String content){
        List<Row> data = Arrays.asList(
                RowFactory.create(AnsjUtils.participle(content))
        );
        StructType schema = new StructType(new StructField[]{
                new StructField("text", new ArrayType(DataTypes.StringType, false), false, Metadata.empty())
        });

        Dataset<Row> testData = sparkSession.createDataFrame(data, schema);
        //word frequency count
        Dataset<Row> transform = hashingTF.transform(testData);
        //count tf-idf
        Dataset<Row> rescaledData = idfModel.transform(transform);

        Row row =rescaledData.select("features").first();
        SparseVector sparseVector = row.getAs("features");
        Vector features = Vectors.dense(sparseVector.toArray());
        Double predict = model.predict(features);

        System.out.println("test a single");
        System.out.println("=======================================================");
        System.out.println("Result");
        System.out.println("-------------------------------------------------------");
        System.out.println(labels.get(predict.intValue()));
        System.out.println("===================================");
    }
}
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107
  • 108
  • 109
  • 110
  • 111
  • 112
  • 113
  • 114
  • 115
  • 116
  • 117
  • 118
  • 119
  • 120
  • 121
  • 122
  • 123
  • 124
  • 125
  • 126
  • 127
  • 128
  • 129
  • 130
  • 131
  • 132
  • 133
  • 134
  • 135
  • 136
  • 137

测试结果

batch test
=======================================================
Summary
-------------------------------------------------------
Correctly Classified Instances          :      785	   98.6181%
Incorrectly Classified Instances        :       11	    1.3819%
Total Classified Instances              :      796
===================================
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8

准确率98%,还可以。以上就是文本分类器的实现,我们还可以直接把数据样本换成 正常邮件|垃圾邮件 这两类的数据,就可以实现一个垃圾邮箱分类器了

源码

https://github.com/Maweiming/SparkTextClassifier

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/Cpp五条/article/detail/564669
推荐阅读
相关标签
  

闽ICP备14008679号