赞
踩
package com.dream.ml.features import org.apache.spark.ml.classification.{LogisticRegression, LogisticRegressionModel, LogisticRegressionTrainingSummary} import org.apache.spark.ml.feature.{StandardScaler, StringIndexer, VectorAssembler} import org.apache.spark.sql.{DataFrame, SparkSession} import org.apache.spark.sql.types.{DoubleType, StringType, StructType} /** * @title: IrisFeaturesDemo * @projectName SparkStudy * @description: TODO * @author MXH * @date 2023/9/3 10:44 */ object IrisFeaturesDemo { def main(args: Array[String]): Unit = { // 1.创建SparkSQL的运行环境 val spark: SparkSession = SparkSession.builder() .appName(this.getClass.getSimpleName.stripSuffix("$")) .master("local[4]") .config("spark.sql.shuffle.partitions", 4) .getOrCreate() // 导入隐式转换 import spark.implicits._ // 2.加载鸢尾花数据集iris.data val irisSchema: StructType = new StructType() .add("sepal_length", dataType = DoubleType, nullable = true) .add("sepal_width", dataType = DoubleType, nullable = true) .add("petal_length", dataType = DoubleType, nullable = true) .add("petal_width", dataType = DoubleType, nullable = true) .add("class", dataType = StringType, nullable = true) val rawIrisDF: DataFrame = spark.read // 查看csv源码来设置options .option("sep", ",") // 分隔符 .option("header", "false") // 默认为false .option("inferSchema", "false") // 默认为false // 当csv文件首行不是列名称时,需要自定义Schema .schema(irisSchema) .csv("datas/iris/iris.data") /* root |-- sepal_length: double (nullable = true) |-- sepal_width: double (nullable = true) |-- petal_length: double (nullable = true) |-- petal_width: double (nullable = true) |-- class: string (nullable = true) */ rawIrisDF.printSchema() rawIrisDF.show(10,truncate = false) // 3.数据转换 // 3.1 转换1:将萼片长度、宽度及花瓣长度、宽度封装到一个特征向量中 // https://spark.apache.org/docs/latest/ml-features.html#vectorassembler val assembler = new VectorAssembler() // 把需要组合的列名称枚举出来进行组合 // .setInputCols(Array("hour", "mobile", "userFeatures")) // 本例中除最后一列不要外,需要其他列 .setInputCols(rawIrisDF.columns.dropRight(1)) .setOutputCol("features") // 添加一列,类型为向量 val df1 = assembler.transform(rawIrisDF) /* root |-- sepal_length: double (nullable = true) |-- sepal_width: double (nullable = true) |-- petal_length: double (nullable = true) |-- petal_width: double (nullable = true) |-- class: string (nullable = true) |-- features: vector (nullable = true) */ df1.printSchema() df1.show(10, truncate = false) // 3.2 转换2: 转换类别字符串数据为数值数据 // https://spark.apache.org/docs/latest/ml-features.html#stringindexer val indexer = new StringIndexer() .setInputCol("class") // 需要索引化的列名 .setOutputCol("label") // 数据索引化后列名 .fit(df1) val df2 = indexer.transform(df1) /* root |-- sepal_length: double (nullable = true) |-- sepal_width: double (nullable = true) |-- petal_length: double (nullable = true) |-- petal_width: double (nullable = true) |-- class: string (nullable = true) |-- features: vector (nullable = true) // 特征 x |-- label: double (nullable = false) // 标签 y 算法: y = kx + b */ df2.printSchema() df2.show(10,truncate = false) // 3.3 数据标准化 // 在实际开发中,特征数据features经常需要进行各个转换操作,比如归一化、标准化和正则化等 // 为什么要进行归一化、标准化或正则化等数据预处理?原因在于不同维度特征值,值的范围跨度不一样,导致模型异常 // 比如影响房价的因素有地段、面积、楼层、新旧等特征数据 // 数据标准化 https://spark.apache.org/docs/latest/ml-features.html#standardscaler val scaler = new StandardScaler() .setInputCol("features") .setOutputCol("scale_features") .setWithStd(true) // 使用标准差缩放 .setWithMean(false) //使用平均值缩放 // Compute summary statistics by fitting the StandardScaler. val scalerModel = scaler.fit(df2) // Normalize each feature to have unit standard deviation. val irisDF = scalerModel.transform(df2) irisDF.show(10, truncate = false) // 4.分类算法 // https://spark.apache.org/docs/latest/ml-classification-regression.html /* 分类算法有: (1)决策树(DecisionTree)分类算法 (2)朴素贝叶斯(Native Bayes)分类算法-适合构建文本数据特征分类,比如垃圾邮件、情感分析 (3)逻辑回归(Logistics Regression)分类算法 (4)线性支持向量机(Linear SVM)分类算法 (5)神经网络相关分类算法,比如多层感知机算法-》深度学习算法 (6)集成融合算法,随机森林(RF)分类算法、梯度提升树(GBT)算法 Classification Logistic regression Binomial logistic regression Multinomial logistic regression Decision tree classifier Random forest classifier Gradient-boosted tree classifier Multilayer perceptron classifier Linear Support Vector Machine One-vs-Rest classifier (a.k.a. One-vs-All) Naive Bayes Factorization machines classifier */ // 4.1 创建模型 val lr: LogisticRegression = new LogisticRegression() // 设置特征值列名称和标签列名称 .setFeaturesCol("scale_features") // x -> 特征 .setLabelCol("label") // y-> 标签 // 每个算法都有自己超参数要设置,合理设置,获取较好的模型 .setMaxIter(20) // 模型训练迭代次数,默认100 .setStandardization(true) //是否数据标准化,默认为true .setFamily("multinomial") //设置分类属于二分类(标签label只有2个值)还是多分类(大于2个值) .setRegParam(0) // 正则化参数,默认值为0.0 优化 .setElasticNetParam(0) // 弹性化参数,优化 // 4.2训练模型 val lrModel: LogisticRegressionModel = lr.fit(irisDF) // 4.3评估模型 println(s"多分类混淆矩阵: ${lrModel.coefficientMatrix}") val summary: LogisticRegressionTrainingSummary = lrModel.summary // 准确度: 0.9733333333333334 println(s"准确度: ${summary.accuracy}") // 精确度是针对每一个分类的 println(s"精确度: ${summary.precisionByLabel.mkString(",")}") // 关闭环境 // spark.close() 源码中调用stop() spark.stop() } }
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。