当前位置:   article > 正文

决策树挑出好西瓜_西瓜决策树实验报告

西瓜决策树实验报告

目录

一、决策树算法利用了信息论的信息熵的计算。

1、理论基础

 2、手工推动一遍ID3算法选择特征的过程

3、在jupyter下实现针对西瓜数据集的ID3算法代码,并输出可视化结果。

二、用sk-learn库对西瓜数据集,分别进行ID3、C4.5和CART的算法代码实现。

1、ID3算法

2、C4.5算法

3、CART算法

三、ID3、C4.5、CART算法的优缺点及比较

1.ID3算法的缺点

2.C4.5算法的优缺点

3.CART算法的特点


一、决策树算法利用了信息论的信息熵的计算。

        决策树(Decision Tree)是在已知各种情况发生概率的基础上,通过构成决策树来求取净现值的期望值大于等于零的概率,评价项目风险,判断其可行性的决策分析方法,是直观运用概率分析的一种图解法。由于这种决策分支画成图形很像一棵树的枝干,故称决策树。在机器学习中,决策树是一个预测模型,他代表的是对象属性与对象值之间的一种映射关系。Entropy = 系统的凌乱程度,使用算法ID3,C4.5和C5.0生成树算法使用熵。最经典的决策树算法有ID3、C4.5、CART,其中ID3算法是最早被提出的,它可以处理离散属性样本的分类,C4.5和CART算法则可以处理更加复杂的分类问题。以下主要介绍ID3算法。

1、理论基础

如果我们能测量数据的复杂度,对比按不同特征分类后的数据复杂度,若按某一特征分类后复杂度减少的更多,那么这个特征即为最佳分类特征。Claude Shannon 定义了熵(entropy)和信息增益(information gain)。用熵来表示信息的复杂度,熵越大,则信息越复杂。

ID3算法的基本流程为:如果某一个特征能比其他特征更好的将训练数据集进行区分,那么将这个特征放在初始结点,依此类推,初始特征确定之后,对于初始特征每个可能的取值建立一个子结点,选择每个子结点所对应的特征,若某个子结点包含的所有样本属于同一类或所有特征对其包含的训练数据的区分能力均小于给定阈值,则该子结点为一个叶结点,其类别与该叶结点的训练数据类别最多的一致。重复上述过程直到特征用完或者所有特征的区分能力均小于给定阈值。如何衡量某个特征对训练数据集的区分能力呢,ID3算法通过信息增益来解决这个问题。

1)信息熵(information entropy)

熵定义为信息的期望值。在信息论与概率统计中,熵是表示随机变量不确定性的度量。如果待分类的事务可能划分在多个分类之中,则符号x_{i}的信息定义为:

\large l(x_{i})=-log_{2}P(x_{i})

其中P(x_{i})为选择该分类的概率。通过上式,我们可以得到所有类别的信息。样本集合D中第i类样本所占的比例P(x_{i})(i=1,2,...,n),n为样本分类的个数 ,信息熵公式如下:

\large H=-\sum_{i=1}^{n}P(x_{i})log_{2}P(x_{i})

其中n是分类的数目。熵越大,随机变量的不确定性就越大。

2)信息增益(information gain)

使用属性a对样本集D进行划分所获得的“信息增益”的计算方法是,用样本集的总信息熵减去属性a的每个分支的信息熵与权重(该分支的样本数除以总样本数)的乘积,通常,信息增益越大,意味着用属性a进行划分所获得的“纯度提升”越大。因此,优先选择信息增益最大的属性来划分。设属性a有V个可能的取值,则属性a的信息增益为:

\large G(D,A)=H(D)-H(D,A)

 2、手工推动一遍ID3算法选择特征的过程

以挑选西瓜为例,以下是西瓜样本集:

  

 1)计算样本集的总信息熵

总信息熵是指样本集未分类之前的熵,共有17个瓜,8个好瓜,9个坏瓜

H=-\frac{8}{17}log_{2}\frac{8}{17}-\frac{9}{17}log_{2}\frac{9}{17}=0.9975

2)计算各个特征的信息熵

以色泽特征为例:

色泽乌黑的熵:H_{11}=-\frac{4}{6}log_{2}\frac{4}{6}-\frac{2}{6}log_{2}\frac{2}{6}

色泽青绿的熵:H_{12}=-\frac{3}{6}log_{2}\frac{3}{6}-\frac{3}{6}log_{2}\frac{3}{6}

