当前位置:   article > 正文

利用spark的mllib构建GBDT模型_mllib gbdt

mllib gbdt

GBDT模型

GBDT模型的介绍,我主要是参考博客:http://blog.csdn.net/w28971023/article/details/8240756
在这里,我主要归纳以下几点要素:
1.GBDT中的树都是回归树;
2.回归树节点分割点衡量最好的标准是叶子个数的上限;
3.GBDT的核心在于,每个棵树学的是之前所有树结论和的残差,这个残差就是一个加预测值后能得到真实值的累加量;
4.GB为Gradient Boosting, Boosting的最大好处在于,每一步的残差计算其实变相地增大了分错instance的权重,而已经分对的instance则趋向于0;
5.GBDT采用一个Shrinkage策略,本质上,Shrinkage为每棵树设置了一个weight,累加时要乘以这个weight,但和Gradient并没有关系。

利用spark构建GBDT模型

训练GBDT模型

public void trainModel(){

        //初始化spark
        SparkConf conf = new SparkConf().setAppName("GBDT").setMaster("local");
        conf.set("spark.testing.memory","2147480000");
        SparkContext sc = new SparkContext(conf);

        //加载训练文件, 使用MLUtils包
        JavaRDD<LabeledPoint> lpdata = MLUtils.loadLibSVMFile(sc, this.trainsetFile).toJavaRDD();

        //训练模型, 默认情况下使用均值方差作为阈值标准
        int numIteration = 10;  //boosting提升迭代的次数
        int maxDepth = 3;       //回归树的最大深度
        BoostingStrategy boostingStrategy = BoostingStrategy.defaultParams("Regression");
        boostingStrategy.setNumIterations(numIteration);
        boostingStrategy.getTreeStrategy().setMaxDepth(maxDepth);
        //记录所有特征的连续结果
        Map<Integer, Integer> categoricalFeaturesInfoMap = new HashMap<Integer, Integer>();
        boostingStrategy.treeStrategy().setCategoricalFeaturesInfo(categoricalFeaturesInfoMap);
        //gdbt模型
        final GradientBoostedTreesModel model = GradientBoostedTrees.train(lpdata, boostingStrategy);
        model.save(sc, modelpath);
        sc.stop();
    }
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25

预测数据

public void predict() {
        //初始化spark
        SparkConf conf = new SparkConf().setAppName("GBDT").setMaster("local");
        conf.set("spark.testing.memory","2147480000");
        SparkContext sc = new SparkContext(conf);

        //加载gbdt模型
        final GradientBoostedTreesModel model = GradientBoostedTreesModel.load(sc, this.modelpath);

        //加载测试文件
        JavaRDD<LabeledPoint> testData = MLUtils.loadLibSVMFile(sc, this.predictFile).toJavaRDD();
        testData.cache();


        //预测数据
        JavaRDD<Tuple2<Double, Double>>  predictionAndLabel = testData.map(new Prediction(model)) ;

        //计算所有数据的平均值方差
         Double testMSE = predictionAndLabel.map(new CountSquareError()).reduce(new ReduceSquareError()) / testData.count();
         System.out.println("testData's MSE is : " + testMSE);
         sc.stop();
    }

    static class Prediction implements Function<LabeledPoint, Tuple2<Double , Double>> {
        GradientBoostedTreesModel model;
        public Prediction(GradientBoostedTreesModel model){
            this.model = model;
        }
        public Tuple2<Double, Double> call(LabeledPoint p) throws Exception {
            Double score = model.predict(p.features());
            return new Tuple2<Double , Double>(score, p.label());
        }
    }

    static class CountSquareError implements Function<Tuple2<Double, Double>, Double> {
        public Double call (Tuple2<Double, Double> pl) {
            double diff = pl._1() - pl._2();
            return diff * diff;
        }
    }

    static  class ReduceSquareError implements Function2<Double, Double, Double> {
        public Double call(Double a , Double b){
            return a + b ;
        }
    }
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46

关于具体的代码放至我的github上:https://github.com/Quincy1994/MachineLearning

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/菜鸟追梦旅行/article/detail/547190
推荐阅读
相关标签
  

闽ICP备14008679号