当前位置:   article > 正文

机器学习(4)_决策树生成的基本流程,其三个停止条件式什么?

决策树生成的基本流程,其三个停止条件式什么?

目录

4-1决策树基本流程

4-2信息增益划分

4-3其他属性划分准则

4-4决策树剪枝

4-5缺失值处理

4-6代码实战


4-1决策树基本流程

决策树模型        

基本流程

一个典型的递归过程,策略:分而治之,自根至叶的递归过程

递归最重要的就是停止条件。

三种停止条件:

1.当前结点包含的样本完全属于同一类别,无需划分

2.当前属性集为空,或是所有样本在所有属性上取值相同,无法划分

3.当前结点包含的样本集合为空,不能划分

基本算法

4-2信息增益划分

决策树的提出在很多程度上受到了信息论的启发,所有很多准则以信息论作为判断。

”信息熵“是度量样本集合纯度最常用的一种指标

信息增益

信息增益直接以信息熵为基础,计算当前划分对信息熵所造成的变化

举例说明

4-3其他属性划分准则

信息增益实际上是偏好于分支数目更多的分类(分类越多,越精确),其实有缺陷。提出了增益率

先找信息增益高的,再找增益率高的。C4.5决策树算法采用增益率来选择划分最优属性。

基尼指数

CART决策树采用“基尼指数”来选择划分属性

直观来说,Gini反映了从数据集D中随机抽取两个样本,其类别标记不一致的概率,因此Gini越小,数据集的纯度越高

研究表明:划分选择的各种准则虽然对决策树的尺寸有较大影响,但对泛化性能的影响很有限

例如信息增益于基尼指数产生的结果,仅在约2%的情况下不同

剪枝方法和程度对决策树泛化性能的影响更为显著

在数据带噪时甚至可能将泛化性能提升25%

4-4决策树剪枝

剪枝是决策树学习算法对付“过拟合”的主要手段。决策树剪枝的基本策略有“预剪枝”和“后剪枝”。

预剪枝:在决策树生产过程中,对每个结点在划分前先进行估计,若当前节点的划分不能带来决策树泛化程度的提升,则停止划分

后剪枝:先从训练集生成一颗完整的决策树,然后自底向上的对非叶结点进行考察,若将该结点对应的子树替换成叶结点可以为决策树带来泛化性能的提高,则将该子树替换成叶结点

4-5缺失值处理

现实世界,经常会遇见不完整的样本

基本思路:“样本赋权,权重划分”

举例说明:

4-6代码实战

数据集

代码实战

导包

  1. import numpy as np
  2. import pandas as pd
  3. import sklearn.tree as st
  4. import math
  5. import matplotlib
  6. import os
  7. import matplotlib.pyplot as plt
  8. data = pd.read_csv('xigua.csv',header=None)
  9. data

  1. def calcEntropy(dataSet):
  2. mD = len(dataSet)
  3. dataLabelList = [x[-1] for x in dataSet]
  4. dataLabelSet = set(dataLabelList)
  5. ent = 0
  6. for label in dataLabelSet:
  7. mDv = dataLabelList.count(label)
  8. prop = float(mDv) / mD
  9. ent = ent - prop * np.math.log(prop, 2)
  10. return ent

拆分数据集

  1. # index - 要拆分的特征的下标
  2. # feature - 要拆分的特征
  3. # 返回值 - dataSet中index所在特征为feature,且去掉index一列的集合
  4. def splitDataSet(dataSet, index, feature):
  5. splitedDataSet = []
  6. mD = len(dataSet)
  7. for data in dataSet:
  8. if(data[index] == feature):
  9. sliceTmp = data[:index]
  10. sliceTmp.extend(data[index + 1:])
  11. splitedDataSet.append(sliceTmp)
  12. return splitedDataSet

选择最优特征

  1. # 返回值 - 最好的特征的下标
  2. def chooseBestFeature(dataSet):
  3. entD = calcEntropy(dataSet)
  4. mD = len(dataSet)
  5. featureNumber = len(dataSet[0]) - 1
  6. maxGain = -100
  7. maxIndex = -1
  8. for i in range(featureNumber):
  9. entDCopy = entD
  10. featureI = [x[i] for x in dataSet]
  11. featureSet = set(featureI)
  12. for feature in featureSet:
  13. splitedDataSet = splitDataSet(dataSet, i, feature) # 拆分数据集
  14. mDv = len(splitedDataSet)
  15. entDCopy = entDCopy - float(mDv) / mD * calcEntropy(splitedDataSet)
  16. if(maxIndex == -1):
  17. maxGain = entDCopy
  18. maxIndex = i
  19. elif(maxGain < entDCopy):
  20. maxGain = entDCopy
  21. maxIndex = i
  22. return maxIndex

