当前位置:   article > 正文

逻辑回归算法原理及Spark MLlib调用实例(Scala/Java/python)_regparam

regparam

逻辑回归

算法原理:

        逻辑回归是一个流行的二分类问题预测方法。它是Generalized Linear models 的一个特殊应用以预测结果概率。它是一个线性模型如下列方程所示,其中损失函数为逻辑损失:

 

        对于二分类问题,算法产出一个二值逻辑回归模型。给定一个新数据,由x表示,则模型通过下列逻辑方程来预测:

 

        其中 。默认情况下,如果 ,结果为正,否则为负。和线性SVMs不同,逻辑回归的原始输出有概率解释(x为正的概率)。

        二分类逻辑回归可以扩展为多分类逻辑回归来训练和预测多类别分类问题。如一个分类问题有K种可能结果,我们可以选取其中一种结果作为“中心点“,其他K-1个结果分别视为中心点结果的对立点。在spark.mllib中,取第一个类别为中心点类别。

*目前spark.ml逻辑回归工具仅支持二分类问题,多分类回归将在未来完善。

*当使用无拦截的连续非零列训练LogisticRegressionModel时,Spark MLlib为连续非零列输出零系数。这种处理不同于libsvm与R glmnet相似。

参数:

elasticNetParam:

类型:双精度型。

含义:弹性网络混合参数,范围[0,1]。

featuresCol:

类型:字符串型。

含义:特征列名。

fitIntercept:

类型:布尔型。

含义:是否训练拦截对象。

labelCol:

类型:字符串型。

含义:标签列名。

maxIter:

类型:整数型。

含义:最多迭代次数(>=0)。

predictionCol:

类型:字符串型。

含义:预测结果列名。

probabilityCol:

类型:字符串型。

含义:用以预测类别条件概率的列名。

regParam:

类型:双精度型。

含义:正则化参数(>=0)。

standardization:

类型:布尔型。

含义:训练模型前是否需要对训练特征进行标准化处理。

threshold:

类型:双精度型。

含义:二分类预测的阀值,范围[0,1]。

thresholds:

类型:双精度数组型。

含义:多分类预测的阀值,以调整预测结果在各个类别的概率。

tol:

类型:双精度型。

含义:迭代算法的收敛性。

weightCol:

类型:字符串型。

含义:列权重。

示例:

       下面的例子展示如何训练使用弹性网络正则化的逻辑回归模型。elasticNetParam对应于 ,regParam对应于 。

Scala:

import org.apache.spark.ml.classification.LogisticRegression

// Load training data
val training = spark.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt")

val lr = new LogisticRegression()
  .setMaxIter(10)
  .setRegParam(0.3)
  .setElasticNetParam(0.8)

// Fit the model
val lrModel = lr.fit(training)

// Print the coefficients and intercept for logistic regression
println(s"Coefficients: ${lrModel.coefficients} Intercept: ${lrModel.intercept}")
Java:

  1. import org.apache.spark.ml.classification.LogisticRegression;
  2. import org.apache.spark.ml.classification.LogisticRegressionModel;
  3. import org.apache.spark.sql.Dataset;
  4. import org.apache.spark.sql.Row;
  5. import org.apache.spark.sql.SparkSession;
  6. // Load training data
  7. Dataset<Row> training = spark.read().format("libsvm")
  8. .load("data/mllib/sample_libsvm_data.txt");
  9. LogisticRegression lr = new LogisticRegression()
  10. .setMaxIter(10)
  11. .setRegParam(0.3)
  12. .setElasticNetParam(0.8);
  13. // Fit the model
  14. LogisticRegressionModel lrModel = lr.fit(training);
  15. // Print the coefficients and intercept for logistic regression
  16. System.out.println("Coefficients: "
  17. + lrModel.coefficients() + " Intercept: " + lrModel.intercept());
Python:

  1. from pyspark.ml.classification import LogisticRegression
  2. # Load training data
  3. training = spark.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt")
  4. lr = LogisticRegression(maxIter=10, regParam=0.3, elasticNetParam=0.8)
  5. # Fit the model
  6. lrModel = lr.fit(training)
  7. # Print the coefficients and intercept for logistic regression
  8. print("Coefficients: " + str(lrModel.coefficients))
  9. print("Intercept: " + str(lrModel.intercept))

        spark.ml逻辑回归工具同样支持提取模总结。LogisticRegressionTrainingSummary提供LogisticRegressionModel的总结。目前仅支持二分类问题,所以总结必须明确投掷到BinaryLogisticRegressionTrainingSummary。支持多分类问题后可能有所改善。

继续上面的例子:

Scala:

import org.apache.spark.ml.Pipeline
import org.apache.spark.ml.classification.DecisionTreeClassificationModel
import org.apache.spark.ml.classification.DecisionTreeClassifier
import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator
import org.apache.spark.ml.feature.{IndexToString, StringIndexer, VectorIndexer}

// Load the data stored in LIBSVM format as a DataFrame.
val data = spark.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt")

// Index labels, adding metadata to the label column.
// Fit on whole dataset to include all labels in index.
val labelIndexer = new StringIndexer()
  .setInputCol("label")
  .setOutputCol("indexedLabel")
  .fit(data)
