当前位置:   article > 正文

Spark MLlib学习(1)-- Pipelines_sparkmllib的pipeline工具类代码

sparkmllib的pipeline工具类代码

基本概念

DataFrame

机器学习API使用来自Spark SQL的DataFrame作为数据集,它能包括多种数据类型,如文本、特征向量、标签、预测值等。

Transformers

一个Transformers是一个能转化一个DataFrame到另一个DataFrame的算法,例如,一个model可以转化带有特征的DataFrame为一个带有预测值的DataFrame。

Transformers包括特征转化器(feature transformers)和已训练模型(learned models),通常实现方法 transform(),一般通过附加上更多列的方式转化DataFrame为另一个DataFrame。

  • 特征转化器:读取DataFrame的一个列,映射为另一个,输出一个新的DataFrame,这个DataFrame附加上新的映射列。
  • 已训练模型:读取DataFrame的包含特征向量的列,预测特征向量的标签,输出预测标签作为附加列。

Estimators

一个Estimators能通过一个DataFrame生成一个Transformer,例如,一个机器学习算法是一个Estimators,它能在DataFrame上训练得到model。

通常实现方法fit()

Pipeline

一个Pipeline链接多个Transformers和Estimators,指定一个机器学习工作流。

例如,一个简单的文本文件处理需要以下步骤:

  1. 划分文件的文本为单词
  2. 转化单词为特征向量
  3. 从特征向量和标签中学习预测模型

这些步骤就是一个机器学习工作流,也就是Pipeline,它包含一系列PipelineStages,并且按一定顺序运行。

例子

Estimator, Transformer, and Param

  1. import org.apache.spark.ml.classification.LogisticRegression
  2. import org.apache.spark.ml.linalg.{Vector, Vectors}
  3. import org.apache.spark.ml.param.ParamMap
  4. import org.apache.spark.sql.Row
  5. // Prepare training data from a list of (label, features) tuples.
  6. val training = spark.createDataFrame(Seq(
  7. (1.0, Vectors.dense(0.0, 1.1, 0.1)),
  8. (0.0, Vectors.dense(2.0, 1.0, -1.0)),
  9. (0.0, Vectors.dense(2.0, 1.3, 1.0)),
  10. (1.0, Vectors.dense(0.0, 1.2, -0.5))
  11. )).toDF("label", "features")
  12. // Create a LogisticRegression instance. This instance is an Estimator.
  13. //这是一个逻辑回归实例,是一个Estimator
  14. val lr = new LogisticRegression()
  15. // Print out the parameters, documentation, and any default values.
  16. //打印逻辑回归参数
  17. println(s"LogisticRegression parameters:\n ${lr.explainParams()}\n")
  18. // We may set parameters using setter methods.
  19. //设置参数
  20. lr.setMaxIter(10)
  21. .setRegParam(0.01)
  22. // Learn a LogisticRegression model. This uses the parameters stored in lr.
  23. //训练逻辑回归模型
  24. val model1 = lr.fit(training)
  25. // Since model1 is a Model (i.e., a Transformer produced by an Estimator),
  26. // we can view the parameters it used during fit().
  27. // This prints the parameter (name: value) pairs, where names are unique IDs for this
  28. // LogisticRegression instance.
  29. //打印训练model1所用的参数
  30. println(s"Model 1 was fit using parameters: ${model1.parent.extractParamMap}")
  31. // We may alternatively specify parameters using a ParamMap,
  32. // which supports several methods for specifying parameters.
  33. //使用ParamMap制定参数
  34. val paramMap = ParamMap(lr.maxIter -> 20)
  35. .put(lr.maxIter, 30) // Specify 1 Param. This overwrites the original maxIter.
  36. .put(lr.regParam -> 0.1, lr.threshold -> 0.55) // Specify multiple Params.
  37. // One can also combine ParamMaps.
  38. val paramMap2 = ParamMap(lr.probabilityCol -> "myProbability") // Change output column name.
  39. val paramMapCombined = paramMap ++ paramMap2
  40. // Now learn a new model using the paramMapCombined parameters.
  41. // paramMapCombined overrides all parameters set earlier via lr.set* methods.
  42. val model2 = lr.fit(training, paramMapCombined)
  43. println(s"Model 2 was fit using parameters: ${model2.parent.extractParamMap}")
  44. // Prepare test data.
  45. val test = spark.createDataFrame(Seq(
  46. (1.0, Vectors.dense(-1.0, 1.5, 1.3)),
  47. (0.0, Vectors.dense(3.0, 2.0, -0.1)),
  48. (1.0, Vectors.dense(0.0, 2.2, -1.5))
  49. )).toDF("label", "features")
  50. // Make predictions on test data using the Transformer.transform() method.
  51. // LogisticRegression.transform will only use the 'features' column.
  52. // Note that model2.transform() outputs a 'myProbability' column instead of the usual
  53. // 'probability' column since we renamed the lr.probabilityCol parameter previously.
  54. //用model2做预测
  55. model2.transform(test)
  56. .select("features", "label", "myProbability", "prediction")
  57. .collect()
  58. .foreach { case Row(features: Vector, label: Double, prob: Vector, prediction: Double) =>
  59. println(s"($features, $label) -> prob=$prob, prediction=$prediction")
  60. }