色泽浅白的熵:H_{13}=-\frac{1}{5}log_{2}\frac{1}{5}-\frac{4}{5}log_{2}\frac{4}{5}

色泽的信息熵:H_{1}=\frac{6}{17}*H_{11}+\frac{6}{17}*H_{12}+\frac{5}{17}*H_{13}=0.8894

色泽的信息增益:G_{1}=H-H_{1}=0.1081

同理根据以上步骤也可以计算出其他几个特征的信息增益,选择信息增益最大的属性作为根节点来进行划分,然后再对每个分支做进一步划分。各个特征的信息增益计算结果如下表所示:

特征信息增益
色泽0.10812516526536531
根蒂0.14267495956679277
敲声0.14078143361499584
纹理0.3805918973682686
脐部0.28915878284167895
触感0.006046489176565584

由上表可知,纹理特征的信息增益最大,即最优选择纹理作为根节点进行分类。

3、在jupyter下实现针对西瓜数据集的ID3算法代码,并输出可视化结果。

1)导入需要的库

  1. #导入模块
  2. import pandas as pd
  3. import numpy as np
  4. from collections import Counter
  5. from math import log2

2)数据获取与处理

  1. def getData(filePath):
  2. data = pd.read_excel(filePath)
  3. return data
  4. def dataDeal(data):
  5. dataList = np.array(data).tolist()
  6. dataSet = [element[1:] for element in dataList]
  7. return dataSet

 getData()通过pandas模块中的read_excel()函数读取样本数据。 dataDeal()函数将dataframe转换为list,并且去掉了编号列。编号列并不是西瓜的属性,事实上,如果把它当做属性,会获得最大的信息增益。

3)获取属性名称

  1. def getLabels(data):
  2. labels = list(data.columns)[1:-1]
  3. return labels

获取属性名称:纹理,色泽,根蒂,敲声,脐部,触感。

4)获取类别标记

  1. def targetClass(dataSet):
  2. classification = set([element[-1] for element in dataSet])
  3. return classification

获取一个样本是否好瓜的标记(是/否)。

5)将分支结点标记为叶结点,选择样本数最多的类作为类标记

  1. def majorityRule(dataSet):
  2. mostKind = Counter([element[-1] for element in dataSet]).most_common(1)
  3. majorityKind = mostKind[0][0]
  4. return majorityKind

6) 计算信息熵

  1. def infoEntropy(dataSet):
  2. classColumnCnt = Counter([element[-1] for element in dataSet])
  3. Ent = 0
  4. for symbol in classColumnCnt:
  5. p_k = classColumnCnt[symbol]/len(dataSet)
  6. Ent = Ent-p_k*log2(p_k)
  7. return Ent

7)子数据集构建

在某一个属性值下的数据,比如纹理为清晰的数据集。

  1. def makeAttributeData(dataSet,value,iColumn):
  2. attributeData = []
  3. for element in dataSet:
  4. if element[iColumn]==value:
  5. row = element[:iColumn]
  6. row.extend(element[iColumn+1:])
  7. attributeData.append(row)
  8. return attributeData

8)计算信息增益:

  1. def infoGain(dataSet,iColumn):
  2. Ent = infoEntropy(dataSet)
  3. tempGain = 0.0
  4. attribute = set([element[iColumn] for element in dataSet])
  5. for value in attribute:
  6. attributeData = makeAttributeData(dataSet,value,iColumn)
  7. tempGain = tempGain+len(attributeData)/len(dataSet)*infoEntropy(attributeData)
  8. Gain = Ent-tempGain
  9. return Gain

9)选择最优属性

  1. def selectOptimalAttribute(dataSet,labels):
  2. bestGain = 0
  3. sequence = 0
  4. for iColumn in range(0,len(labels)):#不计最后的类别列
  5. Gain = infoGain(dataSet,iColumn)
  6. if Gain>bestGain:
  7. bestGain = Gain
  8. sequence = iColumn
  9. print(labels[iColumn],Gain)
  10. return sequence

10)采用递归的方式建立决策树

  1. def createTree(dataSet,labels):
  2. classification = targetClass(dataSet) #获取类别种类(集合去重)
  3. if len(classification) == 1:
  4. return list(classification)[0]
  5. if len(labels) == 1:
  6. return majorityRule(dataSet)#返回样本种类较多的类别
  7. sequence = selectOptimalAttribute(dataSet,labels)
  8. print(labels)
  9. optimalAttribute = labels[sequence]
  10. del(labels[sequence])
  11. myTree = {optimalAttribute:{}}
  12. attribute = set([element[sequence] for element in dataSet])
  13. for value in attribute:
  14. print(myTree)
  15. print(value)
  16. subLabels = labels[:]
  17. myTree[optimalAttribute][value] = \
  18. createTree(makeAttributeData(dataSet,value,sequence),subLabels)
  19. return myTree

