当前位置:   article > 正文

Spark逻辑回归分类算法-鸢尾花分类

Spark逻辑回归分类算法-鸢尾花分类
package com.dream.ml.features

import org.apache.spark.ml.classification.{LogisticRegression, LogisticRegressionModel, LogisticRegressionTrainingSummary}
import org.apache.spark.ml.feature.{StandardScaler, StringIndexer, VectorAssembler}
import org.apache.spark.sql.{DataFrame, SparkSession}
import org.apache.spark.sql.types.{DoubleType, StringType, StructType}

/**
 * @title: IrisFeaturesDemo
 * @projectName SparkStudy
 * @description: TODO
 * @author MXH
 * @date 2023/9/3 10:44
 */
object IrisFeaturesDemo {
    def main(args: Array[String]): Unit = {
        // 1.创建SparkSQL的运行环境
        val spark: SparkSession = SparkSession.builder()
            .appName(this.getClass.getSimpleName.stripSuffix("$"))
            .master("local[4]")
            .config("spark.sql.shuffle.partitions", 4)
            .getOrCreate()

        // 导入隐式转换
        import spark.implicits._

        // 2.加载鸢尾花数据集iris.data
        val irisSchema: StructType = new StructType()
            .add("sepal_length", dataType = DoubleType, nullable = true)
            .add("sepal_width", dataType = DoubleType, nullable = true)
            .add("petal_length", dataType = DoubleType, nullable = true)
            .add("petal_width", dataType = DoubleType, nullable = true)
            .add("class", dataType = StringType, nullable = true)

        val rawIrisDF: DataFrame = spark.read
            // 查看csv源码来设置options
            .option("sep", ",") // 分隔符
            .option("header", "false") // 默认为false
            .option("inferSchema", "false") // 默认为false
            // 当csv文件首行不是列名称时,需要自定义Schema
            .schema(irisSchema)
            .csv("datas/iris/iris.data")

        /*
        root
         |-- sepal_length: double (nullable = true)
         |-- sepal_width: double (nullable = true)
         |-- petal_length: double (nullable = true)
         |-- petal_width: double (nullable = true)
         |-- class: string (nullable = true)
                 */
        rawIrisDF.printSchema()
        rawIrisDF.show(10,truncate = false)

        // 3.数据转换
        // 3.1 转换1:将萼片长度、宽度及花瓣长度、宽度封装到一个特征向量中
        // https://spark.apache.org/docs/latest/ml-features.html#vectorassembler
        val assembler = new VectorAssembler()
            // 把需要组合的列名称枚举出来进行组合
            // .setInputCols(Array("hour", "mobile", "userFeatures"))
            // 本例中除最后一列不要外,需要其他列
            .setInputCols(rawIrisDF.columns.dropRight(1))
            .setOutputCol("features") // 添加一列,类型为向量

        val df1 = assembler.transform(rawIrisDF)

        /*
        root
         |-- sepal_length: double (nullable = true)
         |-- sepal_width: double (nullable = true)
         |-- petal_length: double (nullable = true)
         |-- petal_width: double (nullable = true)
         |-- class: string (nullable = true)
         |-- features: vector (nullable = true)
         */
        df1.printSchema()
        df1.show(10, truncate = false)

        // 3.2 转换2: 转换类别字符串数据为数值数据
        // https://spark.apache.org/docs/latest/ml-features.html#stringindexer
        val indexer = new StringIndexer()
            .setInputCol("class") // 需要索引化的列名
            .setOutputCol("label") // 数据索引化后列名
            .fit(df1)
        val df2 = indexer.transform(df1)

        /*
        root
        |-- sepal_length: double (nullable = true)
        |-- sepal_width: double (nullable = true)
        |-- petal_length: double (nullable = true)
        |-- petal_width: double (nullable = true)
        |-- class: string (nullable = true)
        |-- features: vector (nullable = true)  // 特征 x
        |-- label: double (nullable = false) // 标签 y
        算法: y = kx + b
        */
        df2.printSchema()
        df2.show(10,truncate = false)

        // 3.3 数据标准化
        // 在实际开发中,特征数据features经常需要进行各个转换操作,比如归一化、标准化和正则化等
        // 为什么要进行归一化、标准化或正则化等数据预处理?原因在于不同维度特征值,值的范围跨度不一样,导致模型异常
        // 比如影响房价的因素有地段、面积、楼层、新旧等特征数据
        // 数据标准化 https://spark.apache.org/docs/latest/ml-features.html#standardscaler
        val scaler = new StandardScaler()
            .setInputCol("features")
            .setOutputCol("scale_features")
            .setWithStd(true)  // 使用标准差缩放
            .setWithMean(false) //使用平均值缩放

        // Compute summary statistics by fitting the StandardScaler.
        val scalerModel = scaler.fit(df2)

        // Normalize each feature to have unit standard deviation.
        val irisDF = scalerModel.transform(df2)

        irisDF.show(10, truncate = false)

        // 4.分类算法
        // https://spark.apache.org/docs/latest/ml-classification-regression.html
        /*
            分类算法有:
            (1)决策树(DecisionTree)分类算法
            (2)朴素贝叶斯(Native Bayes)分类算法-适合构建文本数据特征分类,比如垃圾邮件、情感分析
            (3)逻辑回归(Logistics Regression)分类算法
            (4)线性支持向量机(Linear SVM)分类算法
            (5)神经网络相关分类算法,比如多层感知机算法-》深度学习算法
            (6)集成融合算法,随机森林(RF)分类算法、梯度提升树(GBT)算法
            Classification
                Logistic regression
                Binomial logistic regression
                Multinomial logistic regression
                Decision tree classifier
                Random forest classifier
                Gradient-boosted tree classifier
                Multilayer perceptron classifier
                Linear Support Vector Machine
                One-vs-Rest classifier (a.k.a. One-vs-All)
                Naive Bayes
                Factorization machines classifier
         */

        // 4.1 创建模型
        val lr: LogisticRegression = new LogisticRegression()
                // 设置特征值列名称和标签列名称
                .setFeaturesCol("scale_features")  // x -> 特征
                .setLabelCol("label") // y-> 标签
                // 每个算法都有自己超参数要设置,合理设置,获取较好的模型
                .setMaxIter(20)  // 模型训练迭代次数,默认100
                .setStandardization(true) //是否数据标准化,默认为true
                .setFamily("multinomial") //设置分类属于二分类(标签label只有2个值)还是多分类(大于2个值)
                .setRegParam(0) // 正则化参数,默认值为0.0 优化
                .setElasticNetParam(0) // 弹性化参数,优化

        // 4.2训练模型
        val lrModel: LogisticRegressionModel = lr.fit(irisDF)

        // 4.3评估模型
        println(s"多分类混淆矩阵: ${lrModel.coefficientMatrix}")

        val summary: LogisticRegressionTrainingSummary = lrModel.summary
        // 准确度: 0.9733333333333334
        println(s"准确度: ${summary.accuracy}")

        // 精确度是针对每一个分类的
        println(s"精确度: ${summary.precisionByLabel.mkString(",")}")



        // 关闭环境
        // spark.close() 源码中调用stop()
        spark.stop()
    }
}

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107
  • 108
  • 109
  • 110
  • 111
  • 112
  • 113
  • 114
  • 115
  • 116
  • 117
  • 118
  • 119
  • 120
  • 121
  • 122
  • 123
  • 124
  • 125
  • 126
  • 127
  • 128
  • 129
  • 130
  • 131
  • 132
  • 133
  • 134
  • 135
  • 136
  • 137
  • 138
  • 139
  • 140
  • 141
  • 142
  • 143
  • 144
  • 145
  • 146
  • 147
  • 148
  • 149
  • 150
  • 151
  • 152
  • 153
  • 154
  • 155
  • 156
  • 157
  • 158
  • 159
  • 160
  • 161
  • 162
  • 163
  • 164
  • 165
  • 166
  • 167
  • 168
  • 169
  • 170
  • 171
  • 172
  • 173
  • 174
  • 175
  • 176
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/盐析白兔/article/detail/547099
推荐阅读
相关标签
  

闽ICP备14008679号