当前位置:   article > 正文

spark XGBoost算法demo_import ml.dmlc.xgboost4j.scala.spark.xgboost demo

import ml.dmlc.xgboost4j.scala.spark.xgboost demo

1.运行环境配置

     该算法需要运行Linux环境下,运行的版本为:spark2.4.0,scala 2.11

2.maven配置

  1. <properties>
  2. <project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
  3. <project.reporting.outputEncoding>UTF-8</project.reporting.outputEncoding>
  4. <scala.version>2.11</scala.version>
  5. <spark.version>2.4.0</spark.version>
  6. </properties>
  7. <dependencies>
  8. <dependency>
  9. <groupId>org.apache.spark</groupId>
  10. <artifactId>spark-sql_${scala.version}</artifactId>
  11. <version>${spark.version}</version>
  12. </dependency>
  13. <dependency>
  14. <groupId>org.apache.spark</groupId>
  15. <artifactId>spark-core_${scala.version}</artifactId>
  16. <version>${spark.version}</version>
  17. </dependency>
  18. <dependency>
  19. <groupId>org.apache.spark</groupId>
  20. <artifactId>spark-mllib_${scala.version}</artifactId>
  21. <version>${spark.version}</version>
  22. </dependency>
  23. <dependency>
  24. <groupId>ml.dmlc</groupId>
  25. <artifactId>xgboost4j</artifactId>
  26. <version>0.72</version>
  27. </dependency>
  28. <dependency>
  29. <groupId>ml.dmlc</groupId>
  30. <artifactId>xgboost4j-spark</artifactId>
  31. <version>0.72</version>
  32. </dependency>
  33. </dependencies>
  34. <build>
  35. <plugins>
  36. <plugin>
  37. <groupId>org.scala-tools</groupId>
  38. <artifactId>maven-scala-plugin</artifactId>
  39. <version>2.15.2</version>
  40. <executions>
  41. <execution>
  42. <goals>
  43. <goal>compile</goal>
  44. <goal>testCompile</goal>
  45. </goals>
  46. </execution>
  47. </executions>
  48. </plugin>
  49. <plugin>
  50. <artifactId>maven-compiler-plugin</artifactId>
  51. <version>3.6.0</version>
  52. <configuration>
  53. <source>1.8</source>
  54. <target>1.8</target>
  55. </configuration>
  56. </plugin>
  57. <plugin>
  58. <groupId>org.apache.maven.plugins</groupId>
  59. <artifactId>maven-surefire-plugin</artifactId>
  60. <version>2.19</version>
  61. <configuration>
  62. <skip>true</skip>
  63. </configuration>
  64. </plugin>
  65. <!-- 打出jar包引用关联包 -->
  66. <plugin>
  67. <groupId>org.apache.maven.plugins</groupId>
  68. <artifactId>maven-jar-plugin</artifactId>
  69. <version>2.4</version>
  70. <configuration>
  71. <archive>
  72. <manifest>
  73. <addClasspath>true</addClasspath>
  74. <classpathPrefix>lib/</classpathPrefix>
  75. <!--<mainClass>com.caxs.artemis.model.schedule.ModelInvoke</mainClass>-->
  76. </manifest>
  77. </archive>
  78. </configuration>
  79. </plugin>
  80. <!-- 将依赖包放到lib文件夹中 -->
  81. <plugin>
  82. <groupId>org.apache.maven.plugins</groupId>
  83. <artifactId>maven-dependency-plugin</artifactId>
  84. <executions>
  85. <execution>
  86. <id>copy</id>
  87. <phase>package</phase>
  88. <goals>
  89. <goal>copy-dependencies</goal>
  90. </goals>
  91. <configuration>
  92. <outputDirectory>
  93. ${project.build.directory}/lib
  94. </outputDirectory>
  95. </configuration>
  96. </execution>
  97. </executions>
  98. </plugin>
  99. </plugins>
  100. </build>

