当前位置:   article > 正文

spark-GBDTs源码解析(GBDT梯度提升决策树[回归GBTClassifier|分类GBDTRegressor])_(spark_2.2.0)

gbtclassifier

GBDT算法简介

【概述】

            GBDT(全称梯度下降树)是集成学习中的其中一种算法。幸运的是spark在MLlib中有相关实现,共有两种实现GBTClassifier,GBDTRegressor。

spark实现计算流程】

       1. 若当前实现为GBTClassifier,检查训练集的label是否包含0和1之外的值,如果包含异常退出,否则将0和1转换成-1和+1。若当前时限为GBDTRegressor,数据不做处理。

       2.根据不同实现配置不同的损失函数和纯度计算函数

 GBTClassifierGBDTRegressor
损失函数(loss)L1 ,L2Logloss
纯度计算(impurity)基尼系数label列方差

       3.启发式训练第一个回归树模型,并设置其权重为1   

       4.预测测试集的label

       4.调整训练数据集的label值= -loss.gradient(pred, point.label) 【注】gradient和loss函数绑定下面章节会有讲解

       5.将调整label值后的训练数据,传入回归树训练器训练模型得到模型,设置当前模型权重(weight)=步长(stepSize)

       6.根据训练模型预测数据:预测结果=上次迭代模型预测结果 + 当前树模型预测结果 * 当前权重(步长))

       7.重复4-6流程,直到训练次数达到配置的最大迭代次数

       8.返回树模型数组和各个模型权重

【注】正式预测过程中,GBTClassifier会将预测结果重新转换为0和1(后续代码会有展示)

调用样例

 

  1. val gbtClassfier = new GBTClassifier()
  2. /*设置目标列*/
  3. .setLabelCol("")
  4. /*设置特征列*/
  5. .setFeaturesCol("")
  6. /*设置损失函数类型,仅支持Logistic方式*/
  7. .setLossType("")
  8. /*设置最大深度*/
  9. .setMaxDepth("")
  10. /*设置纯度度量函数*/
  11. .setImpurity("")
  12. /*为避免driver端DAG过长,对driver栈空间压力过大以及容错压力,需要定次checkpoint清空DAG和中间数据持久化*/
  13. .setCheckpointInterval(10)
  14. /*最大迭代次数即最终计算随机森林的个数*/
  15. .setMaxIter("")
  16. .setCacheNodeIds("")
  17. .setMaxBins("")
  18. .setMaxMemoryInMB("")
  19. .setMinInfoGain("")
  20. .setMinInstancesPerNode("")
  21. .setSeed(31D)
  22. .setStepSize(0.0)
  23. .setSubsamplingRate(0.0)
  24. val model: GBTClassificationModel = gbtClassfier.fit(null:DataFrame)
  25. model.transform(null:DataFrame)

 

损失函数

损失函数共有两种类别:

       1.基于回归思想实现的GBDT损失函数被封装在GBTClassifierParams中,仅支持logistic。

       2.基于分类思想实现的GBDT损失函数被封装在GBTRegressorParams中,支持sequared(L2正则化)和absolution(L1正则化)两种计算方式。

1.​​​​分类相关损失函数实现

  【损失函数判定和实例化代码】

  1. private[ml] object GBTClassifierParams {
  2. /** 基于分类的实现仅支持:logistic计算类型 */
  3. final val supportedLossTypes: Array[String] = Array("logistic").map(_.toLowerCase)
  4. }
  1. import org.apache.spark.mllib.tree.loss.{AbsoluteError => OldAbsoluteError, LogLoss => OldLogLoss, Loss => OldLoss, SquaredError => OldSquaredError}
  2. //以上将LogLoss重命名为OldLogLoss
  3. ...
  4. override private[ml] def getOldLossType: OldLoss = {
  5. getLossType match {
  6. case "logistic" => OldLogLoss
  7. case _ =>
  8. // Should never happen because of check in setter method.
  9. throw new RuntimeException(s"GBTClassifier was given bad loss type: $getLossType")
  10. }
  11. }

