赞
踩
机器学习中,人为地设计算法,需要一定的知识积淀。
而使用别人设计好的机器学习库如 Spark 2.0 ML
,那是基本不需要什么基础的,开箱即用。
首先,看一个简单、完整、规范的案例,无疑是最好的方式。
之前的文章(内含短小精悍的案例):
Spark 2.0 机器学习 ML 库:特征提取、转化、选取(Scala 版)
Spark 2.0 机器学习 ML 库:机器学习工作流、交叉 - 验证方法(Scala 版)
Spark 2.0 机器学习 ML 库:数据分析方法(Scala 版)
下面的代码,来自网上,挺好的,笔者加以细化
package change
import org.apache.spark.ml.regression.{LinearRegression, LinearRegressionModel}
import org.apache.spark.ml.linalg.{Vector, Vectors}
import org.apache.spark.sql._
import org.apache.spark.sql.SparkSession
/**
* 线性回归
*/
object linearTest {
def main(args: Array[String]): Unit = {
// 0.构建 Spark 对象
val spark = SparkSession
.builder()
.master("local") // 本地测试,否则报错 A master URL must be set in your configuration at org.apache.spark.SparkContext.
.appName("test")
.enableHiveSupport()
.getOrCreate() // 有就获取无则创建
spark.sparkContext.setCheckpointDir("C:\\LLLLLLLLLLLLLLLLLLL\\BigData_AI\\sparkmlTest") //设置文件读取、存储的目录,HDFS最佳
import spark.implicits._
//1 训练样本准备
val training = spark.createDataFrame(Seq(
(5.601801561245534, Vectors.sparse(10, Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9), Array(0.6949189734965766, -0.32697929564739403, -0.15359663581829275, -0.8951865090520432, 0.2057889391931318, -0.6676656789571533, -0.03553655732400762, 0.14550349954571096, 0.034600542078191854, 0.4223352065067103))),
(0.2577820163584905, Vectors.sparse(10, Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9), Array(0.8386555657374337, -0.1270180511534269, 0.499812362510895, -0.22686625128130267, -0.6452430441812433, 0.18869982177936828, -0.5804648622673358, 0.651931743775642, -0.6555641246242951, 0.17485476357259122))),
(1.5299675726687754, Vectors.sparse(10, Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9), Array(-0.13079299081883855, 0.0983382230287082, 0.15347083875928424, 0.45507300685816965, 0.1921083467305864, 0.6361110540492223, 0.7675261182370992, -0.2543488202081907, 0.2927051050236915, 0.680182444769418))))).toDF("label", "features")
training.show(false)
//2 建立逻辑回归模型
val lr = new LinearRegression()
.setMaxIter(100)
.setRegParam(0.1)
.setElasticNetParam(0.5)
//2 根据训练样本进行模型训练
val lrModel = lr.fit(training)
//2 打印模型信息
println(s"Coefficients: ${lrModel.coefficients} Intercept: ${lrModel.intercept}")
/**
* Coefficients: [0.0,-0.8840148895400428,-4.451571521834594,-0.42090140779272434,0.857395634491616,-1.237347818637769,0.0,0.0,0.0,0.0] Intercept: 3.1417724655192645
*/
println(s"Intercept: ${lrModel.intercept}")
/**
* Intercept: 3.1417724655192645
*/
//4 测试样本
val test = spark.createDataFrame(Seq(
(5.601801561245534, Vectors.sparse(10, Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9), Array(0.6949189734965766, -0.32697929564739403, -0.15359663581829275, -0.8951865090520432, 0.2057889391931318, -0.6676656789571533, -0.03553655732400762, 0.14550349954571096, 0.034600542078191854, 0.4223352065067103))),
(0.2577820163584905, Vectors.sparse(10, Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9), Array(0.8386555657374337, -0.1270180511534269, 0.499812362510895, -0.22686625128130267, -0.6452430441812433, 0.18869982177936828, -0.5804648622673358, 0.651931743775642, -0.6555641246242951, 0.17485476357259122))),
(1.5299675726687754, Vectors.sparse(10, Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9), Array(-0.13079299081883855, 0.0983382230287082, 0.15347083875928424, 0.45507300685816965, 0.1921083467305864, 0.6361110540492223, 0.7675261182370992, -0.2543488202081907, 0.2927051050236915, 0.680182444769418))))).toDF("label", "features")
test.show(false)
/**
* +------------------+-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
* |label |features |
* +------------------+-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
* |5.601801561245534 |(10,[0,1,2,3,4,5,6,7,8,9],[0.6949189734965766,-0.32697929564739403,-0.15359663581829275,-0.8951865090520432,0.2057889391931318,-0.6676656789571533,-0.03553655732400762,0.14550349954571096,0.034600542078191854,0.4223352065067103])|
* |0.2577820163584905|(10,[0,1,2,3,4,5,6,7,8,9],[0.8386555657374337,-0.1270180511534269,0.499812362510895,-0.22686625128130267,-0.6452430441812433,0.18869982177936828,-0.5804648622673358,0.651931743775642,-0.6555641246242951,0.17485476357259122]) |
* |1.5299675726687754|(10,[0,1,2,3,4,5,6,7,8,9],[-0.13079299081883855,0.0983382230287082,0.15347083875928424,0.45507300685816965,0.1921083467305864,0.6361110540492223,0.7675261182370992,-0.2543488202081907,0.2927051050236915,0.680182444769418]) |
* +------------------+-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
*/
//5 对模型进行测试
val test_predict = lrModel.transform(test)
test_predict
.select("label","prediction","features")
.show(false)
/**
* +------------------+-------------------+-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
* |label |prediction |features |
* +------------------+-------------------+-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
* |5.601801561245534 |5.493935912726037 |(10,[0,1,2,3,4,5,6,7,8,9],[0.6949189734965766,-0.32697929564739403,-0.15359663581829275,-0.8951865090520432,0.2057889391931318,-0.6676656789571533,-0.03553655732400762,0.14550349954571096,0.034600542078191854,0.4223352065067103])|
* |0.2577820163584905|0.33788027718672575|(10,[0,1,2,3,4,5,6,7,8,9],[0.8386555657374337,-0.1270180511534269,0.499812362510895,-0.22686625128130267,-0.6452430441812433,0.18869982177936828,-0.5804648622673358,0.651931743775642,-0.6555641246242951,0.17485476357259122]) |
* |1.5299675726687754|1.557734960360036 |(10,[0,1,2,3,4,5,6,7,8,9],[-0.13079299081883855,0.0983382230287082,0.15347083875928424,0.45507300685816965,0.1921083467305864,0.6361110540492223,0.7675261182370992,-0.2543488202081907,0.2927051050236915,0.680182444769418]) |
* +------------------+-------------------+-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
*/
//6 模型摘要
val trainingSummary = lrModel.summary
// 每次迭代目标值
val objectiveHistory = trainingSummary.objectiveHistory
println(s"numIterations: ${trainingSummary.totalIterations}")
/**
* numIterations: 101
*/
println(s"objectiveHistory: [${trainingSummary.objectiveHistory.mkString(",")}]")
/**
* objectiveHistory: [0.5,0.3343138710353481,0.03498698228406803,0.034394702527331365,0.03361752133820051,0.033440576313009396,0.032492999827506586,0.03209249818672103,0.03201118276878801,0.0318030335506653,0.031556141484809515,0.03146914334471842,0.03132368104987874,0.030906857778152226,0.030829631969772512,0.030792601096269995,0.03075807300477159,0.03064409361649658,0.03057645418974434,0.03048720940080922,0.030450452329432418,0.0303403006892938,0.03022336621283447,0.030105231797686347,0.03005248564337978,0.029952523828252434,0.029901762708870988,0.029901114112460842,0.029897992643680316,0.029897097909156505,0.029892358780083193,0.029890487541861296,0.029883508098656905,0.02986342331315129,0.029846157576330717,0.02983921669719768,0.029837621981381814,0.029832343881027193,0.029818011565517288,0.0298174329753425,0.029816619127868163,0.029815897918569062,0.029815813156609985,0.029815635355907394,0.029814914126549,0.029813735638819686,0.02981357400967502,0.02981340129452729,0.029813363218666296,0.029813104482615992,0.029813066188642295,0.02981290111924657,0.029812867201451012,0.029812730285385426,0.029812706953398726,0.02981259780704471,0.02981258478371474,0.02981249810105761,0.029812492058484363,0.029812414896583955,0.02981239284306545,0.02981217952516655,0.029812093354005524,0.029812078847204722,0.02981204606864486,0.029812029284085127,0.029812008170753846,0.029812001127453244,0.02981198610905457,0.029811978179336476,0.029811968590860403,0.029811960922339894,0.02981195510843637,0.029811951516538388,0.02981194560589678,0.029811931971338676,0.029811927559300986,0.02981192583464405,0.029811923533256,0.02981192147493291,0.029811919101372975,0.0298119178536648,0.029811915692737362,0.02981191417259256,0.029811912340872517,0.02981191111669305,0.02981190922210416,0.029811908328486812,0.029811906376022823,0.029811905682559023,0.02981190386743857,0.029811903165691635,0.02981190159751578,0.029811901021202986,0.02981189985181355,0.02981189892054736,0.029811897724408266,0.02981189698790617,0.02981189562974597,0.029811894938092554,0.029811894064851477]
*/
trainingSummary.residuals.show(false)
/**
* +---------------------+
* |residuals |
* +---------------------+
* |0.1078656485194962 |
* |-0.08009826082823523 |
* |-0.027767387691260526|
* +---------------------+
*/
println(s"RMSE: ${trainingSummary.rootMeanSquaredError}")
/**
* RMSE: 0.07920807479341203
*/
println(s"r2: ${trainingSummary.r2}")
/**
* r2: 0.998792363204057
*/
//7 模型保存与加载(发布到服务器 django 时,View 加入如下代码 + 文件)
lrModel.save("C:\\LLLLLLLLLLLLLLLLLLL\\BigData_AI\\sparkmlTest\\lrmodel2")
val load_lrModel = LinearRegressionModel.load("C:\\LLLLLLLLLLLLLLLLLLL\\BigData_AI\\sparkmlTest\\lrmodel2")
}
}