11)主函数

  1. def main():
  2. filePath = 'D:\西瓜数据集.xls'
  3. data = getData(filePath)
  4. dataSet = dataDeal(data)
  5. labels = getLabels(data)
  6. myTree = createTree(dataSet,labels)
  7. return myTree

12)生成树并打印树

  1. if __name__ == '__main__':
  2. myTree = main()
  3. print(myTree)

 13)根据输出的生成树绘制成可视化图像

  1. #绘制可视化树
  2. import matplotlib.pylab as plt
  3. import matplotlib
  4. # 能够显示中文
  5. matplotlib.rcParams['font.sans-serif'] = ['SimHei']
  6. matplotlib.rcParams['font.serif'] = ['SimHei']
  7. # 分叉节点,也就是决策节点
  8. decisionNode = dict(boxstyle="sawtooth", fc="0.8")
  9. # 叶子节点
  10. leafNode = dict(boxstyle="round4", fc="0.8")
  11. # 箭头样式
  12. arrow_args = dict(arrowstyle="<-")
  13. def plotNode(nodeTxt, centerPt, parentPt, nodeType):
  14. """
  15. 绘制一个节点
  16. :param nodeTxt: 描述该节点的文本信息
  17. :param centerPt: 文本的坐标
  18. :param parentPt: 点的坐标,这里也是指父节点的坐标
  19. :param nodeType: 节点类型,分为叶子节点和决策节点
  20. :return:
  21. """
  22. createPlot.ax1.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction',
  23. xytext=centerPt, textcoords='axes fraction',
  24. va="center", ha="center", bbox=nodeType, arrowprops=arrow_args)
  25. def getNumLeafs(myTree):
  26. """
  27. 获取叶节点的数目
  28. :param myTree:
  29. :return:
  30. """
  31. # 统计叶子节点的总数
  32. numLeafs = 0
  33. # 得到当前第一个key,也就是根节点
  34. firstStr = list(myTree.keys())[0]
  35. # 得到第一个key对应的内容
  36. secondDict = myTree[firstStr]
  37. # 递归遍历叶子节点
  38. for key in secondDict.keys():
  39. # 如果key对应的是一个字典,就递归调用
  40. if type(secondDict[key]).__name__ == 'dict':
  41. numLeafs += getNumLeafs(secondDict[key])
  42. # 不是的话,说明此时是一个叶子节点
  43. else:
  44. numLeafs += 1
  45. return numLeafs
  46. def getTreeDepth(myTree):
  47. """
  48. 得到数的深度层数
  49. :param myTree:
  50. :return:
  51. """
  52. # 用来保存最大层数
  53. maxDepth = 0
  54. # 得到根节点
  55. firstStr = list(myTree.keys())[0]
  56. # 得到key对应的内容
  57. secondDic = myTree[firstStr]
  58. # 遍历所有子节点
  59. for key in secondDic.keys():
  60. # 如果该节点是字典,就递归调用
  61. if type(secondDic[key]).__name__ == 'dict':
  62. # 子节点的深度加1
  63. thisDepth = 1 + getTreeDepth(secondDic[key])
  64. # 说明此时是叶子节点
  65. else:
  66. thisDepth = 1
  67. # 替换最大层数
  68. if thisDepth > maxDepth:
  69. maxDepth = thisDepth
  70. return maxDepth
  71. def plotMidText(cntrPt, parentPt, txtString):
  72. """
  73. 计算出父节点和子节点的中间位置,填充信息
  74. :param cntrPt: 子节点坐标
  75. :param parentPt: 父节点坐标
  76. :param txtString: 填充的文本信息
  77. :return:
  78. """
  79. # 计算x轴的中间位置
  80. xMid = (parentPt[0]-cntrPt[0])/2.0 + cntrPt[0]
  81. # 计算y轴的中间位置
  82. yMid = (parentPt[1]-cntrPt[1])/2.0 + cntrPt[1]
  83. # 进行绘制
  84. createPlot.ax1.text(xMid, yMid, txtString)
  85. def plotTree(myTree, parentPt, nodeTxt):
  86. """
  87. 绘制出树的所有节点,递归绘制
  88. :param myTree: 树
  89. :param parentPt: 父节点的坐标
  90. :param nodeTxt: 节点的文本信息
  91. :return:
  92. """
  93. # 计算叶子节点数
  94. numLeafs = getNumLeafs(myTree=myTree)
  95. # 计算树的深度
  96. depth = getTreeDepth(myTree=myTree)
  97. # 得到根节点的信息内容
  98. firstStr = list(myTree.keys())[0]
  99. # 计算出当前根节点在所有子节点的中间坐标,也就是当前x轴的偏移量加上计算出来的根节点的中心位置作为x轴(比如说第一次:初始的x偏移量为:-1/2W,计算出来的根节点中心位置为:(1+W)/2W,相加得到:1/2),当前y轴偏移量作为y轴
  100. cntrPt = (plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW, plotTree.yOff)
  101. # 绘制该节点与父节点的联系
  102. plotMidText(cntrPt, parentPt, nodeTxt)
  103. # 绘制该节点
  104. plotNode(firstStr, cntrPt, parentPt, decisionNode)
  105. # 得到当前根节点对应的子树
  106. secondDict = myTree[firstStr]
  107. # 计算出新的y轴偏移量,向下移动1/D,也就是下一层的绘制y轴
  108. plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD
  109. # 循环遍历所有的key
  110. for key in secondDict.keys():
  111. # 如果当前的key是字典的话,代表还有子树,则递归遍历
  112. if isinstance(secondDict[key], dict):
  113. plotTree(secondDict[key], cntrPt, str(key))
  114. else:
  115. # 计算新的x轴偏移量,也就是下个叶子绘制的x轴坐标向右移动了1/W
  116. plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW
  117. # 打开注释可以观察叶子节点的坐标变化
  118. # print((plotTree.xOff, plotTree.yOff), secondDict[key])
  119. # 绘制叶子节点
  120. plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode)
  121. # 绘制叶子节点和父节点的中间连线内容
  122. plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))
  123. # 返回递归之前,需要将y轴的偏移量增加,向上移动1/D,也就是返回去绘制上一层的y轴
  124. plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD
  125. def createPlot(inTree):
  126. """
  127. 需要绘制的决策树
  128. :param inTree: 决策树字典
  129. :return:
  130. """
  131. # 创建一个图像
  132. fig = plt.figure(1, facecolor='white')
  133. fig.clf()
  134. axprops = dict(xticks=[], yticks=[])
  135. createPlot.ax1 = plt.subplot(111, frameon=False, **axprops)
  136. # 计算出决策树的总宽度
  137. plotTree.totalW = float(getNumLeafs(inTree))
  138. # 计算出决策树的总深度
  139. plotTree.totalD = float(getTreeDepth(inTree))
  140. # 初始的x轴偏移量,也就是-1/2W,每次向右移动1/W,也就是第一个叶子节点绘制的x坐标为:1/2W,第二个:3/2W,第三个:5/2W,最后一个:(W-1)/2W
  141. plotTree.xOff = -0.5/plotTree.totalW
  142. # 初始的y轴偏移量,每次向下或者向上移动1/D
  143. plotTree.yOff = 1.0
  144. # 调用函数进行绘制节点图像
  145. plotTree(inTree, (0.5, 1.0), '')
  146. # 绘制
  147. plt.show()
  148. if __name__ == '__main__':
  149. createPlot(mytree)

