当前位置:   article > 正文

spark mllib源码分析之随机森林(Random Forest)(四)_sparkmllib训练随机森林

sparkmllib训练随机森林

spark源码分析之随机森林(Random Forest)(一)
spark源码分析之随机森林(Random Forest)(二)
spark源码分析之随机森林(Random Forest)(三)
spark源码分析之随机森林(Random Forest)(五)

6.4. node分裂

逻辑主要在DecisionTree.findBestSplits函数中,是RF训练最核心的部分

DecisionTree.findBestSplits(baggedInput, metadata, topNodes, nodesForGroup,
        treeToNodeToIndexInfo, splits, bins, nodeQueue, timer, nodeIdCache = nodeIdCache)
  • 1
  • 2
6.4.1. 数据统计

数据统计分成两部分,先在各个partition上分别统计,再累积各partition成全局统计。

6.4.1.1. 取出node的特征子集
val nodeToFeatures = getNodeToFeatures(treeToNodeToIndexInfo)
val nodeToFeaturesBc = input.sparkContext.broadcast(nodeToFeatures)
  • 1
  • 2

取出各node的特征子集,如果不需要抽样则为None;否则返回Map[Int, Array[Int]],其实就是将之前treeToNodeToIndexInfo中的NodeIndexInfo转换为map结构,将其作为广播变量nodeToFeaturesBc。

6.4.1.2. 分区统计

一系列函数的调用链,我们逐层分析

val partitionAggregates : RDD[(Int, DTStatsAggregator)] = if (nodeIdCache.nonEmpty) {
      input.zip(nodeIdCache.get.nodeIdsForInstances).mapPartitions { points =>
        // Construct a nodeStatsAggregators array to hold node aggregate stats,
        // each node will have a nodeStatsAggregator
        val nodeStatsAggregators = Array.tabulate(numNodes) { nodeIndex =>
          val featuresForNode = nodeToFeaturesBc.value.flatMap { nodeToFeatures =>
            Some(nodeToFeatures(nodeIndex))
          }
          new DTStatsAggregator(metadata, featuresForNode)
        }

        // iterator all instances in current partition and update aggregate stats
        points.foreach(binSeqOpWithNodeIdCache(nodeStatsAggregators, _))

        // transform nodeStatsAggregators array to (nodeIndex, nodeAggregateStats) pairs,
        // which can be combined with other partition using `reduceByKey`
        nodeStatsAggregators.view.zipWithIndex.map(_.swap).iterator
      }
    } else {
      input.mapPartitions { points =>
        // Construct a nodeStatsAggregators array to hold node aggregate stats,
        // each node will have a nodeStatsAggregator
        val nodeStatsAggregators = Array.tabulate(numNodes) { nodeIndex =>
          val featuresForNode = nodeToFeaturesBc.value.flatMap { nodeToFeatures =>
            Some(nodeToFeatures(nodeIndex))
          }
          new DTStatsAggregator(metadata, featuresForNode)
        }

        // iterator all instances in current partition and update aggregate stats
        points.foreach(binSeqOp(nodeStatsAggregators, _))

        // transform nodeStatsAggregators array to (nodeIndex, nodeAggregateStats) pairs,
        // which can be combined with other partition using `reduceByKey`
        nodeStatsAggregators.view.zipWithIndex.map(_.swap).iterator
      }
    }
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37

首先对每个partition构造一个DTStatsAggregator数组,长度是node的个数,注意这里实际使用的是数组,node怎样与自己的aggregator的对应?前面我们提到NodeIndexInfo的第一个成员是groupIndex,其值就是node的次序,和这里aggregator数组index其实是对应的,也就是说可以从NodeIndexInfo中取得groupIndex,然后作为数组index取得对应node的agg。DTStatsAggregator的入参是metadata和每个node的特征子集。然后将每个点统计到DTStatsAggregator中,其中调用了内部函数binSeqOp,

 /**
     * Performs a sequential aggregation over a partition.
     *
     * Each data point contributes to one node. For each feature,
     * the aggregate sufficient statistics are updated for the relevant bins.
     *
     * @param agg  Array storing aggregate calculation, with a set of sufficient statistics for
     *             each (node, feature, bin).
     * @param baggedPoint   Data point being aggregated.
     * @return  agg
     */
    def binSeqOp(
        agg: Array[DTStatsAggregator],
        baggedPoint: BaggedPoint[TreePoint]): Array[DTStatsAggregator] = {
    //对每个node
      treeToNodeToIndexInfo.foreach { case (treeIndex, nodeIndexToInfo) =>
        val nodeIndex = predictNodeIndex(topNodes(treeIndex), baggedPoint.datum.binnedFeatures,
          bins, metadata.unorderedFeatures)
        nodeBinSeqOp(treeIndex, nodeIndexToInfo.getOrElse(nodeIndex, null), agg, baggedPoint)
      }

      agg
    }
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23

首先调用函数predictNodeIndex计算nodeIndex,如果是首轮或者叶子节点,直接返回node.id;如果不是首轮,因为传入的是每棵树的root node,就从root node开始,逐渐往下判断该point应该是属于哪个node的,因为我们已经对node进行了分裂,这里其实实现了样本的划分。举个栗子,当前node如果是root的左孩子节点,而point预测节点应该属于右孩子,则调用nodeBinSepOp时就直接返回了,不会将这个point统计进去,用不大的时间换取样本集划分的空间,还是比较巧妙的。

/**
   * Get the node index corresponding to this data point.
   * Thi
  • 1
  • 2
  • 3
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/码创造者/article/detail/875371
推荐阅读
相关标签
  

闽ICP备14008679号