Pipeline

  1. import org.apache.spark.ml.{Pipeline, PipelineModel}
  2. import org.apache.spark.ml.classification.LogisticRegression
  3. import org.apache.spark.ml.feature.{HashingTF, Tokenizer}
  4. import org.apache.spark.ml.linalg.Vector
  5. import org.apache.spark.sql.Row
  6. // Prepare training documents from a list of (id, text, label) tuples.
  7. val training = spark.createDataFrame(Seq(
  8. (0L, "a b c d e spark", 1.0),
  9. (1L, "b d", 0.0),
  10. (2L, "spark f g h", 1.0),
  11. (3L, "hadoop mapreduce", 0.0)
  12. )).toDF("id", "text", "label")
  13. // Configure an ML pipeline, which consists of three stages: tokenizer, hashingTF, and lr.
  14. //配置pipeline,包含三个阶段:tokenizer, hashingTF, and lr.
  15. val tokenizer = new Tokenizer()
  16. .setInputCol("text")
  17. .setOutputCol("words")
  18. val hashingTF = new HashingTF()
  19. .setNumFeatures(1000)
  20. .setInputCol(tokenizer.getOutputCol)
  21. .setOutputCol("features")
  22. val lr = new LogisticRegression()
  23. .setMaxIter(10)
  24. .setRegParam(0.001)
  25. val pipeline = new Pipeline()
  26. .setStages(Array(tokenizer, hashingTF, lr))
  27. // Fit the pipeline to training documents.
  28. //使用pipeline训练模型
  29. val model = pipeline.fit(training)
  30. // Now we can optionally save the fitted pipeline to disk
  31. //保存模型到磁盘
  32. model.write.overwrite().save("/tmp/spark-logistic-regression-model")
  33. // We can also save this unfit pipeline to disk
  34. //保存pipeline到磁盘
  35. pipeline.write.overwrite().save("/tmp/unfit-lr-model")
  36. // And load it back in during production
  37. //从磁盘加载已保存的model
  38. val sameModel = PipelineModel.load("/tmp/spark-logistic-regression-model")
  39. // Prepare test documents, which are unlabeled (id, text) tuples.
  40. val test = spark.createDataFrame(Seq(
  41. (4L, "spark i j k"),
  42. (5L, "l m n"),
  43. (6L, "spark hadoop spark"),
  44. (7L, "apache hadoop")
  45. )).toDF("id", "text")
  46. // Make predictions on test documents.
  47. model.transform(test)
  48. .select("id", "text", "probability", "prediction")
  49. .collect()
  50. .foreach { case Row(id: Long, text: String, prob: Vector, prediction: Double) =>
  51. println(s"($id, $text) --> prob=$prob, prediction=$prediction")
  52. }

 

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/你好赵伟/article/detail/547119
推荐阅读
相关标签
  

闽ICP备14008679号