结果输出:

二、用sk-learn库对西瓜数据集,分别进行ID3、C4.5和CART的算法代码实现。

1、ID3算法

通过sk-learn库实现ID3主要是调用sklearn内置的决策树的库和画图工具

其代码实现如下:

  1. # 导入库
  2. import pandas as pd
  3. from sklearn import tree
  4. import graphviz
  5. import numpy as np
  6. #导入数据并读取
  7. df = pd.read_excel('D:\西瓜数据集.xls')
  8. df.head(10)
  9. #将特征值化为数字
  10. df['色泽']=df['色泽'].map({'浅白':1,'青绿':2,'乌黑':3})
  11. df['根蒂']=df['根蒂'].map({'稍蜷':1,'蜷缩':2,'硬挺':3})
  12. df['敲声']=df['敲声'].map({'清脆':1,'浊响':2,'沉闷':3})
  13. df['纹理']=df['纹理'].map({'清晰':1,'稍糊':2,'模糊':3})
  14. df['脐部']=df['脐部'].map({'平坦':1,'稍凹':2,'凹陷':3})
  15. df['触感'] = np.where(df['触感']=="硬滑",1,2)
  16. df['好瓜'] = np.where(df['好瓜']=="是",1,0)
  17. x_train=df[['色泽','根蒂','敲声','纹理','脐部','触感']]
  18. y_train=df['好瓜']
  19. print(df)
  20. id3=tree.DecisionTreeClassifier(criterion='entropy')
  21. id3=id3.fit(x_train,y_train)
  22. print(id3)
  23. id3=tree.DecisionTreeClassifier(criterion='entropy')
  24. id3=id3.fit(x_train,y_train)
  25. labels = ['色泽', '根蒂', '敲击', '纹理', '脐部', '触感']
  26. dot_data = tree.export_graphviz(id3
  27. ,feature_names=labels
  28. ,class_names=["好瓜","坏瓜"]
  29. ,filled=True
  30. ,rounded=True
  31. )
  32. graph = graphviz.Source(dot_data)
  33. graph

