位于ml/tree/impl/目录下。mllib目录下的随机森林算法也是调用的ml下的RandomForest。ml是mllib的最新实现,将来是要替换掉mllib库的。
-
- RandomForest核心代码
- train方法
- RandomForest核心代码
每次迭代将要计算的node推入堆栈,选择参与计算的抽样数据,计算该节点,循环该过程。
while (nodeStack.nonEmpty) {
// Collect some nodes to split, and choose features for each node (if subsampling).
// Each group of nodes may come from one or multiple trees, and at multiple levels.
val (nodesForGroup, treeToNodeToIndexInfo) =
RandomForest.selectNodesToSplit(nodeStack, maxMemoryUsage, metadata, rng)
// Sanity check (should never occur):
assert(nodesForGroup.nonEmpty,
s"RandomForest selected empty nodesForGroup. Error for unknown reason.")
// Only send trees to worker if they contain nodes being split this iteration.
val topNodesForGroup: Map[Int, LearningNode] =
nodesForGroup.keys.map(treeIdx => treeIdx -> topNodes(treeIdx)).toMap
// Choose node splits, and enqueue new nodes as needed.
timer.start("findBestSplits")
RandomForest.findBestSplits(baggedInput, metadata, topNodesForGroup, nodesForGroup,
treeToNodeToIndexInfo, splits, nodeStack, timer, nodeIdCache)
timer.stop("findBestSplits")
}
-
- RandomForest算法
- training
- RandomForest算法
nodesForGroup:本次等待处理的节点集合。
topNodesForGroup:nodesForGroup所对应的每颗树的根节点。
def run(
input: RDD[LabeledPoint],
strategy: OldStrategy,
numTrees: Int,
featureSubsetStrategy: String,
seed: Long,
instr: Option[Instrumentation[_]],
parentUID: Option[String] = None): Array[DecisionTreeModel]
run方法返回DecisionTreemodel数组,每个成员是一个决策树,森林对每个决策树预测值加权得到最终预测结果。
循环处理节点:
(1)RandomForest.selectNodesToSplit
(2)RandomForest.findBestSplits
直到所有nodes都处理完毕,则循环结束,开始构造决策树模型,创建DecisionTreeClassificationModel。
所以这里最关键的是下面两个方法:
(1)RandomForest.selectNodesToSplit
(2)RandomForest.findBestSplits
-
-
- selectNodesToSplit
-
选择进行切分的节点。根据内存等状态选择本次切分的节点集合。返回(NodesForGroup,TreeToNodeToIndexInfo)。该方法的作用就是检查内存是否够用,在内存足够的情况下其实可以忽略该函数。
森林的每个树顶点保存在stack中,该方法从此stack中找出可以进行切分的节点,然后调用findBestSplits方法构造决策树。stack中的元素是动态变化的。
数据结构:
NodesForGroup:HashMap[Int, mutable.ArrayBuffer[LearningNode]]
key是treeIndex,value是node列表,表示属于该tree的node列表。
TreeToNodeToIndexInfo:HashMap[Int, mutable.HashMap[Int, NodeIndexInfo]]
key是treeIndex。
value是HashMap,其中key是nodeId,value是nodeIndexInfo(有featureSubset属性和本次group内的node数目)。由selectNodesToSplit方法创建该对象。featureSubset就是本节点需要处理的特征集合(是所有特征的子集)。
-
-
- findBestSplits
-
随机森林的【主函数】,找到最好切分。
重点分析:
/**
* Given a group of nodes, this finds the best split for each node.
*
* @param input Training data: RDD of [[TreePoint]]
* @param metadata Learning and dataset metadata
* @param topNodesForGroup For each tree in group, tree index -> root node.
* Used for matching instances with nodes.
* @param nodesForGroup Mapping: treeIndex --> nodes to be split in tree
* @param treeToNodeToIndexInfo Mapping: treeIndex --> nodeIndex --> nodeIndexInfo,
* where nodeIndexInfo stores the index in the group and the
* feature subsets (if using feature subsets).
* @param splits possible splits for all features, indexed (numFeatures)(numSplits)
* @param nodeStack Queue of nodes to split, with values (treeIndex, node).
* Updated with new non-leaf nodes which are created.
* @param nodeIdCache Node Id cache containing an RDD of Array[Int] where
* each value in the array is the data point's node Id
* for a corresponding tree. This is used to prevent the need
* to pass the entire tree to the executors during
* the node stat aggregation phase.
*/
private[tree] def findBestSplits(
input: RDD[BaggedPoint[TreePoint]],
metadata: DecisionTreeMetadata,
topNodesForGroup: Map[Int, LearningNode],
nodesForGroup: Map[Int, Array[LearningNode]],
treeToNodeToIndexInfo: Map[Int, Map[Int, NodeIndexInfo]],
splits: Array[Array[Split]],
nodeStack: mutable.Stack[(Int, LearningNode)],
timer: TimeTracker = new TimeTracker,
nodeIdCache: Option[NodeIdCache] = None): Unit = {
。。。
}
寻找最优切分的函数。
为简化代码分析,忽略代码中优化部分(入cache机制等)。
-
-
- findSplits
-
找出splits,供选择最优分解特征值算法使用。
findSplitsBySorting:实际完成findSplits功能。
-
-
- binsToBestSplit
-
也是重点方法。
寻找当前node的最优特征和特征值,findBestSplits会调用到。
包含两层循环,一是特征循环,内部再嵌套该特征的特征值增益循环计算。最后找出最优解。
步骤:
首先获取要spit的节点的level。获取node增益状态。
过滤合法的split,如果某特feature的split为空,则忽略。
/**
* Find the best split for a node.
*
* @param binAggregates Bin statistics.
* @return tuple for best split: (Split, information gain, prediction at node)
*/
private[tree] def binsToBestSplit(
binAggregates: DTStatsAggregator,
splits: Array[Array[Split]],
featuresForNode: Option[Array[Int]],
node: LearningNode): (Split, ImpurityStats) = {
。。。
}
-
-
- calculateImpurityStats
-
计算节点左右子数的增益或者熵。
calculateImpurityStats
gain(增益)= 父node的impurity-左子数的impurity*权重-右子数的impurity*权重。
-
-
- extractMultiClassCategories
-
从离散型数值抽取出多个classLabel,和findSplitsForContinuousFeature对应。
返回离散的分割类别。
-
-
- findSplitsForContinuousFeature
-
对连续特征抽取分割线,比如等分划分特征最小值和最大值之间的距离,划分成N个split,每个split包含一个合理划分连续数值的分割点,分割点是一个double数值。
主要输入参数:每条记录的对应feature值的数组。
返回各个分割的阈值。
-
-
- aggregateSizeForNode
-
计算每个node的统计汇总维度,对于分类模型,总的统计维度=分类类别数*总的bin数(也就是每个特征的可枚举数目)。
-
- 决策树:DecisionTreeClassifier
单个决策树,构造随机森林的参数,设置子树的数目为1,然后调用随机森林算法RandomForest生成决策森林,返回第一个节点。
-
- GBT分类
梯度提升决策树算法。