【关于OldLogLoss的实现】

    LogLoss中封装了,梯度计算和损失值的计算

  1. object LogLoss extends Loss {
  2. /**
  3. *梯度计算,用于每次迭代前生成新的label
  4. * Method to calculate the loss gradients for the gradient boosting calculation for binary
  5. * classification
  6. * The gradient with respect to F(x) is: - 4 y / (1 + exp(2 y F(x)))
  7. * @param prediction Predicted label.
  8. * @param label True label.
  9. * @return Loss gradient
  10. */
  11. @Since("1.2.0")
  12. override def gradient(prediction: Double, label: Double): Double = {
  13. - 4.0 * label / (1.0 + math.exp(2.0 * label * prediction))
  14. }
  15. /*计算预测误差*/
  16. override private[spark] def computeError(prediction: Double, label: Double): Double = {
  17. val margin = 2.0 * label * prediction
  18. // The following is equivalent to 2.0 * log(1 + exp(-margin)) but more numerically stable.
  19. 2.0 * MLUtils.log1pExp(-margin)
  20. }
  21. }

 


2.回归相关损失函数

【损失函数判定和实例化代码】

  1. private[ml] object GBTRegressorParams {
  2. // The losses below should be lowercase.
  3. /** Accessor for supported loss settings: squared (L2), absolute (L1) */
  4. final val supportedLossTypes: Array[String] = Array("squared", "absolute").map(_.toLowerCase)
  5. }
  1. import org.apache.spark.mllib.tree.loss.{AbsoluteError => OldAbsoluteError, LogLoss => OldLogLoss, Loss => OldLoss, SquaredError => OldSquaredError}
  2. ...
  3. override private[ml] def getOldLossType: OldLoss = {
  4. getLossType match {
  5. /*L2正则化*/
  6. case "squared" => OldSquaredError
  7. /*L1正则化*/
  8. case "absolute" => OldAbsoluteError
  9. case _ =>
  10. // Should never happen because of check in setter method.
  11. throw new RuntimeException(s"GBTRegressorParams was given bad loss type: $getLossType")
  12. }
  13. }

【squared实现】:L2正则化

  1. object SquaredError extends Loss {
  2. /**
  3. * Method to calculate the gradients for the gradient boosting calculation for least
  4. * squares error calculation.
  5. * The gradient with respect to F(x) is: - 2 (y - F(x))
  6. * @param prediction Predicted label.
  7. * @param label True label.
  8. * @return Loss gradient
  9. */
  10. @Since("1.2.0")
  11. override def gradient(prediction: Double, label: Double): Double = {
  12. - 2.0 * (label - prediction)
  13. }
  14. override private[spark] def computeError(prediction: Double, label: Double): Double = {
  15. val err = label - prediction
  16. err * err
  17. }
  18. }

【absolute实现】:L1正则化

  1. object AbsoluteError extends Loss {
  2. /**
  3. * Method to calculate the gradients for the gradient boosting calculation for least
  4. * absolute error calculation.
  5. * The gradient with respect to F(x) is: sign(F(x) - y)
  6. * @param prediction Predicted label.
  7. * @param label True label.
  8. * @return Loss gradient
  9. */
  10. @Since("1.2.0")
  11. override def gradient(prediction: Double, label: Double): Double = {
  12. if (label - prediction < 0) 1.0 else -1.0
  13. }
  14. override private[spark] def computeError(prediction: Double, label: Double): Double = {
  15. val err = label - prediction
  16. math.abs(err)
  17. }
  18. }

 

