当前位置:   article > 正文

机器学习之决策树分类算法_人工智能题学生买电脑题决策树

人工智能题学生买电脑题决策树

决策树和随机森林:决策树和随机森林都是非线性有监督的分类模型。

决策树是一种树形结构,树内部每个节点表示一个属性上的测试,每个分支代表一个测试输出,每个叶子节点代表一个分类类别。通过训练数据构建决策树,可以对未知数据进行分类,

随机森林是由多个决策树组成,随机森林中每一棵决策树之间没有关联,在得到一个随机森林后,当有新的样本进入的时候,随机森林中的每一棵决策树分别进行判断,分析出该样本属于哪一类,然后最后看哪一类被选择最多,就预测该样本属于这一类。

一、认识决策树

术语:
根节点:最顶层的分类条件
叶节点:代表每一个类别号
中间节点:中间分类条件
分支:代表每一个条件的输出
二叉树:每一个节点上有两个分支
多叉树:每一个节点上至少有两个分支

二、决策树分类原则

如下图数据集:

要按照前4列的信息,使用决策树预测车祸的发生,如何选择根节点呢?

通过以上发现,只有使用天气作为根节点时,决策树的高度相对低而且树的两边能将数据分类的更彻底(其他列作为根节点时,树两边分类不纯粹,都有天气)。

决策树的生成原则:数据不断分裂的递归过程,每一次分裂,尽可能让类别一样的数据在树的一边,当树的叶子节点的数据都是一类的时候,则停止分类。这样分类的数据,每个节点两边的数据不同,将相同的数据分类到树的一侧,能将数据分类的更纯粹。减少树的高度和训练决策树的迭代次数注意训练决策树的数据集要离散化,不然有可能造成训练出来的树有些节点的分支特别多,容易造成过拟合。

三、选择分类条件

下图:

上图中箱子①中有100个红球。箱子②中有50个红球和50个黑球。箱子③中有10个红球和30个篮球,60个绿球。箱子④中各个颜色均有10中球。发现箱子①中球类单一,信息量少,比较纯粹,箱子④中,球的类别最多,相对①来说比较混乱,信息量大。

如何量化以上每个箱子中信息的纯粹和混乱(信息量的大小)指标,可以使用信息熵或者基尼系数。

1). 信息熵:信息熵是香农在1948年提出来量化信息信息量的指标,熵的定义如下:

 

其中,n代表当前类别有多少类别,pi代表当前类别中某一类别的概率。例如下图,

计算“是否购买电脑”这列的信息熵,当前类别“是否购买电脑”有2个类别,分别是“是”和“否”,那么“是否购买电脑”类别的信息熵如下:

 

 

通过以上计算可以得到,某个类别下信息量越多,熵越大,信息量越少,熵越小。假设“是否购买电脑”这列下只有“否”这个信息类别,那么“是否购买电脑”这列的信息熵为:

 

上图中,如果按照“年龄”,“收入层次”,“学生”,“信用等级”列使用决策树来预测“是否购买电脑”。如何选择决策树的根节点分类条件,就是找到某列作为分类条件时,使“是否购买电脑”这列分类的更彻底,也就是找到在某个列作为分类条件下时,“是否购买电脑”信息熵相对于没有这个分类条件时信息熵降低最大(降低最大,就是熵越低,分类越彻底),这个条件就是分类节点的分类条件。这里要使用到条件熵和信息增益。

条件熵:在某个分类条件下某个类别的信息熵叫做条件熵,类似于条件概率,在知道Y的情况下,X的不确定性。条件熵一般使用 表示,代表在Y条件下,X的信息熵。上图中假设在“年龄”条件下,“是否购买电脑”的信息熵为:“年龄”列每个类别下对应的“是否购买电脑”信息熵的和。

H(是否购买电脑|年龄)=H(是否购买电脑|青少年)+H(是否购买电脑|中年)+H(是否购买电脑|老年)

信息增益:代表熵的变化程度。分类前的信息熵减去分类后的信息熵。如特征Y对训练集D的信息增益为

在“年龄”条件下,“是否购买电脑”的信息增益为:

g(是否购买电脑,年龄)=H(是否购买电脑)-H(是否购买电脑,年龄)=0.94-0.69=0.25

由以上可知,按照“记录ID”,“年龄”,“收入层次”,“学生”,“信用等级”列使用决策树来预测“是否购买电脑”,选择分类根分类条件时步骤:

