当前位置:   article > 正文

通过Spark Mllib中决策树模型,训练通过其他信息来判断婚姻状况_python spark mlib 鸢尾花 决策树

python spark mlib 鸢尾花 决策树

简介

尝试通过Spark上的决策树模型来训练模型,通过人群的其他信息来判断婚姻状况

此项目基于UCI上的开放数据 adult.data
github地址:AdultBase - Truedick23

配置

  • Spark版本: spark-2.3.1-bin-hadoop2.7
  • 语言:Scala 2.11.8
  • 数据地址:Adult Data Set
  • sbt的内容:注意scalaVersion、导入的spark类jar包以及libraryDependency一定严格对应,spark-2.3仅支持scala2.11
name := "AdultBase"

version := "0.1"

scalaVersion := "2.11.8"

// https://mvnrepository.com/artifact/org.apache.spark/spark-core
libraryDependencies += "org.apache.spark" %% "spark-core" % "2.3.1"

// https://mvnrepository.com/artifact/org.apache.spark/spark-streaming
libraryDependencies += "org.apache.spark" %% "spark-streaming" % "2.3.1"

// https://mvnrepository.com/artifact/org.apache.spark/spark-mllib
libraryDependencies += "org.apache.spark" %% "spark-mllib" % "2.3.1"

// https://mvnrepository.com/artifact/org.apache.spark/spark-mllib-local
libraryDependencies += "org.apache.spark" %% "spark-mllib-local" % "2.3.1"

// https://mvnrepository.com/artifact/org.scalanlp/breeze-viz
libraryDependencies += "org.scalanlp" %% "breeze-viz" % "0.13.2"
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20

对数据进行读取和简单格式化

因为我们不是在spark shell内编写,需要首先建立一个SparkContext类来读取数据

 import org.apache.spark.SparkContext
 val sc = new SparkContext("local[*]", "AdultData")
 val raw_data = sc.textFile("./data/machine-learning-databases/adult.data")
 val data = raw_data.map(line => line.split(", ")).filter(fields => fields.length == 15)
 data.cache()
  • 1
  • 2
  • 3
  • 4
  • 5
  • 特别要注意filter的使用,用于判断每组元素个数是否为15个,可以避免数组越界(ArrayIndexOutOfBoundsException)的问题
  • cache函数将data放入内存中,减少了spark语言惰性求值带来的时间花销

数据索引化处理

首先提取各个特征值,通过distinct函数返回互不相同的数据集合

    val number_set = data.map(fields => fields(2).toInt).collect().toSet
    val education_types = data.map(fields => fields(3)).distinct.collect()
    val marriage_types = data.map(fields => fields(5)).distinct.collect()
    val family_condition_types = data.map(fields => fields(7)).distinct.collect()
    val occupation_category_types = data.map(fields => fields(1)).distinct.collect()
    val occupation_types = data.map(fields => fields(6)).distinct.collect()
    val racial_types = data.map(fields => fields(8)).distinct.collect()
    val nationality_types = data.map(fields => fields(13)).distinct.collect()
    println(marriage_types.length)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9

定义一个函数,用于根据数据集合构建映射集

def acquireDict(types: Array[String]): Map[String, Int] = {
    var idx = 0
    var dict: Map[String, Int] = Map()
    for (item <- types) {
      dict += (item -> idx)
      idx += 1
    }
    dict
  }
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9

通过调用函数来生成特征值映射

    val education_dict = acquireDict(education_types)
    val marriage_dict = acquireDict(marriage_types)
    val family_condition_dict = acquireDict(family_condition_types)
    val occupation_category_dict = acquireDict(occupation_category_types)
    val occupation_dict = acquireDict(occupation_types)
    val racial_dict = acquireDict(racial_types)
    val nationality_dict = acquireDict(nationality_types)
    val sex_dict = Map("Male" -> 1, "Female" -> 0)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8

构造LabelPoint类,用于输入数据到决策树中:

val data_set = data.map { fields =>
      val number = fields(2).toInt
      val education = education_dict(fields(3))
      val marriage = marriage_dict(fields(5))
      val family_condition = family_condition_dict(fields(7))
      val occupation_category = occupation_category_dict(fields(1))
      val occupation = occupation_dict(fields(6))
      val sex = sex_dict(fields(9))
      val race = racial_dict(fields(8))
      val nationality = nationality_dict(fields(13))
      val featureVector = Vectors.dense(education, occupation, occupation_category, sex, family_condition, race, nationality)
      val label = marriage
    LabeledPoint(label, featureVector)}
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13