package change
import org.apache.spark.ml.classification.{BinaryLogisticRegressionSummary, LogisticRegression, LogisticRegressionModel}
import org.apache.spark.ml.linalg.Vectors
import org.apache.spark.sql.functions._
import org.apache.spark.sql.SparkSession
object logicTest {
def main(args: Array[String]): Unit = {
// 0.构建 Spark 对象
val spark = SparkSession
.builder()
.master("local") // 本地测试,否则报错 A master URL must be set in your configuration at org.apache.spark.SparkContext.
.appName("test")
.enableHiveSupport()
.getOrCreate() // 有就获取无则创建
spark.sparkContext.setCheckpointDir("C:\\LLLLLLLLLLLLLLLLLLL\\BigData_AI\\sparkmlTest") //设置文件读取、存储的目录,HDFS最佳
import spark.implicits._
//1 训练样本准备
val training = spark.createDataFrame(Seq(
(1.0, Vectors.sparse(692, Array(10, 20, 30), Array(-1.0, 1.5, 1.3))),
(0.0, Vectors.sparse(692, Array(45, 175, 500), Array(-1.0, 1.5, 1.3))),
(1.0, Vectors.sparse(692, Array(100, 200, 300), Array(-1.0, 1.5, 1.3))))).toDF("label", "features")
training.show(false)
/**
* +-----+----------------------------------+
* |label|features |
* +-----+----------------------------------+
* |1.0 |(692,[10,20,30],[-1.0,1.5,1.3]) |
* |0.0 |(692,[45,175,500],[-1.0,1.5,1.3]) |
* |1.0 |(692,[100,200,300],[-1.0,1.5,1.3])|
* +-----+----------------------------------+
*/
//2 建立逻辑回归模型
val lr = new LogisticRegression().setMaxIter(10).setRegParam(0.3).setElasticNetParam(0.8)
//2 根据训练样本进行模型训练
val lrModel = lr.fit(training)
//2 打印模型信息
println(s"Coefficients: ${lrModel.coefficients} Intercept: ${lrModel.intercept}")
/**
* Coefficients: (692,[45,175,500],[0.48944928041408226,-0.32629952027605463,-0.37649944647237077]) Intercept: 1.251662793530725
*/
println(s"Intercept: ${lrModel.intercept}")
/**
* Intercept: 1.251662793530725
*/
//3 建立多元回归模型
val mlr = new LogisticRegression().setMaxIter(10).setRegParam(0.3).setElasticNetParam(0.8).setFamily("multinomial")
//3 根据训练样本进行模型训练
val mlrModel = mlr.fit(training)
//3 打印模型信息
println(s"Multinomial coefficients: ${mlrModel.coefficientMatrix}")
/**
* Multinomial coefficients: 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 ... (692 total)
* 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 ...
*/
println(s"Multinomial intercepts: ${mlrModel.interceptVector}")
/**
* Multinomial intercepts: [-0.6449310568167714,0.6449310568167714]
*/
//4 测试样本
val test = spark.createDataFrame(Seq(
(1.0, Vectors.sparse(692, Array(10, 20, 30), Array(-1.0, 1.5, 1.3))),
(0.0, Vectors.sparse(692, Array(45, 175, 500), Array(-1.0, 1.5, 1.3))),
(1.0, Vectors.sparse(692, Array(100, 200, 300), Array(-1.0, 1.5, 1.3))))).toDF("label", "features")
test.show(false)
/**
* +-----+----------------------------------+
* |label|features |
* +-----+----------------------------------+
* |1.0 |(692,[10,20,30],[-1.0,1.5,1.3]) |
* |0.0 |(692,[45,175,500],[-1.0,1.5,1.3]) |
* |1.0 |(692,[100,200,300],[-1.0,1.5,1.3])|
* +-----+----------------------------------+
*/
//5 对模型进行测试
val test_predict = lrModel.transform(test)
test_predict
.select("label", "prediction", "probability", "features")
.show(false)
/**
* +-----+----------+----------------------------------------+----------------------------------+
* |label|prediction|probability |features |
* +-----+----------+----------------------------------------+----------------------------------+
* |1.0 |1.0 |[0.22241243403014824,0.7775875659698517]|(692,[10,20,30],[-1.0,1.5,1.3]) |
* |0.0 |0.0 |[0.5539602964649871,0.44603970353501293]|(692,[45,175,500],[-1.0,1.5,1.3]) |
* |1.0 |1.0 |[0.22241243403014824,0.7775875659698517]|(692,[100,200,300],[-1.0,1.5,1.3])|
* +-----+----------+----------------------------------------+----------------------------------+
*/
//6 模型摘要
val trainingSummary = lrModel.summary
//6 每次迭代目标值
val objectiveHistory = trainingSummary.objectiveHistory
println("objectiveHistory:")
objectiveHistory.foreach(loss => println(loss))
/**
* objectiveHistory:
* 0.6365141682948128
* 0.6212055977633174
* 0.5894552698389314
* 0.5844805633573479
* 0.5761098112571359
* 0.575517297029231
* 0.5754098875805627
* 0.5752562156795122
* 0.5752506337221737
* 0.5752406742715199
* 0.5752404945106846
*/
//6 计算模型指标数据
val binarySummary = trainingSummary.asInstanceOf[BinaryLogisticRegressionSummary]
//6 AUC指标
val roc = binarySummary.roc
roc.show(false)
/**
* +---+---+
* |FPR|TPR|
* +---+---+
* |0.0|0.0|
* |0.0|1.0|
* |1.0|1.0|
* |1.0|1.0|
* +---+---+
*/
val AUC = binarySummary.areaUnderROC
println(s"areaUnderROC: ${binarySummary.areaUnderROC}")
//6 设置模型阈值
//不同的阈值,计算不同的F1,然后通过最大的F1找出并重设模型的最佳阈值。
val fMeasure = binarySummary.fMeasureByThreshold
fMeasure.show(false)
/**
* +-------------------+---------+
* |threshold |F-Measure|
* +-------------------+---------+
* |0.7775875659698517 |1.0 |
* |0.44603970353501293|0.8 |
* +-------------------+---------+
*/
//获得最大的F1值
val maxFMeasure = fMeasure.select(max("F-Measure")).head().getDouble(0)
//找出最大F1值对应的阈值(最佳阈值)
val bestThreshold = fMeasure.where($"F-Measure" === maxFMeasure).select("threshold").head().getDouble(0)
//并将模型的Threshold设置为选择出来的最佳分类阈值
lrModel.setThreshold(bestThreshold)
//7 模型保存与加载(发布到服务器 django 时,View 加入如下代码 + 文件)
lrModel.save("C:\\LLLLLLLLLLLLLLLLLLL\\BigData_AI\\sparkmlTest\\lrmodel")
val load_lrModel = LogisticRegressionModel.load("C:\\LLLLLLLLLLLLLLLLLLL\\BigData_AI\\sparkmlTest\\lrmodel")
}
}

