当前位置:   article > 正文

【Spark ML 系列】RandomForestClassifier RandomForestClassificationModel原理用法示例源码分析_随机森林spark代码

随机森林spark代码

Spark ML中RandomForestClassifier RandomForestClassificationModel原理示例源码分析点击这里看全文

原理

Spark ML中的随机森林分类器(RandomForestClassifier)是基于集成学习方法的一种分类模型。它由多个决策树组成,每个决策树都是通过对训练数据进行自助采样(bootstrap)和特征随机选择而生成的。

以下是Spark ML中随机森林分类器的工作原理:

  1. 数据准备:将输入的训练数据划分为若干个随机子样本。对于每个子样本,从原始数据集中有放回地采样相同数量的样本,形成一个新的训练集。同时,对于每个决策树,还会随机选择一部分特征用于构建树。

  2. 决策树的构建:对于每个子样本和随机选择的特征,使用决策树算法(如ID3、C4.5或CART)构建一个决策树模型。决策树的构建过程包括选择最佳的特征进行节点划分、递归地构建子树,直到达到停止条件(如树的深度达到预设值)。

  3. 集成学习:将所有构建好的决策树组合成随机森林模型。在分类问题中,每个决策树会根据样本的特征进行预测,并统计最终的类别投票结果。根据多数表决原则,选择票数最多的类别作为随机森林模型的最终预测结果。

  4. 特征重要性评估:在随机森林中,每个决策树都可以衡量特征的重要性。通过对所有决策树的特征重要性进行平均,得到整个随机森林模型的特征重要性评估。这可以帮助我们了解哪些特征对于分类结果的贡献更大。

  5. 预测:对于新的输入数据,随机森林模型会将该数据传递给每个决策树进行预测,然后根据决策树的投票结果得出最终的分类结果。

随机森林具有以下优点:

  • 可以处理大量的训练数据,并能够处理高维度的特征。
  • 对于缺失数据和噪声具有一定的鲁棒性。
  • 能够评估特征的重要性,用于特征选择和分析。
  • 在训练过程中,可以并行构建多个决策树,加快训练速度。

需要注意的是,随机森林模型的性能和泛化能力与决策树的数量、树的深度、特征选择策略等参数相关。在使用随机森林时,需要根据具体问题和数据集进行参数调优,以获得最佳的分类性能。

方法总结

RandomForestClassifier是Spark ML中用于分类任务的随机森林模型。下面是该类的一些重要方法的总结:

  • fit(dataset: Dataset[_]): RandomForestClassificationModel:使用给定的训练数据集拟合(训练)随机森林模型,并返回一个训练好的RandomForestClassificationModel对象。

  • setFeaturesCol(value: String): RandomForestClassifier:设置输入特征列的名称。

  • setPredictionCol(value: String): RandomForestClassifier:设置预测结果列的名称。

  • setLabelCol(value: String): RandomForestClassifier:设置标签列的名称,即目标变量。

  • setMaxDepth(value: Int): RandomForestClassifier:设置决策树的最大深度。

  • setNumTrees(value: Int): RandomForestClassifier:设置随机森林中决策树的数量。

  • setSubsamplingRate(value: Double): RandomForestClassifier:设置用于训练每个决策树的样本子集的比例。

  • setFeatureSubsetStrategy(value: String): RandomForestClassifier:设置特征子集选择策略,可以是"auto"、“all”、“onethird”、"sqrt"或"log2"之一。

  • setSeed(value: Long): RandomForestClassifier:设置随机数生成器的种子。

  • setImpurity(value: String): RandomForestClassifier:设置不纯度度量方法,可以是"entropy"(熵)或"gini"(基尼指数)之一。

  • setRawPredictionCol(value: String): RandomForestClassifier:设置原始预测结果列的名称。

  • setProbabilityCol(value: String): RandomForestClassifier:设置概率预测结果列的名称。

  • setWeightCol(value: String): RandomForestClassifier:设置样本权重列的名称。

  • setMaxBins(value: Int): RandomForestClassifier:设置连续特征离散化的最大箱数。

  • fit(dataset: Dataset[_], paramMaps: Array[ParamMap]): Array[RandomForestClassificationModel]:使用给定的训练数据集和参数网格搜索拟合多个随机森林模型,并返回一个包含多个训练好的模型的数组。

  • copy(extra: ParamMap): RandomForestClassifier:复制当前实例,可选地带有额外的参数。

