当前位置:   article > 正文

决策树模型 - (ID3算法、C4.5算法) - Python代码实现_用id3算法建立嫁不嫁的决策树模型

用id3算法建立嫁不嫁的决策树模型

目录

算法简介

信息熵(Entropy)

信息增益(Information gain) - ID3算法

信息增益率(gain ratio) - C4.5算法

源数据

代码实现 - ID3算法

代码实现 - C4.5算法

画决策树代码-treePlotter


算法简介

决策数(Decision Tree)在机器学习中也是比较常见的一种算法,属于监督学习中的一种。其中ID3算法是以信息熵和信息增益作为衡量标准的分类算法。

信息熵(Entropy)

熵的概念主要是指信息的混乱程度,变量的不确定性越大,熵的值也就越大,熵的公式可以表示为:

信息增益(Information gain) - ID3算法

信息增益指的是根据特征划分数据前后熵的变化,可以用下面的公式表示: 

根据不同特征分类后熵的变化不同,信息增益也不同,信息增益越大,区分样本的能力越强,越具有代表性。 这是一种自顶向下的贪心策略,即在ID3中根据“最大信息增益”原则选择特征。

ID3采用信息增益来选择特征,存在一个缺点,它一般会优先选择有较多属性值的特征,因为属性值多的特征会有相对较大的信息增益。(这是因为:信息增益反映的给定一个条件以后不确定性减少的程度,必然是分得越细的数据集确定性更高,也就是条件熵越小,信息增益越大)。

信息增益率(gain ratio) - C4.5算法

为了避免ID3的不足,C4.5中是用信息增益率(gain ratio)来作为选择分支的准则。对于有较多属性值的特征,信息增益率的分母Split information(S,A),我们称之为分裂信息,会稀释掉它对特征选择的影响。分裂信息(公式1)和信息增益率(公式2)的计算如下所示。

                         

源数据

收入身高长相体型是否见面
一般
一般
一般一般一般一般
一般
一般

这是一位单身女性根据对方的一些基本条件,判断是否去约会的数据,此处展示前五行。我们要通过这位女士历史的数据建立决策树模型,使得尽量给这位女性推送她比较愿意约会的异性信息。

代码实现 - ID3算法

  1. from math import log
  2. import operator
  3. import numpy as np
  4. import pandas as pd
  5. from pandas import DataFrame,Series
  6. # 计算数据的熵(entropy)-原始熵
  7. def dataentropy(data, feat):
  8. lendata=len(data) # 数据条数
  9. labelCounts={} # 数据中不同类别的条数
  10. for featVec in data:
  11. category=featVec[-1] # 每行数据的最后一个字(叶子节点)
  12. if category not in labelCounts.keys():
  13. labelCounts[category]=0
  14. labelCounts[category]+=1 # 统计有多少个类以及每个类的数量
  15. entropy=0
  16. for key in labelCounts:
  17. prob=float(labelCounts[key])/lendata # 计算单个类的熵值
  18. entropy-=prob*log(prob,2) # 累加每个类的熵值
  19. return entropy
  20. # 处理后导入数据数据
  21. def Importdata(datafile):
  22. dataa = pd.read_excel(datafile)#datafile是excel文件,所以用read_excel,如果是csv文件则用read_csv
  23. #将文本中不可直接使用的文本变量替换成数字
  24. productDict={'高':1,'一般':2,'低':3, '帅':1, '丑':3, '胖':3, '瘦':1, '是':1, '否':0}
  25. dataa['income'] = dataa['收入'].map(productDict)#将每一列中的数据按照字典规定的转化成数字
  26. dataa['hight'] = dataa['身高'].map(productDict)
  27. dataa['look'] = dataa['长相'].map(productDict)
  28. dataa['shape'] = dataa['体型'].map(productDict)
  29. dataa['is_meet'] = dataa['是否见面'].map(productDict)
  30. data = dataa.iloc[:,5:].values.tolist()#取量化后的几列,去掉文本列
  31. b = dataa.iloc[0:0,5:-1]
  32. labels = b.columns.values.tolist()#将标题中的值存入列表中
  33. return data,labels
  34. # 按某个特征value分类后的数据
  35. def splitData(data,i,value):
  36. splitData=[]
  37. for featVec in data:
  38. if featVec[i]==value:
  39. rfv =featVec[:i]
  40. rfv.extend(featVec[i+1:])
  41. splitData.append(rfv)
  42. return splitData
  43. # 选择最优的分类特征
  44. def BestSplit(data):
  45. numFea = len(data[0])-1#计算一共有多少个特征,因为最后一列一般是分类结果,所以需要-1
  46. baseEnt = dataentropy(data,-1) # 定义初始的熵,用于对比分类后信息增益的变化
  47. bestInfo = 0
  48. bestFeat = -1
  49. for i in range(numFea):
  50. featList = [rowdata[i] for rowdata in data]
  51. uniqueVals = set(featList)
  52. newEnt = 0
  53. for value in uniqueVals:
  54. subData = splitData(data,i,value)#获取按照特征value分类后的数据
  55. prob =len(subData)/float(len(data))
  56. newEnt +=prob*dataentropy(subData,i) # 按特征分类后计算得到的熵
  57. info = baseEnt - newEnt # 原始熵与按特征分类后的熵的差值,即信息增益
  58. if (info>bestInfo): # 若按某特征划分后,若infoGain大于bestInf,则infoGain对应的特征分类区分样本的能力更强,更具有代表性。
  59. bestInfo=info #将infoGain赋值给bestInf,如果出现比infoGain更大的信息增益,说明还有更好地特征分类
  60. bestFeat = i #将最大的信息增益对应的特征下标赋给bestFea,返回最佳分类特征
  61. return bestFeat
  62. #按分类后类别数量排序,取数量较大的
  63. def majorityCnt(classList):
  64. c_count={}
  65. for i in classList:
  66. if i not in c_count.keys():
  67. c_count[i]=0
  68. c_count[i]+=1
  69. ClassCount = sorted(c_count.items(),key=operator.itemgetter(1),reverse=True)#按照统计量降序排序
  70. return ClassCount[0][0]#reverse=True表示降序,因此取[0][0],即最大值
  71. #建树
  72. def createTree(data,labels):
  73. classList = [rowdata[-1] for rowdata in data] # 取每一行的最后一列,分类结果(1/0)
  74. if classList.count(classList[0])==len(classList):
  75. return classList[0]
  76. if len(data[0])==1:
  77. return majorityCnt(classList)
  78. bestFeat = BestSplit(data) #根据信息增益选择最优特征
  79. bestLab = labels[bestFeat]
  80. myTree = {bestLab:{}} #分类结果以字典形式保存
  81. del(labels[bestFeat])
  82. featValues = [rowdata[bestFeat] for rowdata in data]
  83. uniqueVals = set(featValues)
  84. for value in uniqueVals:
  85. subLabels = labels[:]
  86. myTree[bestLab][value] = createTree(splitData(data,bestFeat,value),subLabels)
  87. return myTree
  88. if __name__=='__main__':
  89. datafile = u'E:\\pythondata\\tree.xlsx'#文件所在位置,u为防止路径中有中文名称
  90. data, labels=Importdata(datafile) # 导入数据
  91. print(createTree(data, labels)) # 输出决策树模型结果