// Automatically identify categorical features, and index them.
val featureIndexer = new VectorIndexer()
  .setInputCol("features")
  .setOutputCol("indexedFeatures")
  .setMaxCategories(4) // features with > 4 distinct values are treated as continuous.
  .fit(data)

// Split the data into training and test sets (30% held out for testing).
val Array(trainingData, testData) = data.randomSplit(Array(0.7, 0.3))

// Train a DecisionTree model.
val dt = new DecisionTreeClassifier()
  .setLabelCol("indexedLabel")
  .setFeaturesCol("indexedFeatures")

// Convert indexed labels back to original labels.
val labelConverter = new IndexToString()
  .setInputCol("prediction")
  .setOutputCol("predictedLabel")
  .setLabels(labelIndexer.labels)

// Chain indexers and tree in a Pipeline.
val pipeline = new Pipeline()
  .setStages(Array(labelIndexer, featureIndexer, dt, labelConverter))

// Train model. This also runs the indexers.
val model = pipeline.fit(trainingData)

// Make predictions.
val predictions = model.transform(testData)

// Select example rows to display.
predictions.select("predictedLabel", "label", "features").show(5)

// Select (prediction, true label) and compute test error.
val evaluator = new MulticlassClassificationEvaluator()
  .setLabelCol("indexedLabel")
  .setPredictionCol("prediction")
  .setMetricName("accuracy")
val accuracy = evaluator.evaluate(predictions)
println("Test Error = " + (1.0 - accuracy))

val treeModel = model.stages(2).asInstanceOf[DecisionTreeClassificationModel]
println("Learned classification tree model:\n" + treeModel.toDebugString)
Java:

  1. import org.apache.spark.ml.Pipeline;
  2. import org.apache.spark.ml.PipelineModel;
  3. import org.apache.spark.ml.PipelineStage;
  4. import org.apache.spark.ml.classification.DecisionTreeClassifier;
  5. import org.apache.spark.ml.classification.DecisionTreeClassificationModel;
  6. import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator;
  7. import org.apache.spark.ml.feature.*;
  8. import org.apache.spark.sql.Dataset;
  9. import org.apache.spark.sql.Row;
  10. import org.apache.spark.sql.SparkSession;
  11. // Load the data stored in LIBSVM format as a DataFrame.
  12. Dataset<Row> data = spark
  13. .read()
  14. .format("libsvm")
  15. .load("data/mllib/sample_libsvm_data.txt");
  16. // Index labels, adding metadata to the label column.
  17. // Fit on whole dataset to include all labels in index.
  18. StringIndexerModel labelIndexer = new StringIndexer()
  19. .setInputCol("label")
  20. .setOutputCol("indexedLabel")
  21. .fit(data);
  22. // Automatically identify categorical features, and index them.
  23. VectorIndexerModel featureIndexer = new VectorIndexer()
  24. .setInputCol("features")
  25. .setOutputCol("indexedFeatures")
  26. .setMaxCategories(4) // features with > 4 distinct values are treated as continuous.
  27. .fit(data);
  28. // Split the data into training and test sets (30% held out for testing).
  29. Dataset<Row>[] splits = data.randomSplit(new double[]{0.7, 0.3});
  30. Dataset<Row> trainingData = splits[0];
  31. Dataset<Row> testData = splits[1];
  32. // Train a DecisionTree model.
  33. DecisionTreeClassifier dt = new DecisionTreeClassifier()
  34. .setLabelCol("indexedLabel")
  35. .setFeaturesCol("indexedFeatures");
  36. // Convert indexed labels back to original labels.
  37. IndexToString labelConverter = new IndexToString()
  38. .setInputCol("prediction")
  39. .setOutputCol("predictedLabel")
  40. .setLabels(labelIndexer.labels());
  41. // Chain indexers and tree in a Pipeline.
  42. Pipeline pipeline = new Pipeline()
  43. .setStages(new PipelineStage[]{labelIndexer, featureIndexer, dt, labelConverter});
  44. // Train model. This also runs the indexers.
  45. PipelineModel model = pipeline.fit(trainingData);
  46. // Make predictions.
  47. Dataset<Row> predictions = model.transform(testData);
  48. // Select example rows to display.
  49. predictions.select("predictedLabel", "label", "features").show(5);
  50. // Select (prediction, true label) and compute test error.
  51. MulticlassClassificationEvaluator evaluator = new MulticlassClassificationEvaluator()
  52. .setLabelCol("indexedLabel")
  53. .setPredictionCol("prediction")
  54. .setMetricName("accuracy");
  55. double accuracy = evaluator.evaluate(predictions);
  56. System.out.println("Test Error = " + (1.0 - accuracy));
  57. DecisionTreeClassificationModel treeModel =
  58. (DecisionTreeClassificationModel) (model.stages()[2]);
  59. System.out.println("Learned classification tree model:\n" + treeModel.toDebugString());


声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/2023面试高手/article/detail/547131
推荐阅读
相关标签
  

闽ICP备14008679号