赞
踩
相关文章
spark源码分析之随机森林(Random Forest)(一)
spark源码分析之随机森林(Random Forest)(三)
spark源码分析之随机森林(Random Forest)(四)
spark源码分析之随机森林(Random Forest)(五)
spark源码分析之DecisionTree与GBDT
这部分主要在DecisionTree.scala的findSplitsBins函数,将所有特征封装成Split,然后装箱Bin。首先对split和bin的结构进行说明
class Split(
@Since("1.0.0") feature: Int,
@Since("1.0.0") threshold: Double,
@Since("1.0.0") featureType: FeatureType,
@Since("1.0.0") categories: List[Double])
class Bin(
lowSplit: Split,
highSplit: Split,
featureType: FeatureType,
category: Double)
val continuousFeatures = Range(0, numFeatures).filter(metadata.isContinuous)
val sampledInput = if (continuousFeatures.nonEmpty) {
// Calculate the number of samples for approximate quantile calculation.
val requiredSamples = math.max(metadata.maxBins * metadata.maxBins, 10000)
val fraction = if (requiredSamples < metadata.numExamples) {
requiredSamples.toDouble / metadata.numExamples
} else {
1.0
}
logDebug("fraction of data used for calculating quantiles = " + fraction)
input.sample(withReplacement = false, fraction, new XORShiftRandom().nextInt())
} else {
input.sparkContext.emptyRDD[LabeledPoint]
}
首先筛选出连续特征集,然后计算抽样数量,抽样比例,然后无放回样本抽样;如果没有连续特征,则为空RDD
metadata.quantileStrategy match {
case Sort =>
findSplitsBinsBySorting(sampledInput, metadata, continuousFeatures)
case MinMax =>
throw new UnsupportedOperationException("minmax not supported yet.")
case ApproxHist =>
throw new UnsupportedOperationException("approximate histogram not supported yet.")
}
分位点策略,这里只实现了Sort这一种,前文有说明,下面的计算在findSplitsBinsBySorting函数中,入参是抽样样本集,metadata和连续特征集(里面是特征id,从0开始,见LabelPoint的构造)
val continuousSplits = {
// reduce the parallelism for split computations when there are less
// continuous features than input partitions. this prevents tasks from
// being spun up that will definitely do no work.
val numPartitions = math.min(continuousFeatures.length,input.partitions.length)
input.flatMap(point => continuousFeatures.map(idx => (idx,point.features(idx))))
.groupByKey(numPartitions)
.map { case (k, v) => findSplits(k, v) }
.collectAsMap()
}
特征id为key,value是样本对应的该特征下的所有特征值,传给findSplits函数,其中又调用了findSplitsForContinuousFeature函数获得连续特征的Split,入参为样本,metadata和特征id
def findSplitsForContinuousFeature(
featureSamples: Array[Double],
metadata: DecisionTreeMetadata,
featureIndex: Int): Array[Double] = {
require(metadata.isContinuous(featureIndex),
"findSplitsForContinuousFeature can only be used to find splits for a continuous feature.")
val splits = {
//连续特征的split是numBins-1
val numSplits = metadata.numSplits(featureIndex)
//统计所有特征值其出现的次数
// get count for
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。