当前位置:   article > 正文

聚类LDA_lda聚类

lda聚类

1. 聚类LDA

1.1 概念

LDALatent Dirichlet Allocation)是一种文档主题生成模型,也称为一个三层贝叶斯概率模型,包含词、主题和文档三层结构。所谓生成模型,就是说,我们认为一篇文章的每个词都是通过以一定概率选择了某个主题,并从这个主题中以一定概率选择某个词语这样一个过程得到。文档到主题服从多项式分布,主题到词服从多项式分布。[1] 

LDA是一种非监督机器学习技术,可以用来识别大规模文档集(documentcollection)或语料库(corpus)中潜藏的主题信息。它采用了词袋(bag of words)的方法,这种方法将每一篇文档视为一个词频向量,从而将文本信息转化为了易于建模的数字信息。但是词袋方法没有考虑词与词之间的顺序,这简化了问题的复杂性,同时也为模型的改进提供了契机。每一篇文档代表了一些主题所构成的一个概率分布,而每一个主题又代表了很多单词所构成的一个概率分布。

 

1.2 用处

聚类,显示出高权重的主题。词

1.3 细节

有em和online两种方式,不同方式设置的参数和结果不同。

Model有两个参数likelihood(越大越好)和Perplexity(越小越好)

1.4 Demo

  1. package spark.mllib
  2. import org.apache.spark.ml.Pipeline
  3. import org.apache.spark.ml.feature.{Normalizer, PCA}
  4. import org.apache.spark.ml.linalg.{Vector, Vectors}
  5. import org.apache.spark.mllib.linalg.{Vector, Vectors}
  6. import org.apache.spark.sql.functions.{col, udf}
  7. import org.apache.spark.sql.types.{ArrayType, StringType, StructField, StructType}
  8. import org.apache.spark.sql.{Column, DataFrame, Row, SparkSession}
  9. import org.apache.spark.{SparkConf, SparkContext}
  10. import scala.collection.mutable
  11. import scala.collection.mutable.ArrayBuffer
  12. /**
  13. * Created by liuwei on 2017/7/24.
  14. */
  15. object LDATest {
  16. def main(args: Array[String]): Unit = {
  17. import org.apache.spark.ml.clustering.LDA
  18. import org.apache.spark.ml.linalg.Vector
  19. import org.apache.spark.ml.linalg.Vectors
  20. val sparkConf = new SparkConf().setAppName("LDATest").setMaster("local[8]")
  21. val sc = new SparkContext(sparkConf)
  22. val spark = SparkSession.builder.getOrCreate()
  23. // Loads data.
  24. val dataset:DataFrame = spark.read.format("libsvm")
  25. .load("data/mllib/sample_lda_libsvm_data.txt")
  26. dataset.show(false)
  27. // Trains a LDA model.
  28. val lda = new LDA()
  29. .setK(10)//k: 主题数,或者聚类中心数 >1
  30. .setMaxIter(10)// MaxIterations:最大迭代次数 >= 0
  31. // .setCheckpointInterval(1) //迭代计算时检查点的间隔 set checkpoint interval (>= 1) or disable checkpoint (-1)
  32. .setDocConcentration(0.1) //文章分布的超参数(Dirichlet分布的参数),必需>1.0
  33. .setTopicConcentration(0.1)//主题分布的超参数(Dirichlet分布的参数),必需>1.0
  34. .setOptimizer("online") //默认 online 优化计算方法,目前支持"em", "online"
  35. val model = lda.fit(dataset.select("features"))
  36. val ll = model.logLikelihood(dataset)
  37. val lp = model.logPerplexity(dataset)
  38. println(s"The lower bound on the log likelihood of the entire corpus: $ll")
  39. println(s"The upper bound on perplexity: $lp")
  40. val hm2 = new mutable.HashMap[Int,String]
  41. // val a = sc.textFile("data/mllib/C0_segfeatures.txt").map( x => x.split(",")).map( x =>
  42. // hm2.put(x(0).replaceAll("\"","").toInt,x(1).replaceAll("\"",""))
  43. hm2.put()
  44. // )
  45. // println(a.count())
  46. // hm2.put("ok","ok")
  47. // var data = sc.textFile("data/mllib/C0_segfeatures.txt").map( x => x.split(",")).collect()
  48. // data.foreach{pair => hm2.put(pair(0).replaceAll("\"","").toInt,pair(1).replaceAll("\"",""))}
  49. // println(hm2+"============")
  50. // val rdd = sc.textFile("data/mllib/C0_segfeatures.txt").map( x => x.split(",")).map( x =>
  51. // Row(x(0).replaceAll("\"",""),x(1).replaceAll("\"",""))
  52. // )
  53. // var data = rdd.collect()
  54. // data.foreach{pair => hm2.put(pair._1,pair._2)}
  55. // val schema = StructType(
  56. // Seq(
  57. // StructField("index",StringType,true)
  58. // ,StructField("word",StringType,true)
  59. // )
  60. // )
  61. // val wordDataset = spark.createDataFrame(rdd,schema)
  62. val hm = mutable.HashMap(1 -> "b", 2 -> "c",3-> "d", 6 -> "a",9-> "e", 10 -> "f")
  63. // model.l
  64. val resultUDF = udf((termIndices: mutable.WrappedArray[Integer]) => {//处理第二列输出
  65. termIndices.map(index=>
  66. // hm2.get(index)
  67. index
  68. )
  69. })
  70. // Describe topics.
  71. val topics = model.describeTopics(10)//.withColumn("termIndices", resultUDF(col("termIndices")))
  72. println(topics.schema)
  73. // .withColumn("termIndices", resultUDF(col("termIndices"))).withColumn("termWeights", resultUDF(col("termWeights")))
  74. println("The topics described by their top-weighted terms:")
  75. // topics.join(topics, wordDataset("index") === topics("termIndices")).show()
  76. topics.show(false)
  77. val cosUDF = udf {
  78. (vector: Vector) =>
  79. vector.argmax
  80. }
  81. // Shows the result.
  82. var transformed = model.transform(dataset)
  83. transformed = transformed.withColumn("prediction",cosUDF(col("topicDistribution")))
  84. println(transformed.schema)
  85. transformed.show(false)
  86. println(" transform start. ").setK(5).fit(df)
  87. val result = pca.transform(df).select("pcaFeatures")
  88. result.show(false)
  89. }
  90. }


声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/羊村懒王/article/detail/361563
推荐阅读
相关标签
  

闽ICP备14008679号