保存好的模型
比较容易理解的一个算法
package tree
import org.apache.spark.ml.feature._
import org.apache.spark.ml.classification.RandomForestClassifier
import org.apache.spark.ml.classification.{ DecisionTreeClassifier, DecisionTreeClassificationModel }
import org.apache.spark.ml.classification.GBTClassifier
import org.apache.spark.ml.evaluation.{ MulticlassClassificationEvaluator, BinaryClassificationEvaluator }
import org.apache.spark.ml.{ Pipeline, PipelineModel }
import org.apache.spark.sql.SparkSession
object tree {
def main(args: Array[String]): Unit = {
// 0.构建 Spark 对象
val spark = SparkSession
.builder()
.master("local") // 本地测试,否则报错 A master URL must be set in your configuration at org.apache.spark.SparkContext.
.appName("test")
.enableHiveSupport()
.getOrCreate() // 有就获取无则创建
spark.sparkContext.setCheckpointDir("C:\\LLLLLLLLLLLLLLLLLLL\\BigData_AI\\sparkmlTest") //设置文件读取、存储的目录,HDFS最佳
import spark.implicits._
//1 训练样本准备
val data = spark.read.format("libsvm").load("C:\\LLLLLLLLLLLLLLLLLLL\\BigData_AI\\sparkmlTest\\sample_libsvm_data.txt")
data.show
//2 标签进行索引编号
val labelIndexer = new StringIndexer().
setInputCol("label").
setOutputCol("indexedLabel").
fit(data)
// 对离散特征进行标记索引,以用来确定哪些特征是离散特征
// 如果一个特征的值超过4个以上,该特征视为连续特征,否则将会标记得离散特征并进行索引编号
val featureIndexer = new VectorIndexer().
setInputCol("features").
setOutputCol("indexedFeatures").
setMaxCategories(4).
fit(data)
//3 样本划分
val Array(trainingData, testData) = data.randomSplit(Array(0.7, 0.3))
//4 训练决策树模型
val dt = new DecisionTreeClassifier().
setLabelCol("indexedLabel").
setFeaturesCol("indexedFeatures")
//4 训练随机森林模型
val rf = new RandomForestClassifier()
.setLabelCol("indexedLabel")
.setFeaturesCol("indexedFeatures")
.setNumTrees(10)
//4 训练GBDT模型
val gbt = new GBTClassifier()
.setLabelCol("indexedLabel")
.setFeaturesCol("indexedFeatures")
.setMaxIter(10)
//5 将索引的标签转回原始标签
val labelConverter = new IndexToString().
setInputCol("prediction").
setOutputCol("predictedLabel").
setLabels(labelIndexer.labels)
//6 构建Pipeline
val pipeline1 = new Pipeline().
setStages(Array(labelIndexer, featureIndexer, dt, labelConverter))
val pipeline2 = new Pipeline().
setStages(Array(labelIndexer, featureIndexer, rf, labelConverter))
val pipeline3 = new Pipeline().
setStages(Array(labelIndexer, featureIndexer, gbt, labelConverter))
//7 Pipeline开始训练
val model1 = pipeline1.fit(trainingData)
val model2 = pipeline2.fit(trainingData)
val model3 = pipeline3.fit(trainingData)
//8 模型测试
val predictions = model1.transform(testData)
predictions.show(5)
//8 测试结果
predictions.select("predictedLabel", "label", "features").show(5)
//9 分类指标
// 正确率
val evaluator1 = new MulticlassClassificationEvaluator().
setLabelCol("indexedLabel").
setPredictionCol("prediction").
setMetricName("accuracy")
val accuracy = evaluator1.evaluate(predictions)
println("Test Error = " + (1.0 - accuracy))
// f1
val evaluator2 = new MulticlassClassificationEvaluator().
setLabelCol("indexedLabel").
setPredictionCol("prediction").
setMetricName("f1")
val f1 = evaluator2.evaluate(predictions)
println("f1 = " + f1)
// Precision
val evaluator3 = new MulticlassClassificationEvaluator().
setLabelCol("indexedLabel").
setPredictionCol("prediction").
setMetricName("weightedPrecision")
val Precision = evaluator3.evaluate(predictions)
println("Precision = " + Precision)
// Recall
val evaluator4 = new MulticlassClassificationEvaluator().
setLabelCol("indexedLabel").
setPredictionCol("prediction").
setMetricName("weightedRecall")
val Recall = evaluator4.evaluate(predictions)
println("Recall = " + Recall)
// AUC
val evaluator5 = new BinaryClassificationEvaluator().
setLabelCol("indexedLabel").
setRawPredictionCol("prediction").
setMetricName("areaUnderROC")
val AUC = evaluator5.evaluate(predictions)
println("Test AUC = " + AUC)
// aupr
val evaluator6 = new BinaryClassificationEvaluator().
setLabelCol("indexedLabel").
setRawPredictionCol("prediction").
setMetricName("areaUnderPR")
val aupr = evaluator6.evaluate(predictions)
println("Test aupr = " + aupr)
//10 决策树打印
val treeModel = model1.stages(2).asInstanceOf[DecisionTreeClassificationModel]
println("Learned classification tree model:\n" + treeModel.toDebugString)
//11 模型保存与加载
model1.save("C:\\LLLLLLLLLLLLLLLLLLL\\BigData_AI\\sparkmlTest\\dtmodel")
val load_treeModel = PipelineModel.load("C:\\LLLLLLLLLLLLLLLLLLL\\BigData_AI\\sparkmlTest\\dtmodel")
}
}

