当前位置:   article > 正文

决策树代码实现

决策树代码

目录

一、前言

二、代码实现逻辑

        构造树模块

        选择最好特征模块

        计算熵值模块

        切分数据集模块

        当前样本中多数类别模块

三、可视化拓展

 四、结果展示+完整代码


一、前言

        本文需要读者有对决策树有一定的基础,可以参考决策树原理(决策树算法概述,熵,信息增益,信息增益率,gini系数,剪枝,回归、分类任务解决)

二、代码实现逻辑

        构造树模块

        (学过数据结构的都知道,构造树最好的方法是递归)

        1.判断是否需要建树:如果当前节点所有样本的标签相同,不需要建树,如果所有特征都用完还是没有完全分类好,则分类结果采取需要少数服从多数的策略。

        2.把最好的那个特征选出来用来当作根节点

        3.根据根节点的不同特征值进行分叉

        4.在数据集中把以根节点为特征的特征值去掉(更新数据集)

        5.在特征值里循环递归建树

        6.返回树

        注意:采用字典嵌套的形式来存储树,featLabels表示根节点的值,可以根据先后顺序把特征值存储起来。

        

  1. def crecateTree(dataset,labels,featLabels):
  2. '''
  3. :param dataset: 数据集
  4. :param labels: 判断当前节点是否需要再分
  5. :param featLabels: 根节点的值
  6. :return:
  7. '''
  8. classList = [example[-1] for example in dataset] #当前节点的所有样本的标签
  9. if classList.count(classList[0]) == len(classList): #判断所有标签是否一致
  10. return classList[0]
  11. if len(dataset[0]) == 1: #只剩下一列特征值
  12. return majorityCnt(classList) #返回主要特征
  13. bestFeature = chooseBestFeatureToSplit(dataset) #得到最好特征的索引
  14. bestFeatureLabel = labels[bestFeature]
  15. featLabels.append(bestFeatureLabel)
  16. myTree = {bestFeatureLabel:{}} #用字典来存储树,嵌套
  17. del labels[bestFeature] #删除特征值
  18. featValue = [example[bestFeature] for example in dataset] #得到根节点特征值
  19. uniqueVals = set(featValue)# 有几个不同的特征值,树分几个叉
  20. for value in uniqueVals: #递归调用
  21. sublabels = labels[:]
  22. myTree[bestFeatureLabel][value] = crecateTree(splitDataSet(dataset,bestFeature,value),sublabels,featLabels)
  23. return myTree
'
运行

        选择最好特征模块

        需要把每个特征都遍历一遍,选择信息增益最大的那个特征

        

  1. def chooseBestFeatureToSplit(dataset): #核心,熵值计算
  2. numFeatures = len(dataset[0]) - 1 #特征的个数 随便一列减去label
  3. baseEntropy = calcShannonEnt(dataset) #计算当前什么都不做的熵值
  4. bestInfoGain = 0 #最好的信息增益
  5. bestFeature = -1 #最好的特征
  6. for i in range(numFeatures):
  7. featList = [example[i] for example in dataset] #当前的特征列
  8. uniqueVals = set(featList) #特征值的类别
  9. newEntropy = 0
  10. for val in uniqueVals:
  11. subDataSet = splitDataSet(dataset,i,val)
  12. prob = len (subDataSet) / float(len(dataset))
  13. newEntropy += prob * calcShannonEnt(subDataSet) # 选择特征后的熵值
  14. infoGain = baseEntropy - newEntropy
  15. if(infoGain > bestInfoGain):
  16. bestInfoGain = infoGain
  17. bestFeature = i
  18. return bestFeature
'
运行

        计算熵值模块

        把需要的概率算出来

  1. def calcShannonEnt(dataset):#熵值计算
  2. numexamples = len(dataset)
  3. labelCount = {}
  4. for featVec in dataset:
  5. currentlabel = featVec[-1]
  6. if currentlabel not in labelCount.keys():
  7. labelCount[currentlabel] = 0
  8. labelCount[currentlabel] += 1
  9. shannonEnt = 0
  10. for key in labelCount:
  11. prop = float(labelCount[key]/numexamples) #概率值
  12. shannonEnt -= prop*log(prop,2) #熵值
  13. return shannonEnt