这些方法允许您设置和调整随机森林模型的各种参数,以及在训练过程中控制模型的行为。通过适当选择和设置这些参数,可以优化模型的性能和预测准确度。

示例

以下是使用RandomForestClassifier进行分类任务的示例代码:

import org.apache.spark.ml.classification.RandomForestClassifier
import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator
import org.apache.spark.ml.feature.{
   IndexToString, StringIndexer, VectorAssembler}
import org.apache.spark.ml.Pipeline

// 读取训练数据集
val data = spark.read.format("csv")
  .option("header", "true")
  .option("inferSchema", "true")
  .load("path/to/training_data.csv")

// 创建特征向量列
val featureColumns = Array("feature1", "feature2", "feature3")
val assembler = new VectorAssembler()
  .setInputCols(featureColumns)
  .setOutputCol("features")

val assembledData = assembler.transform(data)

// 对标签进行索引化
val labelIndexer = new StringIndexer()
  .setInputCol("label")
  .setOutputCol("indexedLabel")
  .fit(assembledData)

// 拆分数据集为训练集和测试集
val Array(trainingData, testData) = assembledData.randomSplit(Array(0.7, 0.3))

// 创建随机森林分类器
val rf = new RandomForestClassifier()
  .setLabelCol("indexedLabel")
  .setFeaturesCol("features")
  .setNumTrees(10)

// 将索引化的标签转换回原始标签
val labelConverter = new IndexToString()
  .setInputCol("prediction")
  .setOutputCol("predictedLabel")
  .setLabels(labelIndexer.labels)

// 构建Pipeline
val pipeline = new Pipeline()
  .setStages(Array(labelIndexer, rf, labelConverter))

// 训练模型
val model = pipeline.fit(trainingData)

// 在测试集上进行预测
val predictions = model.transform(testData)

// 评估模型性能
val evaluator = new MulticlassClassificationEvaluator()
  .setLabelCol("indexedLabel")
  .setPredictionCol("prediction")
  .setMetricName("accuracy")

val accuracy = evaluator.evaluate(predictions)
println("Accuracy: " + accuracy)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59

在这个示例中,首先加载训练数据集,并创建特征向量列。然后对标签进行索引化,并将数据集拆分为训练集和测试集。接下来,创建一个RandomForestClassifier对象,并设置相关参数。然后,使用Pipeline构建一个包含数据转换和模型训练的流水线。通过调用fit方法来训练模型。

最后,在测试集上进行预测并评估模型的性能。在这个示例中,我们使用了多分类准确度(accuracy)作为评估指标。

中文源码

class RandomForestClassifier

/**
 * 随机森林(Random Forest)分类学习算法。
 * 支持二进制和多类标签,以及连续和分类特征。
 */