a.计算“是否购买电脑”的信息熵

b.计算在已知各个列的条件熵--H(是够购买电脑|年龄),H(是够购买电脑|收入层次),H(是够购买电脑|是否学生),H(是够购买电脑|信用等级)

c.求各个条件下的信息增益,选择信息增益大的作为分类条件。选择中间节点时,以此类推。

在构建决策树时,选择信息增益大的属性作为分类节点的方法也叫ID3分类算法。

2).基尼系数:基尼系数也可以表示样本的混乱程度。公式如下:

 

 

基尼系数越小代表信息越纯,类别越少,基尼系数越大,代表信息越混乱,类别越多。基尼增益的计算和信息增益相同。假设某列只有一类值,这列的基尼系数为0。

四、信息增益率--也叫C4.5算法

在上图中,如果将“记录ID”也作为分类条件的话,由于“记录ID”对于“是否购买电脑”列的条件熵为0,可以得到“是否购买电脑”在“记录ID”这个分类条件下信息增益最大。如果选择“记录ID”作为分类条件,容易造成分支特别多,对已有记录ID的数据可以分类出结果,对于新的记录ID有可能不能成功的分类出结果。

使用信息增益来筛选分类条件,更倾向于选择更混杂的属性。容易出现过拟合问题。可以使用信息增益率来解决这个问题。

信息增益率的公式:gr(D,A) = g(D,A)/H(A),在某个条件下信息增益除以这个条件的信息熵。

五、使用决策树来做回归或者预测值

如上图,使用学历、收入、身高、行业使用决策树来预测收到的邮件数。可以将邮件数分为几类(也可以按照其他列,将邮件数分类),比如邮件数<=23封属于A类,邮件数大于23<邮件数<=30为B类,A类中取邮件的平均数,B类中也取邮件的平均数。就是可以将某些列作为分类条件划分邮件数的类别,再取邮件数的平均数,这样可以使用决策树来预测大概值的范围。

六、决策树预剪枝和后剪枝(降低过拟合问题)

决策树对训练集有很好的分类能力,但是对于未知的测试集未必有好的分类能力,导致模型的泛化能力弱,可能发生过拟合问题,为了防止过拟合问题的出现,可以对决策树进行剪枝。剪枝分为预剪枝和后剪枝。

  • 预剪枝:就是在构建决策树的时候提前停止。比如指定树的深度最大为3,那么训练出来决策树的高度就是3,预剪枝主要是建立某些规则限制决策树的生长,降低了过拟合的风险,降低了建树的时间,但是有可能带来欠拟合问题。
  • 后剪枝:后剪枝是一种全局的优化方法,在决策树构建好之后,然后才开始进行剪枝。后剪枝的过程就是删除一些子树,这个叶子节点的标识类别通过大多数原则来确定,即属于这个叶子节点下大多数样本所属的类别就是该叶子节点的标识。选择减掉哪些子树时,可以计算没有减掉子树之前的误差和减掉子树之后的误差,如果相差不大,可以将子树减掉。一般使用后剪枝得到的结果比较好。

剪枝可以降低过拟合问题,如下图:

当来一条数据年龄为中年,信用高,孩子个数是4个时,没有办法分类。可以通过剪枝,降低过拟合问题。

七、随机森林

随机森林是由多个决策树组成。是用随机的方式建立一个森林,里面由很多决策树组成。随机森林中每一棵决策树之间都是没有关联的。得到随机森林之后,对于一个样本输入时,森林中的每一棵决策树都进行判断,看看这个样本属于哪一类,最终哪一类得到的结果最多,该输入的预测值就是哪一类。

随机森林中的决策树生成过程是对样本数据进行行采样和列采样,可以指定随机森林中的树的个数和属性个数,这样当训练集很大的时候,随机选取数据集的一部分,生成一棵树,重复上面过程,可以生成一堆形态各异的树,这些决策树构成随机森林。

随机森林中的每个决策树可以分布式的训练,解决了单棵决策树在数据量大的情况下预算量大的问题。当训练样本中出现异常数据时,决策树的抗干扰能力差,对于随机森林来说也解决了模型的抗干扰能力。