'
运行

        切分数据集模块

        每次进行划分后都需要数据切分,包括去掉根节点特征的那一列

  1. def splitDataSet(dataset,axis,val): #切分数据集,把根节点的那一特征列去掉
  2. retDataSet = []
  3. for featVec in dataset:
  4. if featVec[axis] == val:
  5. reducedFeatVec = featVec[:axis]
  6. reducedFeatVec.extend(featVec[axis+1:]) #用切片和拼接把第axis列切掉
  7. retDataSet.append(reducedFeatVec)
  8. return retDataSet
'
运行

        当前样本中多数类别模块

        当所有的特征都用完后还不能完全划分,采取少数服从多数策略

  1. def majorityCnt(classList): #当前多数类别是哪一个
  2. classCount = {}
  3. for vote in classList:
  4. if vote not in classCount.keys():
  5. classCount[vote] = 0
  6. classCount[vote] += 1
  7. sortedClassCount = sorted(classCount.items(),key=operator.itemgetter(1),reverse=True) #排序
  8. return sortedClassCount[0][0]
'
运行

三、可视化拓展

        这个不是重点,重要的是掌握递归建树的思想

        

  1. def getNumLeafs(myTree):
  2. numLeafs = 0
  3. firstStr = next(iter(myTree))
  4. secondDict = myTree[firstStr]
  5. for key in secondDict.keys():
  6. if type(secondDict[key]).__name__=='dict':
  7. numLeafs += getNumLeafs(secondDict[key])
  8. else:
  9. numLeafs +=1
  10. return numLeafs
  11. def getTreeDepth(myTree):
  12. maxDepth = 0
  13. firstStr = next(iter(myTree))
  14. secondDict = myTree[firstStr]
  15. for key in secondDict.keys():
  16. if type(secondDict[key]).__name__=='dict':
  17. thisDepth = 1 + getTreeDepth(secondDict[key])
  18. else:
  19. thisDepth = 1
  20. if thisDepth > maxDepth: maxDepth = thisDepth
  21. return maxDepth
  22. def plotNode(nodeTxt, centerPt, parentPt, nodeType):
  23. arrow_args = dict(arrowstyle="<-")
  24. font = FontProperties(fname=r"c:\windows\fonts\simsunb.ttf", size=14)
  25. createPlot.axl.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction',
  26. xytext=centerPt, textcoords='axes fraction',
  27. va="center", ha="center", bbox=nodeType, arrowprops=arrow_args, FontProperties=font)
  28. def plotMidText(cntrPt, parentPt, txtString):
  29. xMid = (parentPt[0]-cntrPt[0])/2.0 + cntrPt[0]
  30. yMid = (parentPt[1]-cntrPt[1])/2.0 + cntrPt[1]
  31. createPlot.axl.text(xMid, yMid, txtString, va="center", ha="center", rotation=30)
  32. def plotTree(myTree, parentPt, nodeTxt):
  33. decisionNode = dict(boxstyle="sawtooth", fc="0.8")
  34. leafNode = dict(boxstyle="round4", fc="0.8")
  35. numLeafs = getNumLeafs(myTree)
  36. depth = getTreeDepth(myTree)
  37. firstStr = next(iter(myTree))
  38. cntrPt = (plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW, plotTree.yOff)
  39. plotMidText(cntrPt, parentPt, nodeTxt)
  40. plotNode(firstStr, cntrPt, parentPt, decisionNode)
  41. secondDict = myTree[firstStr]
  42. plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD
  43. for key in secondDict.keys():
  44. if type(secondDict[key]).__name__=='dict':
  45. plotTree(secondDict[key],cntrPt,str(key))
  46. else:
  47. plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW
  48. plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode)
  49. plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))
  50. plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD
  51. def createPlot(inTree):
  52. fig = plt.figure(1, facecolor='white') #创建fig
  53. fig.clf() #清空fig
  54. axprops = dict(xticks=[], yticks=[])
  55. createPlot.ax1 = plt.subplot(111, frameon=False, **axprops) #去掉x、y轴
  56. plotTree.totalW = float(getNumLeafs(inTree)) #获取决策树叶结点数目
  57. plotTree.totalD = float(getTreeDepth(inTree)) #获取决策树层数
  58. plotTree.xOff = -0.5/plotTree.totalW; plotTree.yOff = 1.0 #x偏移
  59. plotTree(inTree, (0.5,1.0), '') #绘制决策树
  60. plt.show()
