当前位置:   article > 正文

SPARK官方实例:两种方法实现随机森林模型(ML/MLlib)_随机森林 sparkml 视频教程

随机森林 sparkml 视频教程

在spark2.0以上版本中,存在两种对机器学习算法的实现库MLlib与ML,比如随机森林:
org.apache.spark.mllib.tree.RandomForest

org.apache.spark.ml.classification.RandomForestClassificationModel

两种库对应的使用方法也不同,Mllib是RDD-based API,
ML是基于ML pipeline的API与dataframe的数据结构。
详见http://spark.apache.org/docs/latest/ml-guide.html
所以官方实例也是有很大区别的,下面分别给出了源码和注释:

MLlib的模型实现

  1. // scalastyle:off println
  2. package org.apache.spark.examples.mllib
  3. import org.apache.spark.{SparkConf, SparkContext}
  4. // $example on$
  5. import org.apache.spark.mllib.tree.RandomForest
  6. import org.apache.spark.mllib.tree.model.RandomForestModel
  7. import org.apache.spark.mllib.util.MLUtils
  8. // $example off$
  9. object RandomForestClassificationExample {
  10. def main(args: Array[String]): Unit = {
  11. val conf = new SparkConf().setAppName("RandomForestClassificationExample")
  12. val sc = new SparkContext(conf)
  13. // $example on$
  14. // Load and parse the data file.
  15. val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt")
  16. // Split the data into training and test sets (30% held out for testing)
  17. val splits = data.randomSplit(Array(0.7, 0.3))
  18. val (trainingData, testData) = (splits(0), splits(1))
  19. // Train a RandomForest model.
  20. // Empty categoricalFeaturesInfo indicates all features are continuous.
  21. val numClasses = 2
  22. val categoricalFeaturesInfo = Map[Int, Int]()
  23. val numTrees = 3 // Use more in practice.
  24. val featureSubsetStrategy = "auto" // Let the algorithm choose.
  25. val impurity = "gini"
  26. val maxDepth = 4
  27. val maxBins = 32
  28. val model = RandomForest.trainClassifier(trainingData, numClasses, categoricalFeaturesInfo,
  29. numTrees, featureSubsetStrategy, impurity, maxDepth, maxBins)
  30. // Evaluate model on test instances and compute test error
  31. val labelAndPreds = testData.map { point =>
  32. val prediction = model.predict(point.features)
  33. (point.label, prediction)
  34. }
  35. val testErr = labelAndPreds.filter(r => r._1 != r._2).count.toDouble / testData.count()
  36. println("Test Error = " + testErr)
  37. println("Learned classification forest model:\n" + model.toDebugString)
  38. // Save and load model
  39. model.save(sc, "target/tmp/myRandomForestClassificationModel")
  40. val sameModel = RandomForestModel.load(sc, "target/tmp/myRandomForestClassificationModel")
  41. // $example off$
  42. }
  43. }
  44. // scalastyle:on println

ML的模型实现

  1. // scalastyle:off println
  2. package org.apache.spark.examples.ml
  3. // $example on$
  4. import org.apache.spark.ml.Pipeline
  5. import org.apache.spark.ml.classification.{RandomForestClassificationModel, RandomForestClassifier}
  6. import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator
  7. import org.apache.spark.ml.feature.{IndexToString, StringIndexer, VectorIndexer}
  8. // $example off$
  9. import org.apache.spark.sql.SparkSession
  10. object RandomForestClassifierExample {
  11. def main(args: Array[String]): Unit = {
  12. val spark = SparkSession
  13. .builder
  14. .appName("RandomForestClassifierExample")
  15. .getOrCreate()
  16. // $example on$
  17. // Load and parse the data file, converting it to a DataFrame.
  18. val data = spark.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt")
  19. // Index labels, adding metadata to the label column.
  20. // Fit on whole dataset to include all labels in index.
  21. val labelIndexer = new StringIndexer()
  22. .setInputCol("label")
  23. .setOutputCol("indexedLabel")
  24. .fit(data)
  25. // Automatically identify categorical features, and index them.
  26. // Set maxCategories so features with > 4 distinct values are treated as continuous.
  27. val featureIndexer = new VectorIndexer()
  28. .setInputCol("features")
  29. .setOutputCol("indexedFeatures")
  30. .setMaxCategories(4)
  31. .fit(data)
  32. // Split the data into training and test sets (30% held out for testing).
  33. val Array(trainingData, testData) = data.randomSplit(Array(0.7, 0.3))
  34. // Train a RandomForest model.
  35. val rf = new RandomForestClassifier()
  36. .setLabelCol("indexedLabel")
  37. .setFeaturesCol("indexedFeatures")
  38. .setNumTrees(10)
  39. // Convert indexed labels back to original labels.
  40. val labelConverter = new IndexToString()
  41. .setInputCol("prediction")
  42. .setOutputCol("predictedLabel")
  43. .setLabels(labelIndexer.labels)
  44. // Chain indexers and forest in a Pipeline.
  45. val pipeline = new Pipeline()
  46. .setStages(Array(labelIndexer, featureIndexer, rf, labelConverter))
  47. // Train model. This also runs the indexers.
  48. val model = pipeline.fit(trainingData)
  49. // Make predictions.
  50. val predictions = model.transform(testData)
  51. // Select example rows to display.
  52. predictions.select("predictedLabel", "label", "features").show(5)
  53. // Select (prediction, true label) and compute test error.
  54. val evaluator = new MulticlassClassificationEvaluator()
  55. .setLabelCol("indexedLabel")
  56. .setPredictionCol("prediction")
  57. .setMetricName("accuracy")
  58. val accuracy = evaluator.evaluate(predictions)
  59. println("Test Error = " + (1.0 - accuracy))
  60. val rfModel = model.stages(2).asInstanceOf[RandomForestClassificationModel]
  61. println("Learned classification forest model:\n" + rfModel.toDebugString)
  62. // $example off$
  63. spark.stop()
  64. }
  65. }
  66. // scalastyle:on println

TIPS:
想看http://spark.apache.org/docs里面示例代码的全部吗?一种方法是去github上找,另一种方法是进spark的安装目录,所有的源码都在 spark/examples/src/main/scala/里面,
如ML的算法scala实现:
spark/examples/src/main/scala/org/apache/spark/examples/ml
MLlib的算法scala实现:
spark/examples/src/main/scala/org/apache/spark/examples/mllib

转载于:https://blog.csdn.net/dahunbi/article/details/72821915?locationNum=3&fps=1

声明:本文内容由网友自发贡献,转载请注明出处:【wpsshop博客】
推荐阅读
相关标签
  

闽ICP备14008679号