八、决策树案例

  1. package com.lxk.rf
  2. import org.apache.spark.mllib.tree.DecisionTree
  3. import org.apache.spark.mllib.util.MLUtils
  4. import org.apache.spark.{SparkContext, SparkConf}
  5. /**
  6. * 决策树
  7. */
  8. object ClassificationDecisionTree {
  9. def main(args: Array[String]): Unit = {
  10. val conf = new SparkConf()
  11. conf.setAppName("analysItem")
  12. conf.setMaster("local[3]")
  13. val sc = new SparkContext(conf)
  14. val data = MLUtils.loadLibSVMFile(sc, "汽车数据样本.txt")
  15. // Split the data into training and test sets (30% held out for testing)
  16. val splits = data.randomSplit(Array(0.7, 0.3))
  17. val (trainingData, testData) = (splits(0), splits(1))
  18. //指明分类的类别
  19. val numClasses = 2
  20. //指定离散变量,未指明的都当作连续变量处理
  21. //某列下有1,2,3类别 处理时候要自定为4类,虽然没有0,但是程序默认从0开始分类
  22. //这里天气维度有3类,但是要指明4,这里是个坑,后面以此类推
  23. val categoricalFeaturesInfo = Map[Int, Int](0 -> 4, 1 -> 4, 2 -> 3, 3 -> 3)
  24. //设定评判标准 "gini"/"entropy"
  25. val impurity = "entropy"
  26. //树的最大深度,太深运算量大也没有必要 剪枝 防止模型的过拟合!!!
  27. val maxDepth = 3
  28. //设置离散化程度,连续数据需要离散化,分成32个区间,默认其实就是32,分割的区间保证数量差不多 这个参数也可以进行剪枝
  29. val maxBins = 32
  30. //生成模型
  31. val model = DecisionTree.trainClassifier(trainingData, numClasses, categoricalFeaturesInfo,
  32. impurity, maxDepth, maxBins)
  33. //测试
  34. val labelAndPreds = testData.map { point =>
  35. val prediction = model.predict(point.features)
  36. (point.label, prediction)
  37. }
  38. val testErr = labelAndPreds.filter(r => r._1 != r._2).count().toDouble / testData.count()
  39. println("Test Error = " + testErr)
  40. println("Learned classification tree model:\n" + model.toDebugString)
  41. }
  42. }

测试结果:

  1. 2019-09-26 09:31:09 INFO RandomForest:54 - init: 2.10599667
  2. total: 2.582645014
  3. findSplits: 1.895945487
  4. findBestSplits: 0.443541994
  5. chooseSplits: 0.437320511
  6. Test Error = 0.05280419809773696
  7. Learned classification tree model:
  8. DecisionTreeModel classifier of depth 3 with 15 nodes
  9. If (feature 3 in {2.0})
  10. If (feature 2 in {2.0})
  11. If (feature 1 in {3.0,2.0})
  12. Predict: 0.0
  13. Else (feature 1 not in {3.0,2.0})
  14. Predict: 0.0
  15. Else (feature 2 not in {2.0})
  16. If (feature 4 <= 58.5)
  17. Predict: 0.0
  18. Else (feature 4 > 58.5)
  19. Predict: 1.0
  20. Else (feature 3 not in {2.0})
  21. If (feature 4 <= 58.5)
  22. If (feature 2 in {2.0})
  23. Predict: 0.0
  24. Else (feature 2 not in {2.0})
  25. Predict: 1.0
  26. Else (feature 4 > 58.5)
  27. If (feature 0 in {1.0})
  28. Predict: 1.0
  29. Else (feature 0 not in {1.0})
  30. Predict: 1.0
  31. 2019-09-26 09:31:09 INFO SparkContext:54 - Invoking stop() from shutdown hook
  32. 2019-09-26 09:31:09 INFO AbstractConnector:318 - Stopped Spark@4ecccf52{HTTP/1.1,[http/1.1]}{0.0.0.0:4040}
  33. 2019-09-26 09:31:09 INFO SparkUI:54 - Stopped Spark web UI at http://192.168.0.102:4040
  34. 2019-09-26 09:31:09 INFO MapOutputTrackerMasterEndpoint:54 - MapOutputTrackerMasterEndpoint stopped!
  35. 2019-09-26 09:31:09 INFO MemoryStore:54 - MemoryStore cleared
  36. 2019-09-26 09:31:09 INFO BlockManager:54 - BlockManager stopped
  37. 2019-09-26 09:31:09 INFO BlockManagerMaster:54 - BlockManagerMaster stopped
  38. 2019-09-26 09:31:09 INFO OutputCommitCoordinator$OutputCommitCoordinatorEndpoint:54 - OutputCommitCoordinator stopped!
  39. 2019-09-26 09:31:09 INFO SparkContext:54 - Successfully stopped SparkContext
  40. 2019-09-26 09:31:09 INFO ShutdownHookManager:54 - Shutdown hook called
  41. 2019-09-26 09:31:09 INFO ShutdownHookManager:54 - Deleting directory C:\Users\Administrator\AppData\Local\Temp\spark-c86d3935-8a47-4ff4-a0c9-86d53f1885c9
  42. Process finished with exit code 0