寻找最多作为标签

  1. # 返回值 - 标签
  2. def mainLabel(labelList):
  3. labelRec = labelList[0]
  4. maxLabelCount = -1
  5. labelSet = set(labelList)
  6. for label in labelSet:
  7. if(labelList.count(label) > maxLabelCount):
  8. maxLabelCount = labelList.count(label)
  9. labelRec = label
  10. return labelRec

生成树

  1. def createFullDecisionTree(dataSet, featureNames, featureNamesSet, labelListParent):
  2. labelList = [x[-1] for x in dataSet]
  3. if(len(dataSet) == 0):
  4. return mainLabel(labelListParent)
  5. elif(len(dataSet[0]) == 1): #没有可划分的属性了
  6. return mainLabel(labelList) #选出最多的label作为该数据集的标签
  7. elif(labelList.count(labelList[0]) == len(labelList)): # 全部都属于同一个Label
  8. return labelList[0]
  9. bestFeatureIndex = chooseBestFeature(dataSet)
  10. bestFeatureName = featureNames.pop(bestFeatureIndex)
  11. myTree = {bestFeatureName: {}}
  12. featureList = featureNamesSet.pop(bestFeatureIndex)
  13. featureSet = set(featureList)
  14. for feature in featureSet:
  15. featureNamesNext = featureNames[:]
  16. featureNamesSetNext = featureNamesSet[:][:]
  17. splitedDataSet = splitDataSet(dataSet, bestFeatureIndex, feature)
  18. myTree[bestFeatureName][feature] = createFullDecisionTree(splitedDataSet, featureNamesNext, featureNamesSetNext, labelList)
  19. return myTree

初始化

  1. # 返回值
  2. # dataSet 数据集
  3. # featureNames 标签
  4. # featureNamesSet 列标签
  5. def readWatermelonDataSet():
  6. dataSet = data.values.tolist()
  7. featureNames =['色泽', '根蒂', '敲击', '纹理', '脐部', '触感']
  8. #获取featureNamesSet
  9. featureNamesSet = []
  10. for i in range(len(dataSet[0]) - 1):
  11. col = [x[i] for x in dataSet]
  12. colSet = set(col)
  13. featureNamesSet.append(list(colSet))
  14. return dataSet, featureNames, featureNamesSet