训练并进行可视化,DecisionTreeClassifier的参数为entropy时是ID3算法,默认是CART算法,没有C4.5算法。

输出显示:

 注:在终端使用pip命令安装Graphviz之后绘图报错,主要参考了No module named 'graphviz' 问题的解决_a1208896581的博客-CSDN博客

判定分类结束的依据是,若按某特征分类后出现了最终类(好瓜或坏瓜),则判定分类结束。使用这种方法,在数据比较大,特征比较多的情况下,很容易造成过拟合,于是需进行决策树枝剪,一般枝剪方法是当按某一特征分类后的熵小于设定值时,停止分类。

2、C4.5算法

C4.5算法的基本流程与ID3类似,都是通过构造决策树进行分类,但C4.5算法进行特征选择时不是通过计算信息增益完成的,而是通过信息增益比来进行特征选择。

特征A对训练数据集D的信息增益比G_{R}(D,A)定义为其信息增益G(D,A)与训练数据集D关于特征A的值的熵H_{A}(D)之比,即

G_{R}(D,A)=\frac{G(D,A)}{H_{A}(D)}

其中G(D,A)为本特征的信息增益,为H_{A}(D)本特征的信息熵。

代码实现如下:

  1. ## 实现C4.5算法
  2. def chooseBestFeatureToSplit_4(dataSet, labels):
  3. """
  4. 选择最好的数据集划分特征,根据信息增益值来计算
  5. :param dataSet:
  6. :return:
  7. """
  8. # 得到数据的特征值总数
  9. numFeatures = len(dataSet[0]) - 1
  10. # 计算出基础信息熵
  11. baseEntropy = calcShannonEnt(dataSet)
  12. # 基础信息增益为0.0
  13. bestInfoGain = 0.0
  14. # 最好的特征值
  15. bestFeature = -1
  16. # 对每个特征值进行求信息熵
  17. for i in range(numFeatures):
  18. # 得到数据集中所有的当前特征值列表
  19. featList = [example[i] for example in dataSet]
  20. # 将当前特征唯一化,也就是说当前特征值中共有多少种
  21. uniqueVals = set(featList)
  22. # 新的熵,代表当前特征值的熵
  23. newEntropy = 0.0
  24. # 遍历现在有的特征的可能性
  25. for value in uniqueVals:
  26. # 在全部数据集的当前特征位置上,找到该特征值等于当前值的集合
  27. subDataSet = splitDataSet(dataSet=dataSet, axis=i, value=value)
  28. # 计算出权重
  29. prob = len(subDataSet) / float(len(dataSet))
  30. # 计算出当前特征值的熵
  31. newEntropy += prob * calcShannonEnt(subDataSet)
  32. # 计算出“信息增益”
  33. infoGain = baseEntropy - newEntropy
  34. infoGain = infoGain/newEntropy
  35. #print('当前特征值为:' + labels[i] + ',对应的信息增益值为:' + str(infoGain)+"i等于"+str(i))
  36. #如果当前的信息增益比原来的大
  37. if infoGain > bestInfoGain:
  38. # 最好的信息增益
  39. bestInfoGain = infoGain
  40. # 新的最好的用来划分的特征值
  41. bestFeature = i
  42. #print('信息增益最大的特征为:' + labels[bestFeature])
  43. return bestFeature
  44. #判断各个样本集的各个属性是否一致
  45. def judgeEqualLabels(dataSet):
  46. """
  47. 判断数据集的各个属性集是否完全一致
  48. :param dataSet:
  49. :return:
  50. """
  51. # 计算出样本集中共有多少个属性,最后一个为类别
  52. feature_leng = len(dataSet[0]) - 1
  53. # 计算出共有多少个数据
  54. data_leng = len(dataSet)
  55. # 标记每个属性中第一个属性值是什么
  56. first_feature = ''
  57. # 各个属性集是否完全一致
  58. is_equal = True
  59. # 遍历全部属性
  60. for i in range(feature_leng):
  61. # 得到第一个样本的第i个属性
  62. first_feature = dataSet[0][i]
  63. # 与样本集中所有的数据进行对比,看看在该属性上是否都一致
  64. for _ in range(1, data_leng):
  65. # 如果发现不相等的,则直接返回False
  66. if first_feature != dataSet[_][i]:
  67. return False
  68. return is_equal
  69. #创建树的函数
  70. def createTree_4(dataSet, labels):
  71. """
  72. 创建决策树
  73. :param dataSet: 数据集
  74. :param labels: 特征标签
  75. :return:
  76. """
  77. # 拿到所有数据集的分类标签
  78. classList = [example[-1] for example in dataSet]
  79. # 统计第一个标签出现的次数,与总标签个数比较,如果相等则说明当前列表中全部都是一种标签,此时停止划分
  80. if classList.count(classList[0]) == len(classList):
  81. return classList[0]
  82. # 计算第一行有多少个数据,如果只有一个的话说明所有的特征属性都遍历完了,剩下的一个就是类别标签,或者所有的样本在全部属性上都一致
  83. if len(dataSet[0]) == 1 or judgeEqualLabels(dataSet):
  84. # 返回剩下标签中出现次数较多的那个
  85. return majorityCnt(classList)
  86. # 选择最好的划分特征,得到该特征的下标
  87. bestFeat = chooseBestFeatureToSplit_4(dataSet=dataSet, labels=labels)
  88. print(bestFeat)
  89. # 得到最好特征的名称
  90. bestFeatLabel = labels[bestFeat]
  91. print(bestFeatLabel)
  92. # 使用一个字典来存储树结构,分叉处为划分的特征名称
  93. myTree = {bestFeatLabel: {}}
  94. # 将本次划分的特征值从列表中删除掉
  95. del(labels[bestFeat])
  96. # 得到当前特征标签的所有可能值
  97. featValues = [example[bestFeat] for example in dataSet]
  98. # 唯一化,去掉重复的特征值
  99. uniqueVals = set(featValues)
  100. # 遍历所有的特征值
  101. for value in uniqueVals:
  102. # 得到剩下的特征标签
  103. subLabels = labels[:]
  104. subTree = createTree(splitDataSet(dataSet=dataSet, axis=bestFeat, value=value), subLabels)
  105. # 递归调用,将数据集中该特征等于当前特征值的所有数据划分到当前节点下,递归调用时需要先将当前的特征去除掉
  106. myTree[bestFeatLabel][value] = subTree
  107. return myTree
  108. #调用函数,看一下字典型的树
  109. mytree_4=createTree_4(data,labels)
  110. print(mytree_4)