@Since("1.4.0")
class RandomForestClassifier @Since("1.4.0") (
    @Since("1.4.0") override val uid: String)
  extends ProbabilisticClassifier[Vector, RandomForestClassifier, RandomForestClassificationModel]
  with RandomForestClassifierParams with DefaultParamsWritable {
   

  @Since("1.4.0")
  def this() = this(Identifiable.randomUID("rfc"))

  // 为了与Java API兼容性,重写父trait中的参数设置方法。

  // TreeClassifierParams中的参数:

  /** 设置树的最大深度 */
  @Since("1.4.0")
  override def setMaxDepth(value: Int): this.type = set(maxDepth, value)

  /** 设置每个节点的最大分箱数 */
  @Since("1.4.0")
  override def setMaxBins(value: Int): this.type = set(maxBins, value)

  /** 设置每个节点的最小实例数 */
  @Since("1.4.0")
  override def setMinInstancesPerNode(value: Int): this.type = set(minInstancesPerNode, value)

  /** 设置节点分裂所需的最小信息增益 */
  @Since("1.4.0")
  override def setMinInfoGain(value: Double): this.type = set(minInfoGain, value)

  /** 设置算法使用的内存上限 */
  @Since("1.4.0")
  override def setMaxMemoryInMB(value: Int): this.type = set(maxMemoryInMB, value)

  /** 设置是否缓存节点ID */
  @Since("1.4.0")
  override def setCacheNodeIds(value: Boolean): this.type = set(cacheNodeIds, value)

  /**
   * 设置检查点的频率,即多少次迭代进行一次缓存检查点。
   * 仅在设置了cacheNodeIds为true并且在SparkContext中设置了检查点目录时才会使用。
   * 必须至少为1。
   * 默认值为10。
   */
  @Since("1.4.0")
  override def setCheckpointInterval(value: Int): this.type = set(checkpointInterval, value)

  /** 设置不纯度度量方法 */
  @Since("1.4.0")
  override def setImpurity(value: String): this.type = set(impurity, value)

  // TreeEnsembleParams中的参数:

  /** 设置子采样率 */
  @Since("1.4.0")
  override def setSubsamplingRate(value: Double): this.type = set(subsamplingRate, value)

  /** 设置随机种子 */
  @Since("1.4.0")
  override def setSeed(value: Long): this.type = set(seed, value)

  // RandomForestParams中的参数:

  /** 设置树的数量 */
  @Since("1.4.0")
  override def setNumTrees(value: Int): this.type = set(numTrees, value)

  /** 设置特征子集策略 */
  @Since("1.4.0")
  override def setFeatureSubsetStrategy(value: String): this.type =
    set(featureSubsetStrategy, value)

  override protected def train(
      dataset: Dataset[_]): RandomForestClassificationModel = instrumented {
    instr =>
    instr.logPipelineStage(this)
    instr.logDataset(dataset)
    val categoricalFeatures: Map[Int, Int] =
      MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol)))
    val numClasses: Int = getNumClasses(dataset)

    if (isDefined(thresholds)) {
   
      require($(thresholds).length == numClasses, this.getClass.getSimpleName +
        ".train() called with non-matching numClasses and thresholds.length." +
        s" numClasses=$numClasses, but thresholds has length ${
     $(thresholds).length}")
    }

    val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset, numClasses)
    val strategy =
      super.getOldStrategy(categoricalFeatures, numClasses, OldAlgo.Classification, getOldImpurity)

    instr.logParams(this, labelCol, featuresCol, predictionCol, probabilityCol, rawPredictionCol,
      impurity, numTrees, featureSubsetStrategy, maxDepth, maxBins, maxMemoryInMB, minInfoGain,
      minInstancesPerNode, seed, subsamplingRate, thresholds, cacheNodeIds, checkpointInterval)

    val trees = RandomForest
      .run(oldDataset, strategy, getNumTrees, getFeatureSubsetStrategy, getSeed, Some(instr))
      .map(_.asInstanceOf[DecisionTreeClassificationModel])

    val numFeatures = oldDataset.first().features.size
    instr.logNumClasses(numClasses)
    instr.logNumFeatures(numFeatures)
    new RandomForestClassificationModel(uid, trees, numFeatures, numClasses)
  }

  @Since("1.4.1")
  override def copy(extra: ParamMap): RandomForestClassifier = defaultCopy(extra)
}
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107
  • 108
  • 109
  • 110
  • 111
  • 112
  • 113
  • 114

object RandomForestClassifier

object RandomForestClassifier extends DefaultParamsReadable[RandomForestClassifier] {
   