九、随机森林案例

 

  1. package com.bjsxt.rf
  2. import org.apache.spark.{SparkContext, SparkConf}
  3. import org.apache.spark.mllib.util.MLUtils
  4. import org.apache.spark.mllib.tree.RandomForest
  5. /**
  6. * 随机森林
  7. *
  8. */
  9. object ClassificationRandomForest {
  10. def main(args: Array[String]): Unit = {
  11. val conf = new SparkConf()
  12. conf.setAppName("analysItem")
  13. conf.setMaster("local[3]")
  14. val sc = new SparkContext(conf)
  15. //读取数据
  16. val data = MLUtils.loadLibSVMFile(sc, "汽车数据样本.txt")
  17. //将样本按73的比例分成
  18. val splits = data.randomSplit(Array(0.7, 0.3))
  19. val (trainingData, testData) = (splits(0), splits(1))
  20. //分类数
  21. val numClasses = 2
  22. // categoricalFeaturesInfo 为空,意味着所有的特征为连续型变量
  23. val categoricalFeaturesInfo = Map[Int, Int](0 -> 4, 1 -> 4, 2 -> 3, 3 -> 3)
  24. //树的个数
  25. val numTrees = 3
  26. //特征子集采样策略,auto 表示算法自主选取
  27. //"auto"根据特征数量在4个中进行选择
  28. // 1all 全部特征 。2:sqrt 把特征数量开根号后随机选择的 。 3:log2 取对数个。 4:onethird 三分之一
  29. val featureSubsetStrategy = "auto"
  30. //纯度计算 "gini"/"entropy"
  31. val impurity = "entropy"
  32. //树的最大层次
  33. val maxDepth = 3
  34. //特征最大装箱数,即连续数据离散化的区间
  35. val maxBins = 32
  36. //训练随机森林分类器,trainClassifier 返回的是 RandomForestModel 对象
  37. val model = RandomForest.trainClassifier(trainingData, numClasses, categoricalFeaturesInfo,
  38. numTrees, featureSubsetStrategy, impurity, maxDepth, maxBins)
  39. // //打印模型
  40. // println(model.toDebugString)
  41. //保存模型
  42. //model.save(sc,"汽车保险")
  43. //在测试集上进行测试
  44. val count = testData.map { point =>
  45. val prediction = model.predict(point.features)
  46. // Math.abs(prediction-point.label)
  47. (prediction, point.label)
  48. }.filter(r => r._1 != r._2).count()
  49. println("Test Error = " + count.toDouble / testData.count().toDouble)
  50. println("model " + model.toDebugString)
  51. }
  52. }