运行结果:

{'hight': {1: {'look': {1: {'income': {1: {'shape': {1: 1, 2: 1}}, 2: 1, 3: {'shape': {1: 1, 2: 0}}}}, 2: 1, 3: {'income': {1: 1, 2: 0}}}}, 2: {'income': {1: 1, 2: {'look': {1: 1, 2: 0}}, 3: 0}}, 3: {'look': {1: {'shape': {3: 0, 1: 1}}, 2: 0, 3: 0}}}}

对应的决策树:

代码实现 - C4.5算法

C4.5算法和ID3算法逻辑很相似,只是ID3算法是用信息增益来选择特征,而C4.5算法是用的信息增益率,因此对代码的影响也只有BestSplit(data)函数的定义部分,只需要加一个信息增益率的计算即可,BestSplit(data)函数定义代码更改后如下:

  1. # 选择最优的分类特征
  2. def BestSplit(data):
  3. numFea = len(data[0])-1#计算一共有多少个特征,因为最后一列一般是分类结果,所以需要-1
  4. baseEnt = dataentropy(data,-1) # 定义初始的熵,用于对比分类后信息增益的变化
  5. bestGainRate = 0
  6. bestFeat = -1
  7. for i in range(numFea):
  8. featList = [rowdata[i] for rowdata in data]
  9. uniqueVals = set(featList)
  10. newEnt = 0
  11. for value in uniqueVals:
  12. subData = splitData(data,i,value)#获取按照特征value分类后的数据
  13. prob =len(subData)/float(len(data))
  14. newEnt +=prob*dataentropy(subData,i) # 按特征分类后计算得到的熵
  15. info = baseEnt - newEnt # 原始熵与按特征分类后的熵的差值,即信息增益
  16. splitonfo = dataentropy(subData,i) #分裂信息
  17. if splitonfo == 0:#若特征值相同(eg:长相这一特征的值都是帅),即splitonfo和info均为0,则跳过该特征
  18. continue
  19. GainRate = info/splitonfo #计算信息增益率
  20. if (GainRate>bestGainRate): # 若按某特征划分后,若infoGain大于bestInf,则infoGain对应的特征分类区分样本的能力更强,更具有代表性。
  21. bestGainRate=GainRate #将infoGain赋值给bestInf,如果出现比infoGain更大的信息增益,说明还有更好地特征分类
  22. bestFeat = i #将最大的信息增益对应的特征下标赋给bestFea,返回最佳分类特征
  23. return bestFeat