  /** 支持的不纯度度量方法:熵(entropy)、基尼指数(gini) */
  @Since("1.4.0")
  final val supportedImpurities: Array[String] = TreeClassifierParams.supportedImpurities

  /** 支持的特征子集策略:自动选择(auto)、全部(all)、三分之一(onethird)、平方根(sqrt)、对数(log2) */
  @Since("1.4.0")
  final val supportedFeatureSubsetStrategies: Array[String] =
    TreeEnsembleParams.supportedFeatureSubsetStrategies

  /** 加载模型 */
  @Since("2.0.0")
  override def load(path: String): RandomForestClassifier = super.load(path)
}
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15

这部分代码定义了RandomForestClassifier对象,提供了一些静态方法和常量:

  • supportedImpurities支持的不纯度度量方法,包括熵(entropy)和基尼指数(gini)。
  • supportedFeatureSubsetStrategies:支持的特征子集策略,包括自动选择(auto)、全部(all)、三分之一(onethird)、平方根(sqrt)和对数(log2)。
  • load方法:用于加载模型。

class RandomForestClassificationModel

/**
 * 用于分类的随机森林(Random Forest)模型。
 *
 * @param _trees  集成中的决策树数组。
 *                注意:这些树的父节点为null。
 */
@Since("1.4.0")
class RandomForestClassificationModel private[ml] (
    @Since("1.5.0") override val uid: String,
    private val _trees: Array[DecisionTreeClassificationModel],
    @Since("1.6.0") override val numFeatures: Int,
    @Since("1.5.0") override val numClasses: Int)
  extends ProbabilisticClassificationModel[Vector, RandomForestClassificationModel]
  with RandomForestClassifierParams with TreeEnsembleModel[DecisionTreeClassificationModel]
  with MLWritable with Serializable {
   

  require(_trees.nonEmpty, "RandomForestClassificationModel requires at least 1 tree.")

  /**
   * 构造随机森林分类模型,所有树的权重相等。
   *
   * @param trees  组成模型的决策树数组
   */
  private[ml] def this(
      trees: Array[DecisionTreeClassificationModel],
      numFeatures: Int,
      numClasses: Int) =
    this(Identifiable.randomUID("rfc"), trees, numFeatures, numClasses)

  @Since("1.4.0")
  override def trees: Array[DecisionTreeClassificationModel] = _trees

  // 注意:我们可能会在以后添加根据树性能进行加权的支持。
  private lazy val _treeWeights: Array[Double] = Array.fill[Double](_trees.length)(1.0)

  @Since("1.4.0")
  override def treeWeights: Array[Double] = _treeWeights

  /**
   * 将模型应用于数据集,生成预测结果的转换操作。
   *
   * @param dataset  输入的数据集
   * @return         包含预测结果的新DataFrame
   */
  override protected def transformImpl(dataset: Dataset[_]): DataFrame = {
   
    val bcastModel = dataset.sparkSession.sparkContext.broadcast(this)
    val predictUDF = udf {
    (features: Any) =>
      bcastModel.value.predict(features.asInstanceOf[Vector])
    }
    dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol))))
  }

  /**
   * 根据输入特征向量生成原始预测结果。
   *
   * @param features  输入的特征向量
   * @return          原始预测结果向量
   */
  override protected def predictRaw(features: Vector): Vector = {
   
    // TODO: 当我们添加通用的Bagging类时,将在那里处理:SPARK-7128
    // 使用多数表决进行分类。
    // 目前忽略树权重,因为都是1.0。
    val votes = Array.fill[Double](numClasses)(0.0)
    _trees.view.foreach {
    tree =>
      val classCounts: Array[Double] = tree.rootNode.predictImpl(features).impurityStats.stats
      val total = classCounts.sum
      if (total 
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/繁依Fanyi0/article/detail/619940
推荐阅读
相关标签
  

闽ICP备14008679号