'
运行

 四、结果展示+完整代码

  1. # -*- coding: UTF-8 -*-
  2. from matplotlib.font_manager import FontProperties
  3. import matplotlib.pyplot as plt
  4. from math import log
  5. import operator
  6. def createDataSet():
  7. dataSet = [[0, 0, 0, 0, 'no'],
  8. [0, 0, 0, 1, 'no'],
  9. [0, 1, 0, 1, 'yes'],
  10. [0, 1, 1, 0, 'yes'],
  11. [0, 0, 0, 0, 'no'],
  12. [1, 0, 0, 0, 'no'],
  13. [1, 0, 0, 1, 'no'],
  14. [1, 1, 1, 1, 'yes'],
  15. [1, 0, 1, 2, 'yes'],
  16. [1, 0, 1, 2, 'yes'],
  17. [2, 0, 1, 2, 'yes'],
  18. [2, 0, 1, 1, 'yes'],
  19. [2, 1, 0, 1, 'yes'],
  20. [2, 1, 0, 2, 'yes'],
  21. [2, 0, 0, 0, 'no']]
  22. labels = ['F1-AGE', 'F2-WORK', 'F3-HOME', 'F4-LOAN']
  23. return dataSet, labels
  24. def crecateTree(dataset,labels,featLabels):
  25. '''
  26. :param dataset: 数据集
  27. :param labels: 判断当前节点是否需要再分
  28. :param featLabels: 节点的值
  29. :return:
  30. '''
  31. classList = [example[-1] for example in dataset] #当前节点的所有样本的标签
  32. if classList.count(classList[0]) == len(classList): #判断所有标签是否一致
  33. return classList[0]
  34. if len(dataset[0]) == 1: #只剩下一列特征值
  35. return majorityCnt(classList) #返回主要特征
  36. bestFeature = chooseBestFeatureToSplit(dataset) #得到最好特征的索引
  37. bestFeatureLabel = labels[bestFeature]
  38. featLabels.append(bestFeatureLabel)
  39. myTree = {bestFeatureLabel:{}} #用字典来存储树,嵌套
  40. del labels[bestFeature] #删除特征值
  41. featValue = [example[bestFeature] for example in dataset] #得到根节点特征值
  42. uniqueVals = set(featValue)# 有几个不同的特征值,树分几个叉
  43. for value in uniqueVals: #递归调用
  44. sublabels = labels[:]
  45. myTree[bestFeatureLabel][value] = crecateTree(splitDataSet(dataset,bestFeature,value),sublabels,featLabels)
  46. return myTree
  47. def majorityCnt(classList): #当前多数类别是哪一个
  48. classCount = {}
  49. for vote in classList:
  50. if vote not in classCount.keys():
  51. classCount[vote] = 0
  52. classCount[vote] += 1
  53. sortedClassCount = sorted(classCount.items(),key=operator.itemgetter(1),reverse=True) #排序
  54. return sortedClassCount[0][0]
  55. def chooseBestFeatureToSplit(dataset): #核心,熵值计算
  56. numFeatures = len(dataset[0]) - 1 #特征的个数 随便一列减去label
  57. baseEntropy = calcShannonEnt(dataset) #计算当前什么都不做的熵值
  58. bestInfoGain = 0 #最好的信息增益
  59. bestFeature = -1 #最好的特征
  60. for i in range(numFeatures):
  61. featList = [example[i] for example in dataset] #当前的特征列
  62. uniqueVals = set(featList) #特征值的类别
  63. newEntropy = 0
  64. for val in uniqueVals:
  65. subDataSet = splitDataSet(dataset,i,val)
  66. prob = len (subDataSet) / float(len(dataset))
  67. newEntropy += prob * calcShannonEnt(subDataSet) # 选择特征后的熵值
  68. infoGain = baseEntropy - newEntropy
  69. if(infoGain > bestInfoGain):
  70. bestInfoGain = infoGain
  71. bestFeature = i
  72. return bestFeature
  73. def splitDataSet(dataset,axis,val): #切分数据集,把根节点的那一特征列去掉
  74. retDataSet = []
  75. for featVec in dataset:
  76. if featVec[axis] == val:
  77. reducedFeatVec = featVec[:axis]
  78. reducedFeatVec.extend(featVec[axis+1:]) #用切片和拼接把第axis列切掉
  79. retDataSet.append(reducedFeatVec)
  80. return retDataSet
  81. def calcShannonEnt(dataset):#熵值计算
  82. numexamples = len(dataset)
  83. labelCount = {}
  84. for featVec in dataset:
  85. currentlabel = featVec[-1]
  86. if currentlabel not in labelCount.keys():
  87. labelCount[currentlabel] = 0
  88. labelCount[currentlabel] += 1
  89. shannonEnt = 0
  90. for key in labelCount:
  91. prop = float(labelCount[key]/numexamples) #概率值
  92. shannonEnt -= prop*log(prop,2) #熵值
  93. return shannonEnt
  94. def getNumLeafs(myTree):
  95. numLeafs = 0
  96. firstStr = next(iter(myTree))
  97. secondDict = myTree[firstStr]
  98. for key in secondDict.keys():
  99. if type(secondDict[key]).__name__=='dict':
  100. numLeafs += getNumLeafs(secondDict[key])
  101. else:
  102. numLeafs +=1
  103. return numLeafs
  104. def getTreeDepth(myTree):
  105. maxDepth = 0
  106. firstStr = next(iter(myTree))
  107. secondDict = myTree[firstStr]
  108. for key in secondDict.keys():
  109. if type(secondDict[key]).__name__=='dict':
  110. thisDepth = 1 + getTreeDepth(secondDict[key])
  111. else:
  112. thisDepth = 1
  113. if thisDepth > maxDepth: maxDepth = thisDepth
  114. return maxDepth
  115. def plotNode(nodeTxt, centerPt, parentPt, nodeType):
  116. arrow_args = dict(arrowstyle="<-")
  117. font = FontProperties(fname=r"c:\windows\fonts\simsunb.ttf", size=14)
  118. createPlot.axl.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction',
  119. xytext=centerPt, textcoords='axes fraction',
  120. va="center", ha="center", bbox=nodeType, arrowprops=arrow_args, FontProperties=font)
  121. def plotMidText(cntrPt, parentPt, txtString):
  122. xMid = (parentPt[0]-cntrPt[0])/2.0 + cntrPt[0]
  123. yMid = (parentPt[1]-cntrPt[1])/2.0 + cntrPt[1]
  124. createPlot.axl.text(xMid, yMid, txtString, va="center", ha="center", rotation=30)
  125. def plotTree(myTree, parentPt, nodeTxt):
  126. decisionNode = dict(boxstyle="sawtooth", fc="0.8")
  127. leafNode = dict(boxstyle="round4", fc="0.8")
  128. numLeafs = getNumLeafs(myTree)
  129. depth = getTreeDepth(myTree)
  130. firstStr = next(iter(myTree))
  131. cntrPt = (plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW, plotTree.yOff)
  132. plotMidText(cntrPt, parentPt, nodeTxt)
  133. plotNode(firstStr, cntrPt, parentPt, decisionNode)
  134. secondDict = myTree[firstStr]
  135. plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD
  136. for key in secondDict.keys():
  137. if type(secondDict[key]).__name__=='dict':
  138. plotTree(secondDict[key],cntrPt,str(key))
  139. else:
  140. plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW
  141. plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode)
  142. plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))
  143. plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD
  144. def createPlot(inTree):
  145. fig = plt.figure(1, facecolor='white') #创建fig
  146. fig.clf() #清空fig
  147. axprops = dict(xticks=[], yticks=[])
  148. createPlot.ax1 = plt.subplot(111, frameon=False, **axprops) #去掉x、y轴
  149. plotTree.totalW = float(getNumLeafs(inTree)) #获取决策树叶结点数目
  150. plotTree.totalD = float(getTreeDepth(inTree)) #获取决策树层数
  151. plotTree.xOff = -0.5/plotTree.totalW; plotTree.yOff = 1.0 #x偏移
  152. plotTree(inTree, (0.5,1.0), '') #绘制决策树
  153. plt.show()
  154. if __name__ == '__main__':
  155. dataSet,labels = createDataSet()
  156. featLabels = []
  157. myTree = crecateTree(dataSet, labels, featLabels)
  158. print(featLabels)
  159. createPlot(myTree)

 

        选择两个特征建树

        可视化结果:

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/IT小白/article/detail/1014588
推荐阅读
相关标签
  

闽ICP备14008679号