运行结果:

{'hight': {1: {'look': {1: {'income': {1: {'shape': {0: 0, 1: 1}}, 2: 1, 3: {'shape': {0: 0, 1: 1}}}}, 2: 1, 3: {'shape': {0: 0, 1: 1}}}}, 2: {'shape': {0: 0, 1: 1}}, 3: {'shape': {1: 0, 3: {'look': {0: 0, 1: 1}}}}}}

画决策树代码-treePlotter

决策树可以代码实现的,不需要按照运行结果一点一点手动画图。

  1. import treePlotter
  2. treePlotter.createPlot(myTree)

其中treePlotter模块是如下一段代码,可以保存为.py文件,放在Python/Lib/site-package目录下,然后用的时候import 【文件名】就可以了。

treePlotter模块代码:

  1. #绘制决策树
  2. import matplotlib.pyplot as plt
  3. # 定义文本框和箭头格式,boxstyle用于指定边框类型,color表示填充色
  4. decisionNode = dict(boxstyle="round4", color='#ccccff') #定义判断结点为圆角长方形,填充浅蓝色
  5. leafNode = dict(boxstyle="circle", color='#66ff99') #定义叶结点为圆形,填充绿色
  6. arrow_args = dict(arrowstyle="<-", color='ffcc00') #定义箭头及颜色
  7. #绘制带箭头的注释
  8. def plotNode(nodeTxt, centerPt, parentPt, nodeType):
  9. createPlot.ax1.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction',
  10. xytext=centerPt, textcoords='axes fraction',
  11. va="center", ha="center", bbox=nodeType, arrowprops=arrow_args)
  12. #计算叶结点数
  13. def getNumLeafs(myTree):
  14. numLeafs = 0
  15. firstStr = myTree.keys()[0]
  16. secondDict = myTree[firstStr]
  17. for key in secondDict.keys():
  18. if type(secondDict[key]).__name__ == 'dict':
  19. numLeafs += getNumLeafs(secondDict[key])
  20. else:
  21. numLeafs += 1
  22. return numLeafs
  23. #计算树的层数
  24. def getTreeDepth(myTree):
  25. maxDepth = 0
  26. firstStr = myTree.keys()[0]
  27. secondDict = myTree[firstStr]
  28. for key in secondDict.keys():
  29. if type(secondDict[key]).__name__ == 'dict':
  30. thisDepth = 1 + getTreeDepth(secondDict[key])
  31. else:
  32. thisDepth = 1
  33. if thisDepth > maxDepth:
  34. maxDepth = thisDepth
  35. return maxDepth
  36. #在父子结点间填充文本信息
  37. def plotMidText(cntrPt, parentPt, txtString):
  38. xMid = (parentPt[0] - cntrPt[0]) / 2.0 + cntrPt[0]
  39. yMid = (parentPt[1] - cntrPt[1]) / 2.0 + cntrPt[1]
  40. createPlot.ax1.text(xMid, yMid, txtString, va="center", ha="center", rotation=30)
  41. def plotTree(myTree, parentPt, nodeTxt):
  42. numLeafs = getNumLeafs(myTree)
  43. depth = getTreeDepth(myTree)
  44. firstStr = myTree.keys()[0]
  45. cntrPt = (plotTree.xOff + (1.0 + float(numLeafs)) / 2.0 / plotTree.totalW, plotTree.yOff)
  46. plotMidText(cntrPt, parentPt, nodeTxt) #在父子结点间填充文本信息
  47. plotNode(firstStr, cntrPt, parentPt, decisionNode) #绘制带箭头的注释
  48. secondDict = myTree[firstStr]
  49. plotTree.yOff = plotTree.yOff - 1.0 / plotTree.totalD
  50. for key in secondDict.keys():
  51. if type(secondDict[key]).__name__ == 'dict':
  52. plotTree(secondDict[key], cntrPt, str(key))
  53. else:
  54. plotTree.xOff = plotTree.xOff + 1.0 / plotTree.totalW
  55. plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode)
  56. plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))
  57. plotTree.yOff = plotTree.yOff + 1.0 / plotTree.totalD
  58. def createPlot(inTree):
  59. fig = plt.figure(1, facecolor='white')
  60. fig.clf()
  61. axprops = dict(xticks=[], yticks=[])
  62. createPlot.ax1 = plt.subplot(111, frameon=False, **axprops)
  63. plotTree.totalW = float(getNumLeafs(inTree))
  64. plotTree.totalD = float(getTreeDepth(inTree))
  65. plotTree.xOff = -0.5 / plotTree.totalW;
  66. plotTree.yOff = 1.0;
  67. plotTree(inTree, (0.5, 1.0), '')
  68. plt.show()

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/从前慢现在也慢/article/detail/398777
推荐阅读
相关标签
  

闽ICP备14008679号