测试结果:

  1. 2019-09-26 09:56:49 INFO RandomForest:54 - Internal timing for DecisionTree:
  2. 2019-09-26 09:56:49 INFO RandomForest:54 - init: 1.531397382
  3. total: 2.022783966
  4. findSplits: 1.264719234
  5. findBestSplits: 0.456304254
  6. chooseSplits: 0.450956262
  7. 2019-09-26 09:56:49 INFO Executor:54 - Finished task 1.0 in stage 12.0 (TID 23). 746 bytes result sent to driver
  8. 2019-09-26 09:56:49 INFO TaskSetManager:54 - Finished task 1.0 in stage 12.0 (TID 23) in 36 ms on localhost (executor driver) (2/2)
  9. 2019-09-26 09:56:49 INFO TaskSchedulerImpl:54 - Removed TaskSet 12.0, whose tasks have all completed, from pool
  10. 2019-09-26 09:56:49 INFO DAGScheduler:54 - ResultStage 12 (count at ClassificationRandomForest.scala:50) finished in 0.046 s
  11. 2019-09-26 09:56:49 INFO DAGScheduler:54 - Job 8 finished: count at ClassificationRandomForest.scala:50, took 0.052419 s
  12. Test Error = 0.03447704293054865
  13. model TreeEnsembleModel classifier with 3 trees
  14. Tree 0:
  15. If (feature 2 in {2.0})
  16. If (feature 4 <= 64.5)
  17. If (feature 0 in {1.0,2.0})
  18. Predict: 0.0
  19. Else (feature 0 not in {1.0,2.0})
  20. Predict: 0.0
  21. Else (feature 4 > 64.5)
  22. If (feature 1 in {3.0})
  23. Predict: 0.0
  24. Else (feature 1 not in {3.0})
  25. Predict: 1.0
  26. Else (feature 2 not in {2.0})
  27. If (feature 4 <= 58.5)
  28. If (feature 0 in {1.0,2.0})
  29. Predict: 0.0
  30. Else (feature 0 not in {1.0,2.0})
  31. Predict: 1.0
  32. Else (feature 4 > 58.5)
  33. If (feature 0 in {1.0})
  34. Predict: 1.0
  35. Else (feature 0 not in {1.0})
  36. Predict: 1.0
  37. Tree 1:
  38. If (feature 2 in {2.0})
  39. If (feature 0 in {1.0,2.0})
  40. If (feature 3 in {2.0})
  41. Predict: 0.0
  42. Else (feature 3 not in {2.0})
  43. Predict: 0.0
  44. Else (feature 0 not in {1.0,2.0})
  45. If (feature 4 <= 56.5)
  46. Predict: 0.0
  47. Else (feature 4 > 56.5)
  48. Predict: 1.0
  49. Else (feature 2 not in {2.0})
  50. If (feature 4 <= 56.5)
  51. If (feature 3 in {2.0})
  52. Predict: 0.0
  53. Else (feature 3 not in {2.0})
  54. Predict: 1.0
  55. Else (feature 4 > 56.5)
  56. If (feature 3 in {2.0})
  57. Predict: 1.0
  58. Else (feature 3 not in {2.0})
  59. Predict: 1.0
  60. Tree 2:
  61. If (feature 4 <= 58.5)
  62. If (feature 2 in {2.0})
  63. If (feature 1 in {3.0,2.0})
  64. Predict: 0.0
  65. Else (feature 1 not in {3.0,2.0})
  66. Predict: 0.0
  67. Else (feature 2 not in {2.0})
  68. If (feature 0 in {1.0})
  69. Predict: 0.0
  70. Else (feature 0 not in {1.0})
  71. Predict: 1.0
  72. Else (feature 4 > 58.5)
  73. If (feature 3 in {2.0})
  74. If (feature 2 in {2.0})
  75. Predict: 0.0
  76. Else (feature 2 not in {2.0})
  77. Predict: 1.0
  78. Else (feature 3 not in {2.0})
  79. If (feature 2 in {2.0})
  80. Predict: 1.0
  81. Else (feature 2 not in {2.0})
  82. Predict: 1.0
  83. 2019-09-26 09:56:49 INFO SparkUI:54 - Stopped Spark web UI at http://192.168.0.102:4040
  84. 2019-09-26 09:56:49 INFO MapOutputTrackerMasterEndpoint:54 - MapOutputTrackerMasterEndpoint stopped!
  85. 2019-09-26 09:56:49 INFO MemoryStore:54 - MemoryStore cleared
  86. 2019-09-26 09:56:49 INFO BlockManager:54 - BlockManager stopped
  87. 2019-09-26 09:56:49 INFO BlockManagerMaster:54 - BlockManagerMaster stopped
  88. 2019-09-26 09:56:49 INFO OutputCommitCoordinator$OutputCommitCoordinatorEndpoint:54 - OutputCommitCoordinator stopped!
  89. 2019-09-26 09:56:49 INFO SparkContext:54 - Successfully stopped SparkContext
  90. 2019-09-26 09:56:49 INFO ShutdownHookManager:54 - Shutdown hook called
  91. 2019-09-26 09:56:49 INFO ShutdownHookManager:54 - Deleting directory C:\Users\Administrator\AppData\Local\Temp\spark-596489e0-efaa-46d0-a73e-e52a9f2387a9
  92. Process finished with exit code 0

 

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/煮酒与君饮/article/detail/901509
推荐阅读
相关标签
  

闽ICP备14008679号