当前位置:   article > 正文

spark应用(三)文本分类_spark的中文文本分类

spark的中文文本分类

一、特征提取

1、什么是特征提取?
     对某一模式的组测量值进行变换,以突出该模式具有代表性特征的一种方法(百度百科)。或者参考多方面的解释
http://www.igi-global.com/dictionary/feature-extraction/10960

特征提取简单来说就是从一堆数据中,提取数据,并变成我们熟悉的数据形式(比如从图片中提取像素点,并变成RGB数字,或者把文档变成我们熟悉的向量空间)

2、TF-IDF
     TF-IDF是在文本挖掘中广泛使用的特征向量方法,以反映term(术语)对语料库中的文档的重要性。term的频率是term在文档中出现的次数除以总文档的占比。而文档的频率是包含這个term占除以总体文档的占比。如果单从TF(term频率),很容易出现强调出现很频率的term,百比如英文中,'a',‘the’,'of'。IDF(反向文档频率)就说明该term会不会使用很平凡。
TF-IDF计算公式如下:

对于更多细节,可以参考之前写的:

http://blog.csdn.net/legotime/article/details/51836028
在实际中,往往IDF的分子和分母都会加1

3、TF计算
     TF计算Spark提供了HashingTF和CountVectorizer
HashingTF
       HashingTF方法用的是hash trick(feature hash).而spark中用的是MurmurHash3 ( https://en.wikipedia.org/wiki/MurmurHash)的算法。为了让大家有一个更深层次的认识。现在来说说什么是hash trick。如下三句话
  1. John likes to watch movies.
  2. Mary likes movies too.
  3. John also likes football.
转换为:

Term Index
John1
likes2
to3
watch4
movies5
Mary6
too7
also8
football9
(就是从John开始,从左到右,从上到下。重复的就pass),那么三个句子的term-document矩阵如下:

现在用Java来实现:

  1. /**
  2. *
  3. * @param file 文件位置
  4. * @return
  5. * @throws IOException
  6. */
  7. public static ArrayList<int[]> txt2num(String file) throws IOException {
  8. BufferedReader br = new BufferedReader(new FileReader(file));
  9. String s;
  10. StringBuilder sb = new StringBuilder();
  11. ArrayList<String> strArr = new ArrayList<String>();
  12. while ((s=br.readLine()) != null){
  13. String tmp = s.split("\\.")[0];
  14. strArr.add(tmp);
  15. sb.append(tmp+" ");
  16. }
  17. String[] split = sb.toString().split(" ");
  18. TreeSet<String> strHashSet = new TreeSet<>();
  19. for (String s1 : split) {
  20. strHashSet.add(s1);
  21. }
  22. ArrayList<int[]> txt2Matrix = new ArrayList<int[]>();
  23. System.out.println(Arrays.toString(strHashSet.toArray()));
  24. //填入数据
  25. for (String s1 : strArr) {
  26. int[] txt2IntVec = new int[strHashSet.size()];
  27. String[] ss = s1.split(" ");
  28. ArrayList<String > strs = new ArrayList<String>();
  29. for (String s2 : ss) {
  30. strs.add(s2);
  31. }
  32. System.out.println(Arrays.toString(ss));
  33. for (int i = 0; i < txt2IntVec.length; i++) {
  34. txt2IntVec[i] = strs.contains(strHashSet.toArray()[i]) ? 1 : 0;
  35. }
  36. System.out.println(Arrays.toString(txt2IntVec));
  37. txt2Matrix.add(txt2IntVec);
  38. }
  39. return txt2Matrix;
  40. }

可以看看MLlib下面的MurmurHash3

  1. public final class Murmur3_x86_32 {
  2. private static final int C1 = 0xcc9e2d51;
  3. private static final int C2 = 0x1b873593;
  4. private final int seed;
  5. public Murmur3_x86_32(int seed) {
  6. this.seed = seed;
  7. }
  8. @Override
  9. public String toString() {
  10. return "Murmur3_32(seed=" + seed + ")";
  11. }
  12. public int hashInt(int input) {
  13. return hashInt(input, seed);
  14. }
  15. public static int hashInt(int input, int seed) {
  16. int k1 = mixK1(input);
  17. int h1 = mixH1(seed, k1);
  18. return fmix(h1, 4);
  19. }
  20. public int hashUnsafeWords(Object base, long offset, int lengthInBytes) {
  21. return hashUnsafeWords(base, offset, lengthInBytes, seed);
  22. }
  23. public static int hashUnsafeWords(Object base, long offset, int lengthInBytes, int seed) {
  24. // This is based on Guava's `Murmur32_Hasher.processRemaining(ByteBuffer)` method.
  25. assert (lengthInBytes % 8 == 0): "lengthInBytes must be a multiple of 8 (word-aligned)";
  26. int h1 = hashBytesByInt(base, offset, lengthInBytes, seed);
  27. return fmix(h1, lengthInBytes);
  28. }
  29. public static int hashUnsafeBytes(Object base, long offset, int lengthInBytes, int seed) {
  30. assert (lengthInBytes >= 0): "lengthInBytes cannot be negative";
  31. int lengthAligned = lengthInBytes - lengthInBytes % 4;
  32. int h1 = hashBytesByInt(base, offset, lengthAligned, seed);
  33. for (int i = lengthAligned; i < lengthInBytes; i++) {
  34. int halfWord = Platform.getByte(base, offset + i);
  35. int k1 = mixK1(halfWord);
  36. h1 = mixH1(h1, k1);
  37. }
  38. return fmix(h1, lengthInBytes);
  39. }
  40. private static int hashBytesByInt(Object base, long offset, int lengthInBytes, int seed) {
  41. assert (lengthInBytes % 4 == 0);
  42. int h1 = seed;
  43. for (int i = 0; i < lengthInBytes; i += 4) {
  44. int halfWord = Platform.getInt(base, offset + i);
  45. int k1 = mixK1(halfWord);
  46. h1 = mixH1(h1, k1);
  47. }
  48. return h1;
  49. }
  50. public int hashLong(long input) {
  51. return hashLong(input, seed);
  52. }
  53. public static int hashLong(long input, int seed) {
  54. int low = (int) input;
  55. int high = (int) (input >>> 32);
  56. int k1 = mixK1(low);
  57. int h1 = mixH1(seed, k1);
  58. k1 = mixK1(high);
  59. h1 = mixH1(h1, k1);
  60. return fmix(h1, 8);
  61. }
  62. private static int mixK1(int k1) {
  63. k1 *= C1;
  64. k1 = Integer.rotateLeft(k1, 15);
  65. k1 *= C2;
  66. return k1;
  67. }
  68. private static int mixH1(int h1, int k1) {
  69. h1 ^= k1;
  70. h1 = Integer.rotateLeft(h1, 13);
  71. h1 = h1 * 5 + 0xe6546b64;
  72. return h1;
  73. }
  74. // Finalization mix - force all bits of a hash block to avalanche
  75. private static int fmix(int h1, int length) {
  76. h1 ^= length;
  77. h1 ^= h1 >>> 16;
  78. h1 *= 0x85ebca6b;
  79. h1 ^= h1 >>> 13;
  80. h1 *= 0xc2b2ae35;
  81. h1 ^= h1 >>> 16;
  82. return h1;
  83. }
  84. }
CountVectorizer
  比较简单,可以看看如下:和hash trick区别在于,它是对某个term进行了计算。


二、文本分类

数据

http://qwone.com/~jason/20Newsgroups/

上面已经有数据的属性说明,
选用的数据集是:

导入HDFS如下:



数据预处理
     把数据转换为如下格式:
case class LabeledText(item:String,label:Double,doc:String)
其中:
  1. item:文件名字(类名)
  2. label:标签
  3. doc:从整个文本中提取的单词或者字母
处理程序如下:

  1. object NewClassifier {
  2. def listSonRoute(path: String): Seq[String] ={
  3. val conf = new Configuration()
  4. val fs = new Path(path).getFileSystem(conf)
  5. val status = fs.listFiles(new Path(path),true)
  6. var res: List[String] = Nil
  7. while (status.hasNext){
  8. res = res++Seq(status.next().getPath.toString)
  9. }
  10. res
  11. }
  12. /**
  13. * 提取英文单词或者字母
  14. * @param content
  15. * @return
  16. */
  17. def splitStr(content: String): List[String] =("[A-Za-z]+$".r findAllIn content).toList
  18. def rdd2Str(sc:SparkContext,path:String)= {
  19. val rdd = sc.textFile(path)
  20. val myAccumulator = sc.accumulator[String](" ")(StringAccumulatorParam)
  21. rdd.foreach{ part=> splitStr(part).foreach{ word =>
  22. myAccumulator.add(word.toLowerCase)
  23. }}
  24. myAccumulator.value
  25. }
  26. def getDataFromHDFS(sc:SparkContext,path:String): DataFrame ={
  27. val sqlContext = new SQLContext(sc)
  28. import sqlContext.implicits._
  29. listSonRoute(path).map(
  30. part =>
  31. LabeledText(part.split("/").apply(8),new Random(100).nextInt(),rdd2Str(sc,part))
  32. ).toDF()
  33. }
  34. case class LabeledText(item:String,label:Double,doc:String)
  35. def main(args: Array[String]) {
  36. val conf = new SparkConf().setAppName("new Classifier").setMaster("local")
  37. .set("spark.storage.memoryFraction", "0.1")
  38. val sc = new SparkContext(conf)
  39. // rawData to parquet
  40. val testPath = "hdfs://master:9000/data/studySet/textMining/20news-bydate/20news-bydate-test"
  41. val trainPath = "hdfs://master:9000/data/studySet/textMining/20news-bydate/20news-bydate-train/"
  42. val testDF = getDataFromHDFS(sc,testPath)
  43. val trainDF = getDataFromHDFS(sc,trainPath)
  44. testDF.write.save("hdfs://master:9000/data/studySet/textMining/20news-bydate/test")
  45. trainDF.write.save("hdfs://master:9000/data/studySet/textMining/20news-bydate/train")
  46. }
  47. }
  48. object StringAccumulatorParam extends AccumulatorParam[String] {
  49. override def addInPlace(r1: String, r2: String): String = add(r1,r2)
  50. /**
  51. * 初始化
  52. * @param initialValue 初始值
  53. * @return
  54. */
  55. override def zero(initialValue: String): String = ""
  56. def add(v1:String,v2:String)={
  57. assert((!v1.isEmpty)|| (!v2.isEmpty))
  58. v1+v2+" "
  59. }
  60. }

贝叶斯分类
     下面结合pipeline的处理方式用贝叶斯对文本进行分类,程序如下:
  1. package txtMIning
  2. import org.apache.hadoop.conf.Configuration
  3. import org.apache.hadoop.fs._
  4. import org.apache.spark.ml.Pipeline
  5. import org.apache.spark.ml.classification.{DecisionTreeClassifier, NaiveBayes}
  6. import org.apache.spark.ml.feature.{HashingTF, RegexTokenizer}
  7. import org.apache.spark.sql.{DataFrame, SQLContext}
  8. import org.apache.spark.{AccumulatorParam, SparkConf, SparkContext}
  9. import scala.util.Random
  10. /**
  11. * 新闻分类
  12. */
  13. object NewClassifier {
  14. def listSonRoute(path: String): Seq[String] ={
  15. val conf = new Configuration()
  16. val fs = new Path(path).getFileSystem(conf)
  17. val status = fs.listFiles(new Path(path),true)
  18. var res: List[String] = Nil
  19. while (status.hasNext){
  20. res = res++Seq(status.next().getPath.toString)
  21. }
  22. res
  23. }
  24. /**
  25. * 提取英文单词或者字母
  26. * @param content
  27. * @return
  28. */
  29. def splitStr(content: String): List[String] =("[A-Za-z]+$".r findAllIn content).toList
  30. def rdd2Str(sc:SparkContext,path:String)= {
  31. val rdd = sc.textFile(path)
  32. val myAccumulator = sc.accumulator[String](" ")(StringAccumulatorParam)
  33. rdd.foreach{ part=> splitStr(part).foreach{ word =>
  34. myAccumulator.add(word.toLowerCase)
  35. }}
  36. myAccumulator.value
  37. }
  38. def getDataFromHDFS(sc:SparkContext,path:String): DataFrame ={
  39. val sqlContext = new SQLContext(sc)
  40. import sqlContext.implicits._
  41. listSonRoute(path).map(
  42. part =>
  43. LabeledText(part.split("/").apply(8),new Random(100).nextInt(),rdd2Str(sc,part))
  44. ).toDF()
  45. }
  46. def readParquetFile(sc:SparkContext,path:String)={
  47. val sqlContext = new SQLContext(sc)
  48. sqlContext.read.parquet(path).toDF()
  49. }
  50. case class LabeledText(item:String,label:Double,doc:String)
  51. /**
  52. *
  53. *
  54. * @param args
  55. */
  56. def main(args: Array[String]) {
  57. val conf = new SparkConf().setAppName("new Classifier").setMaster("local")
  58. .set("spark.storage.memoryFraction", "0.1")
  59. val sc = new SparkContext(conf)
  60. // // rawData to parquet
  61. // val testPath = "hdfs://master:9000/data/studySet/textMining/20news-bydate/20news-bydate-test"
  62. // val trainPath = "hdfs://master:9000/data/studySet/textMining/20news-bydate/20news-bydate-train/"
  63. // val testDF = getDataFromHDFS(sc,testPath)
  64. // val trainDF = getDataFromHDFS(sc,trainPath)
  65. // testDF.write.save("hdfs://master:9000/data/studySet/textMining/20news-bydate/test")
  66. // trainDF.write.save("hdfs://master:9000/data/studySet/textMining/20news-bydate/train")
  67. //数据增加label(而且这个label必须是Double类型)
  68. val testParquetPath = "hdfs://master:9000/data/studySet/textMining/20news-bydate/test/*"
  69. val trainParquetPath = "hdfs://master:9000/data/studySet/textMining/20news-bydate/train/*"
  70. val testDF: DataFrame = readParquetFile(sc,testParquetPath)//.sample(withReplacement = true,0.002)
  71. val trainDF: DataFrame = readParquetFile(sc,trainParquetPath)//.sample(withReplacement = true,0.002)
  72. // val pre = readParquetFile(sc,"hdfs://master:9000/data/studySet/textMining/20news-bydate/prediction2/*")
  73. // pre.show(200)
  74. // pre.toJavaRDD.saveAsTextFile("hdfs://master:9000/data/studySet/textMining/20news-bydate/prediction3")
  75. //testDF.foreach(println)
  76. //val trainDF: DataFrame = readParquetFile(sc,trainParquetPath)
  77. //trainDF.show()
  78. // testDF.show(5)
  79. //testDF.take(1).foreach(println)
  80. //[alt.atheism, answers t na translator had but determine kaflowitz ]
  81. //use the pipeline
  82. val tokenizer = new RegexTokenizer()
  83. .setInputCol("doc")
  84. .setOutputCol("words")
  85. val hashingTF = new HashingTF()
  86. .setInputCol(tokenizer.getOutputCol)
  87. .setOutputCol("features")
  88. // .setNumFeatures(100000)
  89. val naiveBayes = new NaiveBayes()
  90. .setPredictionCol("prediction")
  91. val decisionTree = new DecisionTreeClassifier()
  92. .setPredictionCol("prediction")
  93. val pipeline = new Pipeline().setStages(Array(tokenizer,hashingTF,decisionTree))
  94. // //???????
  95. val model = pipeline.fit(testDF)
  96. //println(model.explainParams())
  97. val trainPredictions = model.transform(trainDF)
  98. trainPredictions.show(50)
  99. //trainPredictions.write.save("hdfs://master:9000/data/studySet/textMining/20news-bydate/prediction2")
  100. }
  101. }
  102. object StringAccumulatorParam extends AccumulatorParam[String] {
  103. override def addInPlace(r1: String, r2: String): String = add(r1,r2)
  104. /**
  105. * 初始化
  106. * @param initialValue 初始值
  107. * @return
  108. */
  109. override def zero(initialValue: String): String = ""
  110. def add(v1:String,v2:String)={
  111. assert((!v1.isEmpty)|| (!v2.isEmpty))
  112. v1+v2+" "
  113. }
  114. }

结果如下:















声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/小丑西瓜9/article/detail/564651
推荐阅读
相关标签
  

闽ICP备14008679号