当前位置:   article > 正文

educoder中Spark GraphX—预测社交圈子

spark graphx—预测社交圈子

第1关:计算连通分量

  1. import org.apache.log4j.{Level, Logger}
  2. import org.apache.spark.{SparkConf, SparkContext}
  3. import org.apache.spark.graphx._
  4. object connectComponents{
  5. def main(args: Array[String]): Unit = {
  6. val conf = new SparkConf().setAppName("connectComponents ").setMaster("local[4]")
  7. val sc = new SparkContext(conf) //屏蔽日志
  8. Logger.getLogger("org.apache.spark").setLevel(Level.WARN)
  9. Logger.getLogger("org.eclipse.jetty.server").setLevel(Level.OFF)
  10. //**************Begin*************************
  11. //构造VertexRDD
  12. val myVertices = sc.parallelize((1L to 8L).map((_,"")))
  13. //构造EdgeRDD
  14. val myEdges = sc.parallelize(Array(Edge(1L,2L,""),Edge(4L,5L,""),Edge(4L,6L,""),Edge(5L,6L,""),Edge(5L,7L,""),Edge(7L,6L,"")))
  15. //构造图Graph[VD,ED]
  16. val myGraph=Graph(myVertices,myEdges)
  17. //计算连通分量
  18. val connectedcomponents = myGraph.connectedComponents.vertices.map(_.swap).groupByKey.map(_._2).collect
  19. println("")
  20. //输出结果
  21. connectedcomponents.foreach(println)
  22. //**************End**************************
  23. }
  24. }

第2关:社交圈子预测1

  1. import org.apache.log4j.{Level, Logger}
  2. import org.apache.spark.{SparkConf, SparkContext}
  3. import org.apache.spark.graphx._
  4. object predict1_s{
  5. def main(args: Array[String]): Unit = {
  6. val conf = new SparkConf().setAppName("predict1_s").setMaster("local[4]")
  7. val sc = new SparkContext(conf) //屏蔽日志
  8. Logger.getLogger("org.apache.spark").setLevel(Level.WARN)
  9. Logger.getLogger("org.eclipse.jetty.server").setLevel(Level.OFF)
  10. // 处理egonet文件的每行数据,返回元组形式的边数据
  11. def get_edges_from_line(line: String): Array[(Long, Long)] = {
  12. val ary = line.split(":")
  13. //划分源顶点和目标顶点
  14. val srcId = ary(0).toInt
  15. //根据分隔符把字符串切分为一个数组
  16. val dstIds = ary(1).split(" ")
  17. val edges = for {
  18. dstId <- dstIds
  19. if (dstId != "")
  20. } yield {
  21. //将源顶点和目标顶点组合成元组
  22. (srcId.toLong, dstId.toLong)
  23. }
  24. if (edges.size > 0) edges else Array((srcId, srcId))
  25. }
  26. //读取239.egonet文件内容
  27. val egonet_example = sc.textFile("/root/data/egonets/239.egonet")
  28. //**************Begin*************************
  29. //根据定义的get_edges_from_line函数构造由边数据组成的Array
  30. val edges_array = egonet_example.map(x => get_edges_from_line(x)).collect
  31. //但现在是一个二维数组,需要用flatten函数转化为一维数组
  32. val edges = edges_array.flatten
  33. //构造rawEdges:RDD[(VetexId,VertexId)]
  34. val g_edges = sc.makeRDD(edges)
  35. //使用fromEdgeTuples构造图
  36. val g = Graph.fromEdgeTuples(g_edges,1)
  37. //**************End**************************
  38. //**************Begin*************************
  39. //找出图中的连通分量
  40. val connectedcomponents = g.connectedComponents.vertices.map(_.swap).groupByKey.map(_._2).collect
  41. println("")
  42. //输出结果
  43. connectedcomponents.foreach(println)
  44. //**************End**************************
  45. }
  46. }

第3关:社交圈子预测2

  1. import org.apache.log4j.{Level, Logger}
  2. import org.apache.spark.{SparkConf, SparkContext}
  3. import org.apache.spark.graphx._
  4. object predict2_s{
  5. def main(args: Array[String]): Unit = {
  6. val conf = new SparkConf().setAppName("predict2_s").setMaster("local[4]")
  7. val sc = new SparkContext(conf) //屏蔽日志
  8. Logger.getLogger("org.apache.spark").setLevel(Level.WARN)
  9. Logger.getLogger("org.eclipse.jetty.server").setLevel(Level.OFF)
  10. // 从path/userId.egonet格式的文件路径中解析出用户ID
  11. def extract(s: String) = {
  12. val Pattern = """^.*?(\d+).egonet""".r
  13. val Pattern(num) = s
  14. num
  15. }
  16. // 处理egonet文件的每行数据,返回元组形式的边数据
  17. def get_edges_from_line(line: String): Array[(Long, Long)] = {
  18. val ary = line.split(":")
  19. val srcId = ary(0).toInt
  20. val dstIds = ary(1).split(" ")
  21. val edges = for {
  22. dstId <- dstIds
  23. if (dstId != "")
  24. } yield {
  25. (srcId.toLong, dstId.toLong)
  26. }
  27. if (edges.size > 0) edges else Array((srcId, srcId))
  28. }
  29. // 根据文件内容构造边元组
  30. def make_edges(contents: String) = {
  31. val lines = contents.split("\n")
  32. // 根据get_edges_from_line构造元组类型的数组,此时unflat为二维数组
  33. val unflat = for {
  34. line <- lines
  35. } yield {
  36. get_edges_from_line(line)
  37. }
  38. //使用flatten函数降维,并返回值
  39. val flat = unflat.flatten
  40. flat
  41. }
  42. // 构建图对象,执行connectedComponents()操作,返回结果
  43. def get_circles(flat: Array[(Long, Long)]) = {
  44. val edges = sc.makeRDD(flat)
  45. val g = Graph.fromEdgeTuples(edges,1)
  46. val cc = g.connectedComponents()
  47. cc.vertices.map(x => (x._2, Array(x._1))).
  48. reduceByKey( (a,b) => a ++ b).
  49. values.map(_.mkString(" ")).collect.mkString(";")
  50. }
  51. //读取目录内容
  52. val egonets = sc.wholeTextFiles("/root/data/egonets")
  53. //**************Begin*************************
  54. //获取用户ID
  55. val egonet_numbers = egonets.map(x => extract(x._1)).collect
  56. //构造边元组
  57. val egonet_edges = egonets.map(x => make_edges(x._2)).collect
  58. //根据get_circles函数构造图,并计算连通分量
  59. val egonet_circles = egonet_edges.toList.map(x => get_circles(x))
  60. //**************End**************************
  61. //**************Begin*************************
  62. //将用户ID与预测的社交圈子组合,需要使用zip函数
  63. val result = egonet_numbers.zip(egonet_circles).map(x => x._1 + "," + x._2)
  64. println("")
  65. println("UserId,Predicted social circles(Every social circle is used ';' separated.)")
  66. //换行输出最终结果
  67. println(result.mkString("\n"))
  68. //**************End**************************
  69. }
  70. }

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/AllinToyou/article/detail/666735
推荐阅读
相关标签
  

闽ICP备14008679号