画图

  1. # 能够显示中文
  2. matplotlib.rcParams['font.sans-serif'] = ['SimHei']
  3. matplotlib.rcParams['font.serif'] = ['SimHei']
  4. # 分叉节点,也就是决策节点
  5. decisionNode = dict(boxstyle="sawtooth", fc="0.8")
  6. # 叶子节点
  7. leafNode = dict(boxstyle="round4", fc="0.8")
  8. # 箭头样式
  9. arrow_args = dict(arrowstyle="<-")
  10. def plotNode(nodeTxt, centerPt, parentPt, nodeType):
  11. """
  12. 绘制一个节点
  13. :param nodeTxt: 描述该节点的文本信息
  14. :param centerPt: 文本的坐标
  15. :param parentPt: 点的坐标,这里也是指父节点的坐标
  16. :param nodeType: 节点类型,分为叶子节点和决策节点
  17. :return:
  18. """
  19. createPlot.ax1.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction',
  20. xytext=centerPt, textcoords='axes fraction',
  21. va="center", ha="center", bbox=nodeType, arrowprops=arrow_args)
  22. def getNumLeafs(myTree):
  23. """
  24. 获取叶节点的数目
  25. :param myTree:
  26. :return:
  27. """
  28. # 统计叶子节点的总数
  29. numLeafs = 0
  30. # 得到当前第一个key,也就是根节点
  31. firstStr = list(myTree.keys())[0]
  32. # 得到第一个key对应的内容
  33. secondDict = myTree[firstStr]
  34. # 递归遍历叶子节点
  35. for key in secondDict.keys():
  36. # 如果key对应的是一个字典,就递归调用
  37. if type(secondDict[key]).__name__ == 'dict':
  38. numLeafs += getNumLeafs(secondDict[key])
  39. # 不是的话,说明此时是一个叶子节点
  40. else:
  41. numLeafs += 1
  42. return numLeafs
  43. def getTreeDepth(myTree):
  44. """
  45. 得到数的深度层数
  46. :param myTree:
  47. :return:
  48. """
  49. # 用来保存最大层数
  50. maxDepth = 0
  51. # 得到根节点
  52. firstStr = list(myTree.keys())[0]
  53. # 得到key对应的内容
  54. secondDic = myTree[firstStr]
  55. # 遍历所有子节点
  56. for key in secondDic.keys():
  57. # 如果该节点是字典,就递归调用
  58. if type(secondDic[key]).__name__ == 'dict':
  59. # 子节点的深度加1
  60. thisDepth = 1 + getTreeDepth(secondDic[key])
  61. # 说明此时是叶子节点
  62. else:
  63. thisDepth = 1
  64. # 替换最大层数
  65. if thisDepth > maxDepth:
  66. maxDepth = thisDepth
  67. return maxDepth
  68. def plotMidText(cntrPt, parentPt, txtString):
  69. """
  70. 计算出父节点和子节点的中间位置,填充信息
  71. :param cntrPt: 子节点坐标
  72. :param parentPt: 父节点坐标
  73. :param txtString: 填充的文本信息
  74. :return:
  75. """
  76. # 计算x轴的中间位置
  77. xMid = (parentPt[0]-cntrPt[0])/2.0 + cntrPt[0]
  78. # 计算y轴的中间位置
  79. yMid = (parentPt[1]-cntrPt[1])/2.0 + cntrPt[1]
  80. # 进行绘制
  81. createPlot.ax1.text(xMid, yMid, txtString)
  82. def plotTree(myTree, parentPt, nodeTxt):
  83. """
  84. 绘制出树的所有节点,递归绘制
  85. :param myTree: 树
  86. :param parentPt: 父节点的坐标
  87. :param nodeTxt: 节点的文本信息
  88. :return:
  89. """
  90. # 计算叶子节点数
  91. numLeafs = getNumLeafs(myTree=myTree)
  92. # 计算树的深度
  93. depth = getTreeDepth(myTree=myTree)
  94. # 得到根节点的信息内容
  95. firstStr = list(myTree.keys())[0]
  96. # 计算出当前根节点在所有子节点的中间坐标,也就是当前x轴的偏移量加上计算出来的根节点的中心位置作为x轴(比如说第一次:初始的x偏移量为:-1/2W,计算出来的根节点中心位置为:(1+W)/2W,相加得到:1/2),当前y轴偏移量作为y轴
  97. cntrPt = (plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW, plotTree.yOff)
  98. # 绘制该节点与父节点的联系
  99. plotMidText(cntrPt, parentPt, nodeTxt)
  100. # 绘制该节点
  101. plotNode(firstStr, cntrPt, parentPt, decisionNode)
  102. # 得到当前根节点对应的子树
  103. secondDict = myTree[firstStr]
  104. # 计算出新的y轴偏移量,向下移动1/D,也就是下一层的绘制y轴
  105. plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD
  106. # 循环遍历所有的key
  107. for key in secondDict.keys():
  108. # 如果当前的key是字典的话,代表还有子树,则递归遍历
  109. if isinstance(secondDict[key], dict):
  110. plotTree(secondDict[key], cntrPt, str(key))
  111. else:
  112. # 计算新的x轴偏移量,也就是下个叶子绘制的x轴坐标向右移动了1/W
  113. plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW
  114. # 打开注释可以观察叶子节点的坐标变化
  115. # print((plotTree.xOff, plotTree.yOff), secondDict[key])
  116. # 绘制叶子节点
  117. plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode)
  118. # 绘制叶子节点和父节点的中间连线内容
  119. plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))
  120. # 返回递归之前,需要将y轴的偏移量增加,向上移动1/D,也就是返回去绘制上一层的y轴
  121. plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD
  122. def createPlot(inTree):
  123. """
  124. 需要绘制的决策树
  125. :param inTree: 决策树字典
  126. :return:
  127. """
  128. # 创建一个图像
  129. fig = plt.figure(1, facecolor='white')
  130. fig.clf()
  131. axprops = dict(xticks=[], yticks=[])
  132. createPlot.ax1 = plt.subplot(111, frameon=False, **axprops)
  133. # 计算出决策树的总宽度
  134. plotTree.totalW = float(getNumLeafs(inTree))
  135. # 计算出决策树的总深度
  136. plotTree.totalD = float(getTreeDepth(inTree))
  137. # 初始的x轴偏移量,也就是-1/2W,每次向右移动1/W,也就是第一个叶子节点绘制的x坐标为:1/2W,第二个:3/2W,第三个:5/2W,最后一个:(W-1)/2W
  138. plotTree.xOff = -0.5/plotTree.totalW
  139. # 初始的y轴偏移量,每次向下或者向上移动1/D
  140. plotTree.yOff = 1.0
  141. # 调用函数进行绘制节点图像
  142. plotTree(inTree, (0.5, 1.0), '')
  143. # 绘制
  144. plt.show()

结果

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/你好赵伟/article/detail/617865
推荐阅读
相关标签
  

闽ICP备14008679号