绘制决策树

  1. #绘制可视化树
  2. import matplotlib.pylab as plt
  3. import matplotlib
  4. # 能够显示中文
  5. matplotlib.rcParams['font.sans-serif'] = ['SimHei']
  6. matplotlib.rcParams['font.serif'] = ['SimHei']
  7. # 分叉节点,也就是决策节点
  8. decisionNode = dict(boxstyle="sawtooth", fc="0.8")
  9. # 叶子节点
  10. leafNode = dict(boxstyle="round4", fc="0.8")
  11. # 箭头样式
  12. arrow_args = dict(arrowstyle="<-")
  13. def plotNode(nodeTxt, centerPt, parentPt, nodeType):
  14. """
  15. 绘制一个节点
  16. :param nodeTxt: 描述该节点的文本信息
  17. :param centerPt: 文本的坐标
  18. :param parentPt: 点的坐标,这里也是指父节点的坐标
  19. :param nodeType: 节点类型,分为叶子节点和决策节点
  20. :return:
  21. """
  22. createPlot.ax1.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction',
  23. xytext=centerPt, textcoords='axes fraction',
  24. va="center", ha="center", bbox=nodeType, arrowprops=arrow_args)
  25. def getNumLeafs(myTree):
  26. """
  27. 获取叶节点的数目
  28. :param myTree:
  29. :return:
  30. """
  31. # 统计叶子节点的总数
  32. numLeafs = 0
  33. # 得到当前第一个key,也就是根节点
  34. firstStr = list(myTree.keys())[0]
  35. # 得到第一个key对应的内容
  36. secondDict = myTree[firstStr]
  37. # 递归遍历叶子节点
  38. for key in secondDict.keys():
  39. # 如果key对应的是一个字典,就递归调用
  40. if type(secondDict[key]).__name__ == 'dict':
  41. numLeafs += getNumLeafs(secondDict[key])
  42. # 不是的话,说明此时是一个叶子节点
  43. else:
  44. numLeafs += 1
  45. return numLeafs
  46. def getTreeDepth(myTree):
  47. """
  48. 得到数的深度层数
  49. :param myTree:
  50. :return:
  51. """
  52. # 用来保存最大层数
  53. maxDepth = 0
  54. # 得到根节点
  55. firstStr = list(myTree.keys())[0]
  56. # 得到key对应的内容
  57. secondDic = myTree[firstStr]
  58. # 遍历所有子节点
  59. for key in secondDic.keys():
  60. # 如果该节点是字典,就递归调用
  61. if type(secondDic[key]).__name__ == 'dict':
  62. # 子节点的深度加1
  63. thisDepth = 1 + getTreeDepth(secondDic[key])
  64. # 说明此时是叶子节点
  65. else:
  66. thisDepth = 1
  67. # 替换最大层数
  68. if thisDepth > maxDepth:
  69. maxDepth = thisDepth
  70. return maxDepth
  71. def plotMidText(cntrPt, parentPt, txtString):
  72. """
  73. 计算出父节点和子节点的中间位置,填充信息
  74. :param cntrPt: 子节点坐标
  75. :param parentPt: 父节点坐标
  76. :param txtString: 填充的文本信息
  77. :return:
  78. """
  79. # 计算x轴的中间位置
  80. xMid = (parentPt[0]-cntrPt[0])/2.0 + cntrPt[0]
  81. # 计算y轴的中间位置
  82. yMid = (parentPt[1]-cntrPt[1])/2.0 + cntrPt[1]
  83. # 进行绘制
  84. createPlot.ax1.text(xMid, yMid, txtString)
  85. def plotTree(myTree, parentPt, nodeTxt):
  86. """
  87. 绘制出树的所有节点,递归绘制
  88. :param myTree: 树
  89. :param parentPt: 父节点的坐标
  90. :param nodeTxt: 节点的文本信息
  91. :return:
  92. """
  93. # 计算叶子节点数
  94. numLeafs = getNumLeafs(myTree=myTree)
  95. # 计算树的深度
  96. depth = getTreeDepth(myTree=myTree)
  97. # 得到根节点的信息内容
  98. firstStr = list(myTree.keys())[0]
  99. # 计算出当前根节点在所有子节点的中间坐标,也就是当前x轴的偏移量加上计算出来的根节点的中心位置作为x轴(比如说第一次:初始的x偏移量为:-1/2W,计算出来的根节点中心位置为:(1+W)/2W,相加得到:1/2),当前y轴偏移量作为y轴
  100. cntrPt = (plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW, plotTree.yOff)
  101. # 绘制该节点与父节点的联系
  102. plotMidText(cntrPt, parentPt, nodeTxt)
  103. # 绘制该节点
  104. plotNode(firstStr, cntrPt, parentPt, decisionNode)
  105. # 得到当前根节点对应的子树
  106. secondDict = myTree[firstStr]
  107. # 计算出新的y轴偏移量,向下移动1/D,也就是下一层的绘制y轴
  108. plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD
  109. # 循环遍历所有的key
  110. for key in secondDict.keys():
  111. # 如果当前的key是字典的话,代表还有子树,则递归遍历
  112. if isinstance(secondDict[key], dict):
  113. plotTree(secondDict[key], cntrPt, str(key))
  114. else:
  115. # 计算新的x轴偏移量,也就是下个叶子绘制的x轴坐标向右移动了1/W
  116. plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW
  117. # 打开注释可以观察叶子节点的坐标变化
  118. # print((plotTree.xOff, plotTree.yOff), secondDict[key])
  119. # 绘制叶子节点
  120. plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode)
  121. # 绘制叶子节点和父节点的中间连线内容
  122. plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))
  123. # 返回递归之前,需要将y轴的偏移量增加,向上移动1/D,也就是返回去绘制上一层的y轴
  124. plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD
  125. def createPlot(inTree):
  126. """
  127. 需要绘制的决策树
  128. :param inTree: 决策树字典
  129. :return:
  130. """
  131. # 创建一个图像
  132. fig = plt.figure(1, facecolor='white')
  133. fig.clf()
  134. axprops = dict(xticks=[], yticks=[])
  135. createPlot.ax1 = plt.subplot(111, frameon=False, **axprops)
  136. # 计算出决策树的总宽度
  137. plotTree.totalW = float(getNumLeafs(inTree))
  138. # 计算出决策树的总深度
  139. plotTree.totalD = float(getTreeDepth(inTree))
  140. # 初始的x轴偏移量,也就是-1/2W,每次向右移动1/W,也就是第一个叶子节点绘制的x坐标为:1/2W,第二个:3/2W,第三个:5/2W,最后一个:(W-1)/2W
  141. plotTree.xOff = -0.5/plotTree.totalW
  142. # 初始的y轴偏移量,每次向下或者向上移动1/D
  143. plotTree.yOff = 1.0
  144. # 调用函数进行绘制节点图像
  145. plotTree(inTree, (0.5, 1.0), '')
  146. # 绘制
  147. plt.show()
  148. if __name__ == '__main__':
  149. createPlot(mytree)