如代码所示,我们将婚姻状况设置为类别标签,其他特征(教育状况、家庭情况、职业类别、职业名称、性别、种族、国籍)用来作为特征值,建立一个Vector类型来保存它们,我们打印几组看看格式:

data_set.take(10).foreach(println)
  • 1

结果如下:

(3.0,[11.0,3.0,1.0,1.0,0.0,4.0,2.0])
(4.0,[11.0,11.0,4.0,1.0,4.0,4.0,2.0])
(2.0,[1.0,9.0,6.0,1.0,0.0,4.0,2.0])
(4.0,[4.0,9.0,6.0,1.0,4.0,3.0,2.0])
(4.0,[11.0,1.0,6.0,0.0,1.0,3.0,31.0])
  • 1
  • 2
  • 3
  • 4
  • 5

我们通过ramdomSplit函数来将data_set随机分割成三组数据,分别用于训练模型、交叉检验模型和测试模型,为调用方便将其缓存到内存:

val Array(trainData, cvData, testData) = data_set.randomSplit(Array(0.8, 0.1, 0.1))
    trainData.cache
    cvData.cache
    testData.cache
  • 1
  • 2
  • 3
  • 4

首先我们建立决策树模型,其中需要提供六个参数:

  • 第一个参数为训练所需数据
  • 第二个参数numClassses为分类数,在此为代表婚姻数据种类数
  • 第三个参数categoricalFeaturesInfo用于标准化映射的形式,一般不用特别定义
  • 第四个参数impurity用于规定特征值合并的方式,在分类中有Gini(基尼不纯度)和entropy(香农熵)两种,在回归中有variance一种
    公式如下:
    impurity
  • 第五个参数maxDepth规定了决策树的最大深度,深度越深的树会更好地符合训练数据,但会消耗更多资源,而且会产生过拟合现象
  • 第六个参数maxBins规定了最多分类的种数,更多的分类数能更好分割种类,更好利用数据进行分类,但增加了计算量,其值不得少于提供的种类个数

我们先如下设置参数,进行第一次训练尝试:

val model = DecisionTree.
      trainClassifier(trainData, 7, Map[Int, Int](), "entropy", 10, 100)
  • 1
  • 2

下面我们设置一个元组量,用于保存预测值与真的类别,用MulticlassMetrics来分析模型,得到训练准确值:

	val predictionsAndLabels = cvData.map(example =>
      (model.predict(example.features), example.label)
    )
    val metrics = new MulticlassMetrics(predictionsAndLabels)
    println(metrics.precision)
  • 1
  • 2
  • 3
  • 4
  • 5

准确值如下:

0.8082788671023965
  • 1

考虑到我们只有三万多组数据,这样的准确度已经很可贵了,我们再通过三重循环来探索更优的参数设置:

    val evaluations =
    for (impurity <- Array("gini", "entropy");
         depth <- Array(1, 10, 25);
         bins <- Array(10, 50, 150))
    yield{
      val _model = DecisionTree.
        trainClassifier(trainData, 7, Map[Int, Int](), impurity, depth, bins)
      val _predictionsAndLabels = cvData.map(example =>
        (_model.predict(example.features), example.label)
      )
      val _accuracy = new MulticlassMetrics(_predictionsAndLabels).precision
      ((depth, bins, impurity), _accuracy)
    }

    evaluations.sortBy(_._2).reverse.foreach(println)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15

结果如下:

((10,150,entropy),0.8085365853658537)
((10,50,entropy),0.8085365853658537)
((10,10,entropy),0.8042682926829269)
((10,150,gini),0.8021341463414634)
((10,50,gini),0.8021341463414634)
((10,10,gini),0.8009146341463415)
((25,10,gini),0.7969512195121952)
((25,150,entropy),0.7957317073170732)
((25,50,entropy),0.7957317073170732)
((25,10,entropy),0.7942073170731707)
((25,150,gini),0.7905487804878049)
((25,50,gini),0.7905487804878049)
((1,150,entropy),0.7024390243902439)
((1,50,entropy),0.7024390243902439)
((1,10,entropy),0.7024390243902439)
((1,150,gini),0.7024390243902439)
((1,50,gini),0.7024390243902439)
((1,10,gini),0.7024390243902439)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18

可以看到(10,150,entropy)这一对组合还不错,虽然准确度还是很可怜,后期再尝试改进一下模型

参考资料

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/不正经/article/detail/643570
推荐阅读
相关标签
  

闽ICP备14008679号