赞
踩
**package** mllib.tree **import** org.apache.log4j.{Level, Logger} **import** org.apache.spark.mllib.evaluation.MulticlassMetrics **import** org.apache.spark.mllib.linalg.Vectors **import** org.apache.spark.mllib.regression.LabeledPoint **import** org.apache.spark.mllib.tree.{RandomForest, DecisionTree} **import** org.apache.spark.mllib.tree.model.{RandomForestModel, DecisionTreeModel} **import** org.apache.spark.rdd.RDD **import** org.apache.spark.{SparkContext, SparkConf} _/**_ _* Created by_ _汪本成_ _on 2016/7/18._ _*/_ **object** randomForest { //屏蔽不必要的日志显示在终端上 // Logger.getLogger("org.apache.spark").setLevel(Level.WARN) // Logger.getLogger("org.apache.eclipse.jetty.server").setLevel(Level.OFF) **var** _beg_ = System.currentTimeMillis() //创建入口对象 **val** _conf_ = **new** SparkConf().setAppName("rndomForest").setMaster("local") **val** _sc_ = **new** SparkContext( _conf_ ) **val** _HDFS_COVDATA_PATH_ = "hdfs://192.168.43.150:9000/user/spark/sparkLearning/mllib/covtype.data" **val** _rawData_ = _sc_.textFile( _HDFS_COVDATA_PATH_ ) //设置LabeledPoint格式 **val** _data_ = _rawData_.map{ line => **val** values = line.split(",").map(_.toDouble) // init返回除最后一个值之外的所有值,最后一列是目标 **val** FeatureVector = Vectors.dense(values.init) //决策树要求(目标变量)label从0开始,所以要减一 **val** label = values.last - 1 LabeledPoint(label, FeatureVector) } //分成训练集(80%),交叉验证集(10%),测试集(10%) **val** Array( _trainData_ , _cvData_ , _testData_ ) = _data_.randomSplit(Array(0.8, 0.1, 0.1)) _trainData_.cache() _cvData_.cache() _testData_.cache() //新建随机森林 **val** _numClass_ = 7 //分类数量 **val** _categoricalFeaturesInfo_ = _Map_ [Int, Int](10 -> 4, 11-> 40) //用map存储类别(离散)特征及每个类特征对应值(类别)的数量 **val** _impurity_ = "entropy" //纯度计算方法,用于信息增益的计算 **val** _number_ = 20 //构建树的数量 **val** _maxDepth_ = 4 //树的最大高度 **val** _maxBins_ = 100 // 用于分裂特征的最大划分数量 //训练分类决策树模型 **val** _model_ = RandomForest.trainClassifier( _trainData_ , _numClass_ , _categoricalFeaturesInfo_ , _number_ , "auto", _impurity_ , _maxDepth_ , _maxBins_ ) **val** _metrics_ = getMetrics( _model_ , _cvData_ ) //计算精确度(样本比例) **val** _precision_ = _metrics_. _precision_ __ //计算每个样本的准确度(召回率) **val** _recall_ = (0 until 7).map( //DecisionTreeModel模型的类别号从0开始 cat => ( _metrics_.precision(cat), _metrics_.recall(cat)) ) **val** _end_ = System.currentTimeMillis() //耗时时间 **var** _castTime_ = _end_ - _beg_ ____**def** main(args: Array[String]) { println("========================================================================================") //精确度(样本比例) println("精确度: " + _precision_ ) println("=========================================&#
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。