列选择度量函数(列纯度测度)

   【实现方式】 默认情况下:

               GBDT分类实现使用基尼系数作为列选择度量函数

               GBDT回归实现使用(label列)方差作为列选择度量函数

    【注】以上两种列选择度量函数不可修改。如需自定义度量函数可以通过修改如下如下源码,打包到工程文件并配置(spark.driver.userClassPathFirst=true,spark.executor.userClassPathFirst=true)即可完成纯度测度函数的替换。

 以下为算法绑定代码实现:

  1. def defaultStrategy(algo: Algo): Strategy = algo match {
  2. //若当前为GBDT分类实现,在策略中将Gini作为纯度度量
  3. case Algo.Classification =>
  4. new Strategy(algo = Classification, impurity = Gini, maxDepth = 10,
  5. numClasses = 2)
  6. //若当前为GBDT分类实现,在策略中将Variance作为纯度度量
  7. case Algo.Regression =>
  8. new Strategy(algo = Regression, impurity = Variance, maxDepth = 10,
  9. numClasses = 0)
  10. }

1.基尼系数

   基尼系数共有两种计算方式,

   (1).对于给定特征各个类别概率值的情况下,基尼系数计算方式为:

    Gini(V) =1 -\sum_{k=1}^{K}{p_{k}}^{2}

   (2).对于未给定特征各个类别概率值的情况下,基尼系数计算方式为:

      Gini(D)=1-\sum_{k=1}^{K}\left ( \frac{|{C_{k}}^{}|}{|D|} \right ) ^{2}

 【注】当前spark默认实现为第二种算法

  1. object Gini extends Impurity {
  2. /**
  3. * :: DeveloperApi ::
  4. * information calculation for multiclass classification
  5. * @param counts Array[Double] with counts for each label
  6. * @param totalCount sum of counts for all labels
  7. * @return information value, or 0 if totalCount = 0
  8. */
  9. @Since("1.1.0")
  10. @DeveloperApi
  11. override def calculate(counts: Array[Double], totalCount: Double): Double = {
  12. if (totalCount == 0) {
  13. return 0
  14. }
  15. val numClasses = counts.length
  16. var impurity = 1.0
  17. var classIndex = 0
  18. while (classIndex < numClasses) {
  19. val freq = counts(classIndex) / totalCount
  20. impurity -= freq * freq
  21. classIndex += 1
  22. }
  23. impurity
  24. }

2.方差(label列)实现代码

  1. object Variance extends Impurity {
  2. /**
  3. * :: DeveloperApi ::
  4. * variance calculation
  5. * @param count number of instances
  6. * @param sum sum of labels
  7. * @param sumSquares summation of squares of the labels
  8. * @return information value, or 0 if count = 0
  9. */
  10. @Since("1.0.0")
  11. @DeveloperApi
  12. override def calculate(count: Double, sum: Double, sumSquares: Double): Double = {
  13. if (count == 0) {
  14. return 0
  15. }
  16. val squaredLoss = sumSquares - (sum * sum) / count
  17. squaredLoss / count
  18. }

 

模型训练实现部分

【概述】

          在模型训练过程中,分类和回归模型训练实现都是调用GradientBoostedTrees.run(...),返回多个回归决策树和各个树对应的权重。然后在将他们分别封装成GBTRegressionModel和GBTClassfierModel。

在数据准备阶段,分类实现会检查训练数据的label列是否会有非0,1数据,若出现将异常退出。

【GBTRegression】数据准备,超参封装,以及训练模型代码  调度相关源码实现和源码注释

  1. override protected def train(dataset: Dataset[_]): GBTRegressionModel = {
  2. /*
  3. * 获取列的基元个数,主要通过判断每列有无做过分桶或者二分类处理
  4. * 例如:若做过分桶处理,分桶个数就是Map中的Value,key为field下标.若做个二分类相应value值就为2
  5. */
  6. val categoricalFeatures: Map[Int, Int] =
  7. MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol)))
  8. /*根据配置的labelCol和featrueCol将RDD中的行数据分装成LabelPoint*/
  9. val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset)
  10. /*获取特征列个数*/
  11. val numFeatures = oldDataset.first().features.size
  12. /*封装默认训练策略(数据纯度,损失函数,最大深度,迭代次数等等)*/
  13. val boostingStrategy = super.getOldBoostingStrategy(categoricalFeatures, OldAlgo.Regression)
  14. /*初始化 日志和计算指标(性能耗时)收集器*/
  15. val instr = Instrumentation.create(this, oldDataset)
  16. instr.logParams(params: _*)
  17. instr.logNumFeatures(numFeatures)
  18. /*开始梯度提升训练,训练过程分类和回归的训练函数一致,并做参数,label数据微调*/
  19. val (baseLearners, learnerWeights) = GradientBoostedTrees.run(oldDataset, boostingStrategy,
  20. $(seed))
  21. /*将训练出的回归树模型和各个模型权重以及特征个数(与测试验证用)封装成模型对象*/
  22. val m = new GBTRegressionModel(uid, baseLearners, learnerWeights, numFeatures)
  23. /*输出成功日志*/
  24. instr.logSuccess(m)
  25. m
  26. }