结果输出:

3、CART算法

CART算法构造的是二叉决策树,决策树构造出来后同样需要剪枝,才能更好的应用于未知数据的分类。CART算法在构造决策树时通过基尼系数来进行特征选择。

基尼指数(Gini index)

基尼指数是指在样本集中随机抽取两个样本,其类别标记不一致的概率。Gini(D)越小,纯度越高。其计算公式如下:

Gini(D)=\sum_{k=1}^{|y|}\sum_{k'\neq k}^{}p_{k}p_{k'}=1-\sum_{k=1}^{|y|}p_{k}^{2}

在候选特征集合中,选择划分之后使得基尼指数最小的特征作为最优划分属性。

代码实现跟ID3类似,只需要将DecisionTreeClassifier函数的参数criterion的值改为gini:

  1. clf = tree.DecisionTreeClassifier(criterion="gini") #实例化
  2. clf = clf.fit(x_train, y_train)
  3. score = clf.score(x_test, y_test)
  4. print(score)

绘制决策树:

  1. #实现决策树的可视化
  2. labels = ['色泽', '根蒂', '敲击', '纹理', '脐部', '触感']
  3. gini_data = tree.export_graphviz(gini,feature_names=labels,class_names=["好瓜","坏瓜"],filled=True,rounded=True)
  4. gini_graph = graphviz.Source(gini_data)
  5. gini_graph

