当前位置:   article > 正文

使用Spark MLlib随机森林RandomForest+pipeline进行预测_随机森林模型的pipeline search

随机森林模型的pipeline search

这个程序中,我们使用pipeline来完成整个预测流程,加入了10-fold cross validation。

import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.sql.SQLContext
import org.apache.spark.ml.Pipeline
import org.apache.spark.ml.tuning.{ParamGridBuilder, CrossValidator}
import org.apache.spark.ml.classification.RandomForestClassifier
import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator
import org.apache.spark.ml.feature.{IndexToString, StringIndexer, VectorAssembler}

/**
  * Created by simon on 2017/5/8.
  */
object genderClassificationWithRandomForest {
  def main(args: Array[String]): Unit = {
    val conf = new SparkConf()
    conf.setAppName("genderClassification").setMaster("local[2]")
    val sc = new SparkContext(conf)
    val sqlContext = new SQLContext(sc)
    import sqlContext.implicits._
    val trainData = sc.textFile("file:\\E:\\test.csv")

// 第一步,预处理数据,构建为DataFrame格式
    val data = trainData.map { line =>
      val parts= line.split("\\|")
      val label = toInt(parts(1)) //第二列是标签
      val features = Vectors.dense(parts.slice(6,parts.length-1).map(_.toDouble)) //第7到最后一列是属性,需要转换为Doube类型
      LabeledPoint(label, features) //构建LabelPoint格式
    }.toDF()

// 第二步,将数据随机分为训练集和测试集
    val Array(training, testing) = data.randomSplit(Array(0.7, 0.3),131L)

// 第三步,准备一些基本参数和标签列indexer
// 设置K折交叉验证的K的数量,以及随机森林树的数量,树的数量增加会大幅度增加训练时间
    val nFolds: Int = 10
    val NumTrees: Int = 500 //800,2000

    val indexer = new StringIndexer()
      .setInputCol("label")
      .setOutputCol("label_idx")

// 第四步,创建随机森林分类器
    val rf = new RandomForestClassifier()
      .setNumTrees(NumTrees)
      .setFeaturesCol("features")
      .setLabelCol("label_idx")
      .setFeatureSubsetStrategy("auto")
      .setImpurity("gini")
      .setMaxDepth(10) //2,5,7
      .setMaxBins(100)

// 第五步,创建pipeline
    val pipeline = new Pipeline().setStages(Array(indexer,rf))

// 第六步,创建参数
    val paramGrid = new ParamGridBuilder().build()


// 第七步,设置预测效果测量器
    val evaluator = new BinaryClassificationEvaluator()
      .setLabelCol("label")
      .setRawPredictionCol("rawPrediction")
      .setMetricName("areaUnderROC")

// 第八步,创建交叉验证对象,设置好pipeline、测量器、参数、K的数量
    val cv = new CrossValidator()
      .setEstimator(pipeline)
      .setEvaluator(evaluator)
      .setEstimatorParamMaps(paramGrid)
      .setNumFolds(nFolds)

// 第九步,使用训练集训练模型
    val model = cv.fit(training)

// 第十步,拿训练好的模型预测测试集
    val predictions = model.transform(testing)

    predictions.show()

// 第十一步,测量预测效果
    val metrics  = evaluator.evaluate(predictions)
    println(metrics)
  }

  // 将标签转换为01
  def toInt(s: String): Int = {
    if (s == "m") 1 else  0
  }

}
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/小小林熬夜学编程/article/detail/693268
推荐阅读
相关标签
  

闽ICP备14008679号