【GBTClassification】数据准备,超参封装,以及训练模型代码  调度相关源码实现和源码注释

  1. override protected def train(dataset: Dataset[_]): GBTClassificationModel = {
  2. /*和回归实现方式一致,计算各列的基元数*/
  3. val categoricalFeatures: Map[Int, Int] =
  4. MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol)))
  5. // We copy and modify this from Classifier.extractLabeledPoints since GBT only supports
  6. // 2 classes now. This lets us provide a more precise error message.
  7. /*检查label列是否包含[0|1]之外的值,若label出现[0|1]之外的值将终止计算,异常退出*/
  8. val oldDataset: RDD[LabeledPoint] =
  9. dataset.select(col($(labelCol)), col($(featuresCol))).rdd.map {
  10. case Row(label: Double, features: Vector) =>
  11. require(label == 0 || label == 1, s"GBTClassifier was given" +
  12. s" dataset with invalid label $label. Labels must be in {0,1}; note that" +
  13. s" GBTClassifier currently only supports binary classification.")
  14. LabeledPoint(label, features)
  15. }
  16. /*和回归算法实现一致,获取特征列个数*/
  17. val numFeatures = oldDataset.first().features.size
  18. /*和回归算法一致,封装计算策略,包含纯度测度等封装*/
  19. val boostingStrategy = super.getOldBoostingStrategy(categoricalFeatures, OldAlgo.Classification)
  20. /*和回归算法一致,封装日志和性能指标相关测量函数*/
  21. val instr = Instrumentation.create(this, oldDataset)
  22. instr.logParams(params: _*)
  23. instr.logNumFeatures(numFeatures)
  24. instr.logNumClasses(2)
  25. /*和回归实现一致,开始训练模型,此处列选择纯度测度和其他差异算法,已经在boostingStrategy中差异化封装完成*/
  26. val (baseLearners, learnerWeights) = GradientBoostedTrees.run(oldDataset, boostingStrategy,
  27. $(seed))
  28. /*将训练得出回归树和每棵树的权重封装成GBTClassificationModel*/
  29. val m = new GBTClassificationModel(uid, baseLearners, learnerWeights, numFeatures)
  30. instr.logSuccess(m)
  31. m
  32. }

 

【GradientBoostedTrees梯度提升树】实现和源码注释

 【概述】在GBDT的两种实现中在训练模型环节均调用GradientBoostedTrees.run(...)来训练模型。

在正式训练之前,GBDT分类相关实现对训练数据做了一个封装,将label列的[0|1]转换成[-1|1]。在训练模型时均调用 GradientBoostedTrees.boost(后续展示)来训练模型。