和 KNN 算法有点像
package juhe
import org.apache.spark.ml.clustering.{KMeans, KMeansModel}
import org.apache.spark.sql.SparkSession
object Kmeans {
def main(args: Array[String]): Unit = {
// 0.构建 Spark 对象
val spark = SparkSession
.builder()
.master("local") // 本地测试,否则报错 A master URL must be set in your configuration at org.apache.spark.SparkContext.
.appName("test")
.enableHiveSupport()
.getOrCreate() // 有就获取无则创建
spark.sparkContext.setCheckpointDir("C:\\LLLLLLLLLLLLLLLLLLL\\BigData_AI\\sparkmlTest") //设置文件读取、存储的目录,HDFS最佳
import spark.implicits._
// 读取样本
val dataset = spark.read.format("libsvm").load("C:\\LLLLLLLLLLLLLLLLLLL\\BigData_AI\\sparkmlTest\\sample_kmeans_data.txt")
dataset.show()
// 训练 a k-means model.
val kmeans = new KMeans().setK(2).setSeed(1L)
val model = kmeans.fit(dataset)
// 模型指标计算.
val WSSSE = model.computeCost(dataset)
println(s"Within Set Sum of Squared Errors = $WSSSE")
// 结果显示.
println("Cluster Centers: ")
model.clusterCenters.foreach(println)
// 模型保存与加载
model.save("C:\\LLLLLLLLLLLLLLLLLLL\\BigData_AI\\sparkmlTest\\kmmodel")
val load_treeModel = KMeansModel.load("C:\\LLLLLLLLLLLLLLLLLLL\\BigData_AI\\sparkmlTest\\kmmodel")
spark.stop()
}
}