结果可视化显示:

三、ID3、C4.5、CART算法的优缺点及比较

1.ID3算法的特点

1)ID3算法使用信息增益来判断特征重要性程度,信息增益越大,重要性程度越大,但是其在计算类别数较多的特征的信息增益时结果往往不准确。

2)ID3算法只能对描述属性为离散型属性的数据集构造决策树 ,无法处理连续型变量。

2.C4.5算法的特点

C4.5主要是ID3算法的改进,采用信息增益率进行特征选择。在处理连续值时,C4.5采用单点离散化的思想,用信息增益率来进行连续值特征的属性值选择。

优点:产生的分类规则易于理解,准确率较高。

缺点:首先,C4.5时间耗费大,在构造树的过程中,需要对数据集进行多次的顺序扫描和排序,因而导致算法的低效。其次,C4.5没解决回归问题。

3.CART算法的特点

CART算法效率比较高的另外一个原因是它构建的树都是二叉树,简化了树结构。CART对于分类问题选用基尼指数作为损失函数,对于回归问题使用平方误差作为损失函数,一个结点里的预测值采用的是这个结点里数据结果的平均数。

优点:CART解决了ID3和C4.5都没有解决的回归问题。

注:每次决策树分叉时,所有的特征都是随机排序的,随机种子就是random_state如果你的max_features小于你总特征数n_features,那么每个分叉必须采样,随机性很大。即使你的max_features = n_features,表现相同的分叉还是会选第一个,所以依然有随机性,sklearn的算法大多有random_state,如果需要复盘或是需要模型稳定不变必须设置。

参考文献

决策树挑出好西瓜(基于ID3、CART)_鹤引的博客-CSDN博客

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/凡人多烦事01/article/detail/354629
推荐阅读
相关标签
  

闽ICP备14008679号