如下为GradientBoostedTrees.run相关代码的实现和注释:

  1. def run(
  2. input: RDD[LabeledPoint],
  3. boostingStrategy: OldBoostingStrategy,
  4. seed: Long): (Array[DecisionTreeRegressionModel], Array[Double]) = {
  5. val algo = boostingStrategy.treeStrategy.algo
  6. algo match {
  7. case OldAlgo.Regression =>
  8. GradientBoostedTrees.boost(input, input, boostingStrategy, validate = false, seed)
  9. case OldAlgo.Classification =>
  10. // Map labels to -1, +1 so binary classification can be treated as regression.
  11. /*为了分类GBDT算法能够以回归树的方式计算,将0,1转换成-1,+1*/
  12. val remappedInput = input.map(x => new LabeledPoint((x.label * 2) - 1, x.features))
  13. GradientBoostedTrees.boost(remappedInput, remappedInput, boostingStrategy, validate = false,
  14. seed)
  15. case _ =>
  16. throw new IllegalArgumentException(s"$algo is not supported by gradient boosting.")
  17. }
  18. }

如下为GradientBoostedTrees.boost模型训练相关代码实现和注释,主要负责训练树模型组和模型相关的权重:

 

  1. /**
  2. * Internal method for performing regression using trees as base learners.
  3. * @param input training dataset
  4. * @param validationInput validation dataset, ignored if validate is set to false.
  5. * @param boostingStrategy boosting parameters
  6. * @param validate whether or not to use the validation dataset.
  7. * @param seed Random seed.
  8. * @return tuple of ensemble models and weights:
  9. * (array of decision tree models, array of model weights)
  10. */
  11. def boost(
  12. input: RDD[LabeledPoint],
  13. validationInput: RDD[LabeledPoint],
  14. boostingStrategy: OldBoostingStrategy,
  15. validate: Boolean,
  16. seed: Long): (Array[DecisionTreeRegressionModel], Array[Double]) = {
  17. val timer = new TimeTracker()
  18. timer.start("total")
  19. timer.start("init")
  20. boostingStrategy.assertValid()
  21. // Initialize gradient boosting parameters 初始化梯度提升配置的各个参数
  22. /*获取最大迭代次数*/
  23. val numIterations = boostingStrategy.numIterations
  24. /*申请存放训练结果(回归树)的数组容器,容量大小为迭代次数*/
  25. val baseLearners = new Array[DecisionTreeRegressionModel](numIterations)
  26. /*为训练结果模型(回归树)分配权重容器*/
  27. val baseLearnerWeights = new Array[Double](numIterations)
  28. /*获取损失函数实现,回归为(L1,L2),分类为logLoss 实现见前面【损失函数实现章节】*/
  29. val loss = boostingStrategy.loss
  30. /*获取学习率(步长默认0.1)*/
  31. val learningRate = boostingStrategy.learningRate
  32. // Prepare strategy for individual trees, which use regression with variance impurity. 提取单次迭代数的策略
  33. val treeStrategy = boostingStrategy.treeStrategy.copy
  34. val validationTol = boostingStrategy.validationTol
  35. treeStrategy.algo = OldAlgo.Regression
  36. treeStrategy.impurity = OldVariance
  37. treeStrategy.assertValid()
  38. // Cache input 由于input(RDD)会多次迭代使用,为避免重复计算前面DAG,缓存数据
  39. val persistedInput = if (input.getStorageLevel == StorageLevel.NONE) {
  40. input.persist(StorageLevel.MEMORY_AND_DISK)
  41. true
  42. } else {
  43. false
  44. }
  45. // Prepare periodic checkpointers,中间数据持久化,清空之前DAG
  46. val predErrorCheckpointer = new PeriodicRDDCheckpointer[(Double, Double)](
  47. treeStrategy.getCheckpointInterval, input.sparkContext)
  48. val validatePredErrorCheckpointer = new PeriodicRDDCheckpointer[(Double, Double)](
  49. treeStrategy.getCheckpointInterval, input.sparkContext)
  50. timer.stop("init")
  51. logDebug("##########")
  52. logDebug("Building tree 0")
  53. logDebug("##########")
  54. // Initialize tree,DGDT为启发式计算,先计算第一个回归树模型,默认给予1.0权重
  55. timer.start("building tree 0")
  56. val firstTree = new DecisionTreeRegressor().setSeed(seed)
  57. val firstTreeModel = firstTree.train(input, treeStrategy)
  58. val firstTreeWeight = 1.0
  59. baseLearners(0) = firstTreeModel
  60. baseLearnerWeights(0) = firstTreeWeight
  61. /*预测数据,并根据不同实现方式和传入的损失函数,计算预测误差。计算方式见前面章节【损失函数实现】*/
  62. var predError: RDD[(Double, Double)] =
  63. computeInitialPredictionAndError(input, firstTreeWeight, firstTreeModel, loss)
  64. predErrorCheckpointer.update(predError)
  65. /*输出预测误差均值*/
  66. logDebug("error of gbt = " + predError.values.mean())
  67. // Note: A model of type regression is used since we require raw prediction
  68. timer.stop("building tree 0")
  69. /*预测验证集label,并根据loss函数计算误差*/
  70. var validatePredError: RDD[(Double, Double)] =
  71. computeInitialPredictionAndError(validationInput, firstTreeWeight, firstTreeModel, loss)
  72. if (validate) validatePredErrorCheckpointer.update(validatePredError)
  73. /*计算误差均值*/
  74. var bestValidateError = if (validate) validatePredError.values.mean() else 0.0
  75. /*初始化最佳模型树下标*/
  76. var bestM = 1
  77. var m = 1
  78. /*是否提前终止迭代*/
  79. var doneLearning = false
  80. while (m < numIterations && !doneLearning) {
  81. // Update data with pseudo-residuals
  82. /*将上次预测的结果和label 取梯度的反方向,作为当前迭代的label值,梯度算法见前面章节【损失函数】*/
  83. val data = predError.zip(input).map { case ((pred, _), point) =>
  84. LabeledPoint(-loss.gradient(pred, point.label), point.features)
  85. }
  86. timer.start(s"building tree $m")
  87. logDebug("###################################################")
  88. logDebug("Gradient boosting tree iteration " + m)
  89. logDebug("###################################################")
  90. /*初始化回归决策树并训练模型*/
  91. val dt = new DecisionTreeRegressor().setSeed(seed + m)
  92. val model = dt.train(data, treeStrategy)
  93. timer.stop(s"building tree $m")
  94. // Update partial model
  95. /*将训练的模型,放入模型容器*/
  96. baseLearners(m) = model
  97. // Note: The setting of baseLearnerWeights is incorrect for losses other than SquaredError.
  98. // Technically, the weight should be optimized for the particular loss.
  99. // However, the behavior should be reasonable, though not optimal.
  100. /* 学习率(步长)作为当前模型权重,后续会根据学习率(步长)计算预测值
  101. * (预测结果=上一个树模型预测结果 + 当前树模型预测结果 * 当前权重(步长))
  102. */
  103. baseLearnerWeights(m) = learningRate
  104. /*根据训练出的回归树模型,做预测(预测结果=上一个树模型预测结果 + 当前树模型预测结果 * 当前权重(步长)),并根据配置的loss函数计算预测误差*/
  105. predError = updatePredictionError(
  106. input, predError, baseLearnerWeights(m), baseLearners(m), loss)
  107. predErrorCheckpointer.update(predError)
  108. logDebug("error of gbt = " + predError.values.mean())
  109. //为避免过拟合,是否提前终止计算,当前默认为false,且不可修改,当前算法实现,如下代码将不执行
  110. if (validate) {
  111. // Stop training early if
  112. // 1. Reduction in error is less than the validationTol or
  113. // 2. If the error increases, that is if the model is overfit.
  114. // We want the model returned corresponding to the best validation error.
  115. /*预测验证集的label,并计算预测误差值,*/
  116. validatePredError = updatePredictionError(
  117. validationInput, validatePredError, baseLearnerWeights(m), baseLearners(m), loss)
  118. validatePredErrorCheckpointer.update(validatePredError)
  119. /*计算验证集误差期望*/
  120. val currentValidateError = validatePredError.values.mean()
  121. /*默认情况:validationTol -> 1e-5 ,若最好模型误差期望和当前预测误差期望差值小于某定制,将提前终止计算*/
  122. if (bestValidateError - currentValidateError < validationTol * Math.max(
  123. currentValidateError, 0.01)) {
  124. doneLearning = true
  125. } else if (currentValidateError < bestValidateError) {
  126. /*若当前模型误差期望小于最好模型误差期望,当前模型下标作为最佳模型的下标(标记当前模型为最好模型)*/
  127. bestValidateError = currentValidateError
  128. bestM = m + 1
  129. }
  130. }
  131. m += 1
  132. }
  133. timer.stop("total")
  134. logInfo("Internal timing for DecisionTree:")
  135. logInfo(s"$timer")
  136. /*删除所有持久化的中间数据*/
  137. predErrorCheckpointer.deleteAllCheckpoints()
  138. validatePredErrorCheckpointer.deleteAllCheckpoints()
  139. if (persistedInput) input.unpersist()
  140. /*返回模型树数组和各个模型的权重(出了第一个为1,其余的值和步长相同)*/
  141. if (validate) {
  142. /*若开启了提前终止计算,删除结果模型容器中多余的空位*/
  143. (baseLearners.slice(0, bestM), baseLearnerWeights.slice(0, bestM))
  144. } else {
  145. (baseLearners, baseLearnerWeights)
  146. }
  147. }

 