可以用于文章主题分类
package juhe
import org.apache.spark.ml.clustering.{LDA, LDAModel}
import org.apache.spark.sql.SparkSession
object ldaTest {
def main(args: Array[String]): Unit = {
// 0.构建 Spark 对象
val spark = SparkSession
.builder()
.master("local") // 本地测试,否则报错 A master URL must be set in your configuration at org.apache.spark.SparkContext.
.appName("test")
.enableHiveSupport()
.getOrCreate() // 有就获取无则创建
// 1.读取样本
val dataset = spark.read.format("libsvm").load("C:\\LLLLLLLLLLLLLLLLLLL\\BigData_AI\\data.txt")
dataset.show()
// 2.训练 LDA model.
val lda = new LDA().setK(10).setMaxIter(10)
val model = lda.fit(dataset)
val ll = model.logLikelihood(dataset)
val lp = model.logPerplexity(dataset)
println(s"The lower bound on the log likelihood of the entire corpus: $ll")
println(s"The upper bound on perplexity: $lp")
// 3.主题 topics.
val topics = model.describeTopics(3)
println("The topics described by their top-weighted terms:")
topics.show(false)
val aa = model.topicsMatrix
model.estimatedDocConcentration
model.getTopicConcentration
// 4.测试结果.
val transformed = model.transform(dataset)
transformed.show(false)
transformed.columns
// 5.模型保存与加载
model.save("C:\\LLLLLLLLLLLLLLLLLLL\\BigData_AI\\ldamodel")
spark.stop()
}
}

Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。