当前位置:   article > 正文

spark mllib源码分析之随机森林(Random Forest)(二)_spark mllib random forest 复杂度

spark mllib random forest 复杂度

相关文章
spark源码分析之随机森林(Random Forest)(一)
spark源码分析之随机森林(Random Forest)(三)
spark源码分析之随机森林(Random Forest)(四)
spark源码分析之随机森林(Random Forest)(五)
spark源码分析之DecisionTree与GBDT

4. 特征处理

这部分主要在DecisionTree.scala的findSplitsBins函数,将所有特征封装成Split,然后装箱Bin。首先对split和bin的结构进行说明

4.1. 数据结构

4.1.1. Split
class Split(
    @Since("1.0.0") feature: Int,
    @Since("1.0.0") threshold: Double,
    @Since("1.0.0") featureType: FeatureType,
    @Since("1.0.0") categories: List[Double])
  • 1
  • 2
  • 3
  • 4
  • 5
  • feature:特征id
  • threshold:阈值
  • featureType:连续特征(Continuous)/离散特征(Categorical)
  • categories:离散特征值数组,离散特征使用。放着此split中所有特征值
4.1.2. Bin
class Bin(
    lowSplit: Split, 
    highSplit: Split, 
    featureType: FeatureType, 
    category: Double)
  • 1
  • 2
  • 3
  • 4
  • 5
  • lowSplit/highSplit:上下界
  • featureType:连续特征(Continuous)/离散特征(Categorical)
  • category:离散特征的特征值

4.2. 连续特征处理

4.2.1. 抽样
val continuousFeatures = Range(0, numFeatures).filter(metadata.isContinuous)
val sampledInput = if (continuousFeatures.nonEmpty) {
      // Calculate the number of samples for approximate quantile calculation.
      val requiredSamples = math.max(metadata.maxBins * metadata.maxBins, 10000)
      val fraction = if (requiredSamples < metadata.numExamples) {
        requiredSamples.toDouble / metadata.numExamples
      } else {
        1.0
      }
      logDebug("fraction of data used for calculating quantiles = " + fraction)
      input.sample(withReplacement = false, fraction, new XORShiftRandom().nextInt())
    } else {
      input.sparkContext.emptyRDD[LabeledPoint]
    }
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14

首先筛选出连续特征集,然后计算抽样数量,抽样比例,然后无放回样本抽样;如果没有连续特征,则为空RDD

4.2.2. 计算Split
metadata.quantileStrategy match {
      case Sort =>
        findSplitsBinsBySorting(sampledInput, metadata, continuousFeatures)
      case MinMax =>
        throw new UnsupportedOperationException("minmax not supported yet.")
      case ApproxHist =>
        throw new UnsupportedOperationException("approximate histogram not supported yet.")
    }
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8

分位点策略,这里只实现了Sort这一种,前文有说明,下面的计算在findSplitsBinsBySorting函数中,入参是抽样样本集,metadata和连续特征集(里面是特征id,从0开始,见LabelPoint的构造)

val continuousSplits = {
    // reduce the parallelism for split computations when there are less
    // continuous features than input partitions. this prevents tasks from
    // being spun up that will definitely do no work.
    val numPartitions = math.min(continuousFeatures.length,input.partitions.length)
    input.flatMap(point => continuousFeatures.map(idx =>  (idx,point.features(idx))))
         .groupByKey(numPartitions)
         .map { case (k, v) => findSplits(k, v) }
         .collectAsMap()
    }
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10

特征id为key,value是样本对应的该特征下的所有特征值,传给findSplits函数,其中又调用了findSplitsForContinuousFeature函数获得连续特征的Split,入参为样本,metadata和特征id

def findSplitsForContinuousFeature(
      featureSamples: Array[Double], 
      metadata: DecisionTreeMetadata,
      featureIndex: Int): Array[Double] = {
    require(metadata.isContinuous(featureIndex),
      "findSplitsForContinuousFeature can only be used to find splits for a continuous feature.")

    val splits = {
    //连续特征的split是numBins-1
      val numSplits = metadata.numSplits(featureIndex)
    //统计所有特征值其出现的次数
      // get count for 
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/小小林熬夜学编程/article/detail/693271
推荐阅读
相关标签
  

闽ICP备14008679号