3.运行demo

  1. package spark
  2. import ml.dmlc.xgboost4j.scala.spark.XGBoost
  3. import org.apache.spark.ml.feature.VectorAssembler
  4. import org.apache.spark.sql.{DataFrame, SparkSession}
  5. /**
  6. * author :test-abc
  7. * date :Created in 2019/9/3 11:04
  8. * description:xgboost demo
  9. * modified By:
  10. */
  11. object XgboostDemo {
  12. def main(args: Array[String]): Unit = {
  13. val spark: SparkSession = SparkSession.builder()
  14. .appName("SparkSql")
  15. // .master("local[2]")
  16. .getOrCreate()
  17. //准备示例数据,将数据转为dataframe
  18. import spark.implicits._
  19. val dataList: List[(Int, Double, Double, Double, Double, Double, Double)] = List(
  20. (0, 8.9255, -6.7863, 11.9081, 5.093, 11.4607, -9.2834),
  21. (1, 11.5006, -4.1473, 13.8588, 5.389, 12.3622, 7.0433),
  22. (0, 8.6093, -2.7457, 12.0805, 7.8928, 10.5825, -9.0837),
  23. (1, 11.0604, -2.1518, 8.9522, 7.1957, 12.5846, -1.8361),
  24. (1, 9.8369, -1.4834, 12.8746, 6.6375, 12.2772, 2.4486),
  25. (1, 11.4763, -2.3182, 12.608, 8.6264, 10.9621, 3.5609),
  26. (0, 11.8091, -0.0832, 9.3494, 4.2916, 11.1355, -8.0198),
  27. (0, 13.558, -7.9881, 13.8776, 7.5985, 8.6543, 0.831),
  28. (0, 16.1071, 2.4426, 13.9307, 5.6327, 8.8014, 6.163),
  29. (1, 12.5088, 1.9743, 8.896, 5.4508, 13.6043, -16.2859),
  30. (0, 5.0702, -0.5447, 9.59, 4.2987, 12.391, -18.8687),
  31. (0, 12.7188, -7.975, 10.3757, 9.0101, 12.857, -12.0852),
  32. (0, 8.7671, -4.6154, 9.7242, 7.4242, 9.0254, 1.4247),
  33. (1, 16.3699, 1.5934, 16.7395, 7.333, 12.145, 5.9004),
  34. (0, 13.808, 5.0514, 17.2611, 8.512, 12.8517, -9.1622),
  35. (0, 3.9416, 2.6562, 13.3633, 6.8895, 12.2806, -16.162),
  36. (0, 5.0615, 0.2689, 15.1325, 3.6587, 13.5276, -6.5477),
  37. (1, 8.4199, -1.8128, 8.1202, 5.3955, 9.7184, -17.839),
  38. (0, 4.875, 1.2646, 11.919, 8.465, 10.7203, -0.6707),
  39. (1, 4.409, -0.7863, 15.1828, 8.0631, 11.2831, -0.7356))
  40. val inputDF: DataFrame = dataList.toDF("label", "feature1", "feature2", "feature3", "feature4", "feature5", "feature6")
  41. //将需要转换的列合并为向量列
  42. val transCols: Array[String] = Array("feature1", "feature2", "feature3", "feature4", "feature5", "feature6")
  43. val assembler: VectorAssembler = new VectorAssembler().setInputCols(transCols).setOutputCol("features")
  44. val xGBoostTrainInput: DataFrame = assembler.transform(inputDF).select("features","label")
  45. xGBoostTrainInput.show(10)
  46. // val paramMap = List(
  47. // "eta" -> 0.01, //学习率
  48. // "gamma" -> 0.1, //用于控制是否后剪枝的参数,越大越保守,一般0.1、0.2这样子。
  49. // "lambda" -> 2, //控制模型复杂度的权重值的L2正则化项参数,参数越大,模型越不容易过拟合。
  50. // "subsample" -> 0.8, //随机采样训练样本
  51. // "colsample_bytree" -> 0.8, //生成树时进行的列采样
  52. // "max_depth" -> 5, //构建树的深度,越大越容易过拟合
  53. // "min_child_weight" -> 5,
  54. // "objective" -> "multi:softprob", //定义学习任务及相应的学习目标
  55. // "eval_metric" -> "merror",
  56. // "num_class" -> 21
  57. // ).toMap
  58. val paramMap = List(
  59. "colsample_bytree" -> 1,
  60. "eta" -> 0.05f, //就是学习率
  61. "max_depth" -> 8, //树的最大深度
  62. "min_child_weight" -> 5, //
  63. "n_estimators" -> 120,
  64. "subsample" -> 0.7
  65. ).toMap
  66. //模型训练
  67. val xgBoostModel = XGBoost.trainWithDataFrame(xGBoostTrainInput, paramMap, round = 10, nWorkers = 3,
  68. useExternalMemory = true, featureCol = "features", labelCol = "label")
  69. //准备预测数据
  70. val testList: List[( Double, Double, Double, Double, Double, Double)] = List(
  71. ( 8.9225, -6.7863, 11.9081, 5.093, 11.4607, -9.2834),
  72. ( 11.5006, -4.1473, 13.8588, 5.389, 12.3622, 7.0433),
  73. ( 8.6093, -2.7457, 12.0805, 7.8928, 10.5825, -9.0837),
  74. ( 11.0604, -2.1518, 8.9522, 7.1957, 12.5846, -1.8361),
  75. ( 9.8369, -11.4834, 12.8746, 6.6375, 12.2772, 2.4486),
  76. ( 11.4763, -2.3182, 12.608, 8.6264, 10.9621, 3.5609),
  77. ( 11.8091, -10.0832, 9.3494, 4.2916, 11.1355, -8.0198),
  78. ( 13.558, -7.9881, 13.8776, 7.5985, 8.6543, 0.831),
  79. ( 16.1071, 1.4426, 13.9307, 5.6327, 8.8014, 6.163),
  80. ( 12.5088, 2.9743, 8.896, 5.4508, 13.6043, -16.2859),
  81. ( 5.0702, -0.5447, 9.59, 4.2987, 12.391, -18.8687),
  82. ( 12.7188, -7.975, 10.3757, 9.0101, 12.857, -12.0852),
  83. ( 8.7671, -4.6154, 8.7242, 7.4242, 9.0254, 1.4247),
  84. ( 16.3699, 1.5934, 16.7395, 7.333, 12.145, 5.9004),
  85. ( 13.808, 5.0514, 17.2611, 8.512, 12.8517, -9.1622),
  86. ( 3.9416, 2.6562, 13.3633, 6.8895, 12.2806, -16.162),
  87. ( 5.0615, 0.2689, 15.1325, 3.6587, 13.5276, -6.5477),
  88. ( 8.4199, -1.8128, 9.1202, 5.3955, 9.7184, -17.839),
  89. ( 5.875, 1.2646, 11.919, 8.465, 10.7203, -0.6707),
  90. ( 5.409, -0.7863, 15.1828, 8.0631, 11.2831, -0.7356))
  91. val testDf: DataFrame = testList.toDF("feature1", "feature2", "feature3", "feature4", "feature5", "feature6")
  92. //将测试数据集转为向量
  93. val xGBoostTestInput: DataFrame = assembler.transform(testDf).select("features")
  94. xGBoostTestInput.show(10)
  95. //模型预测
  96. val output: DataFrame = xgBoostModel.transform(xGBoostTestInput)
  97. output.show()
  98. spark.close()
  99. }
  100. }

运行结果为:

本文内容由网友自发贡献,转载请注明出处:【wpsshop博客】
推荐阅读
相关标签
  

闽ICP备14008679号