赞
踩
这个程序中,我们使用pipeline来完成整个预测流程,加入了10-fold cross validation。
import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.sql.SQLContext
import org.apache.spark.ml.Pipeline
import org.apache.spark.ml.tuning.{ParamGridBuilder, CrossValidator}
import org.apache.spark.ml.classification.RandomForestClassifier
import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator
import org.apache.spark.ml.feature.{IndexToString, StringIndexer, VectorAssembler}
/**
* Created by simon on 2017/5/8.
*/
object genderClassificationWithRandomForest {
def main(args: Array[String]): Unit = {
val conf = new SparkConf()
conf.setAppName("genderClassification").setMaster("local[2]")
val sc = new SparkContext(conf)
val sqlContext = new SQLContext(sc)
import sqlContext.implicits._
val trainData = sc.textFile("file:\\E:\\test.csv")
// 第一步,预处理数据,构建为DataFrame格式
val data = trainData.map { line =>
val parts= line.split("\\|")
val label = toInt(parts(1)) //第二列是标签
val features = Vectors.dense(parts.slice(6,parts.length-1).map(_.toDouble)) //第7到最后一列是属性,需要转换为Doube类型
LabeledPoint(label, features) //构建LabelPoint格式
}.toDF()
// 第二步,将数据随机分为训练集和测试集
val Array(training, testing) = data.randomSplit(Array(0.7, 0.3),131L)
// 第三步,准备一些基本参数和标签列indexer
// 设置K折交叉验证的K的数量,以及随机森林树的数量,树的数量增加会大幅度增加训练时间
val nFolds: Int = 10
val NumTrees: Int = 500 //800,2000
val indexer = new StringIndexer()
.setInputCol("label")
.setOutputCol("label_idx")
// 第四步,创建随机森林分类器
val rf = new RandomForestClassifier()
.setNumTrees(NumTrees)
.setFeaturesCol("features")
.setLabelCol("label_idx")
.setFeatureSubsetStrategy("auto")
.setImpurity("gini")
.setMaxDepth(10) //2,5,7
.setMaxBins(100)
// 第五步,创建pipeline
val pipeline = new Pipeline().setStages(Array(indexer,rf))
// 第六步,创建参数
val paramGrid = new ParamGridBuilder().build()
// 第七步,设置预测效果测量器
val evaluator = new BinaryClassificationEvaluator()
.setLabelCol("label")
.setRawPredictionCol("rawPrediction")
.setMetricName("areaUnderROC")
// 第八步,创建交叉验证对象,设置好pipeline、测量器、参数、K的数量
val cv = new CrossValidator()
.setEstimator(pipeline)
.setEvaluator(evaluator)
.setEstimatorParamMaps(paramGrid)
.setNumFolds(nFolds)
// 第九步,使用训练集训练模型
val model = cv.fit(training)
// 第十步,拿训练好的模型预测测试集
val predictions = model.transform(testing)
predictions.show()
// 第十一步,测量预测效果
val metrics = evaluator.evaluate(predictions)
println(metrics)
}
// 将标签转换为0和1
def toInt(s: String): Int = {
if (s == "m") 1 else 0
}
}
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。