预测

【回归实现】

  1. override protected def transformImpl(dataset: Dataset[_]): DataFrame = {
  2. /*广播模型变量*/
  3. val bcastModel = dataset.sparkSession.sparkContext.broadcast(this)
  4. /*实现预测相关UDF*/
  5. val predictUDF = udf { (features: Any) =>
  6. /*调用下面函数进行预测*/
  7. bcastModel.value.predict(features.asInstanceOf[Vector])
  8. }
  9. /*将预测结果作为新的一列拼接到当前DataFrame*/
  10. dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol))))
  11. }
  12. override protected def predict(features: Vector): Double = {
  13. // TODO: When we add a generic Boosting class, handle transform there? SPARK-7129
  14. // Classifies by thresholding sum of weighted tree predictions
  15. /*计算每棵树的预测结果*/
  16. val treePredictions = _trees.map(_.rootNode.predictImpl(features).prediction)
  17. /*将每棵树的计算结果和相关权重做ddot计算*/
  18. blas.ddot(numTrees, treePredictions, 1, _treeWeights, 1)
  19. }

【分类实现】

  1. override protected def transformImpl(dataset: Dataset[_]): DataFrame = {
  2. /*广播模型变量*/
  3. val bcastModel = dataset.sparkSession.sparkContext.broadcast(this)
  4. /*实现预测的UDF*/
  5. val predictUDF = udf { (features: Any) =>、
  6. /*调用下面的函数进行预测*/
  7. bcastModel.value.predict(features.asInstanceOf[Vector])
  8. }
  9. /*将预测结果作为新的一列拼接到当前DataFrame*/
  10. dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol))))
  11. }
  12. override protected def predict(features: Vector): Double = {
  13. // TODO: When we add a generic Boosting class, handle transform there? SPARK-7129
  14. // Classifies by thresholding sum of weighted tree predictions
  15. /*获取每颗模型数的预测结果*/
  16. val treePredictions = _trees.map(_.rootNode.predictImpl(features).prediction)
  17. /*将每颗树模型的预测结果和树模型的权重做ddot计算,得出一个[-1,1]的值*/
  18. val prediction = blas.ddot(numTrees, treePredictions, 1, _treeWeights, 1)
  19. /*由于模型训练期间已经将预测结果范围调整到[-1,+1],将预测结果转换成[0,1]*/
  20. if (prediction > 0.0) 1.0 else 0.0
  21. }

 

 

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/很楠不爱3/article/detail/547221
推荐阅读
相关标签
  

闽ICP备14008679号