当前位置:   article > 正文

机器学习十大算法之决策树---C4.5算法

机器学习十大算法之决策树---C4.5算法

对于决策树来说,主要有两种算法:ID3算法C4.5算法。C4.5算法是对ID3的改进。

Contents

     1. C4.5的基本认识

     2. 信息增益与信息增益率

     3. C4.5算法的Python实现

一、C4.5的基本认识 

C4.5主要是在ID3的基础上改进,ID3选择(属性)树节点是选择信息增益值最大的属性作为节点。而C4.5引入了新概念“信息增益率”,C4.5是选择信息增益率最大的属性作为树节点。 
 

二、信息增益&信息增益率 

信息增益

以上公式是求信息增益率(ID3的知识点) 
信息增益率 

信息增益率

信息增益率是在求出信息增益值在除以这里写图片描述。 

例如下面公式为求属性为“outlook”的这里写图片描述值: 

这里写图片描述

三、C4.5的完整代码

  1. from numpy import *
  2. from scipy import *
  3. from math import log
  4. import operator
  5. #计算给定数据的香浓熵:
  6. def calcShannonEnt(dataSet):
  7. numEntries = len(dataSet)
  8. labelCounts = {} #类别字典(类别的名称为键,该类别的个数为值)
  9. for featVec in dataSet:
  10. currentLabel = featVec[-1]
  11. if currentLabel not in labelCounts.keys(): #还没添加到字典里的类型
  12. labelCounts[currentLabel] = 0;
  13. labelCounts[currentLabel] += 1;
  14. shannonEnt = 0.0
  15. for key in labelCounts: #求出每种类型的熵
  16. prob = float(labelCounts[key])/numEntries #每种类型个数占所有的比值
  17. shannonEnt -= prob * log(prob, 2)
  18. return shannonEnt; #返回熵
  19. #按照给定的特征划分数据集
  20. def splitDataSet(dataSet, axis, value):
  21. retDataSet = []
  22. for featVec in dataSet: #按dataSet矩阵中的第axis列的值等于value的分数据集
  23. if featVec[axis] == value: #值等于value的,每一行为新的列表(去除第axis个数据)
  24. reducedFeatVec = featVec[:axis]
  25. reducedFeatVec.extend(featVec[axis+1:])
  26. retDataSet.append(reducedFeatVec)
  27. return retDataSet #返回分类后的新矩阵
  28. #选择最好的数据集划分方式
  29. def chooseBestFeatureToSplit(dataSet):
  30. numFeatures = len(dataSet[0])-1 #求属性的个数
  31. baseEntropy = calcShannonEnt(dataSet)
  32. bestInfoGain = 0.0; bestFeature = -1
  33. for i in range(numFeatures): #求所有属性的信息增益
  34. featList = [example[i] for example in dataSet]
  35. uniqueVals = set(featList) #第i列属性的取值(不同值)数集合
  36. newEntropy = 0.0
  37. splitInfo = 0.0;
  38. for value in uniqueVals: #求第i列属性每个不同值的熵*他们的概率
  39. subDataSet = splitDataSet(dataSet, i , value)
  40. prob = len(subDataSet)/float(len(dataSet)) #求出该值在i列属性中的概率
  41. newEntropy += prob * calcShannonEnt(subDataSet) #求i列属性各值对于的熵求和
  42. splitInfo -= prob * log(prob, 2);
  43. infoGain = (baseEntropy - newEntropy) / splitInfo; #求出第i列属性的信息增益率
  44. print infoGain;
  45. if(infoGain > bestInfoGain): #保存信息增益率最大的信息增益率值以及所在的下表(列值i)
  46. bestInfoGain = infoGain
  47. bestFeature = i
  48. return bestFeature
  49. #找出出现次数最多的分类名称
  50. def majorityCnt(classList):
  51. classCount = {}
  52. for vote in classList:
  53. if vote not in classCount.keys(): classCount[vote] = 0
  54. classCount[vote] += 1
  55. sortedClassCount = sorted(classCount.iteritems(), key = operator.itemgetter(1), reverse=True)
  56. return sortedClassCount[0][0]
  57. #创建树
  58. def createTree(dataSet, labels):
  59. classList = [example[-1] for example in dataSet]; #创建需要创建树的训练数据的结果列表(例如最外层的列表是[N, N, Y, Y, Y, N, Y])
  60. if classList.count(classList[0]) == len(classList): #如果所有的训练数据都是属于一个类别,则返回该类别
  61. return classList[0];
  62. if (len(dataSet[0]) == 1): #训练数据只给出类别数据(没给任何属性值数据),返回出现次数最多的分类名称
  63. return majorityCnt(classList);
  64. bestFeat = chooseBestFeatureToSplit(dataSet); #选择信息增益最大的属性进行分(返回值是属性类型列表的下标)
  65. bestFeatLabel = labels[bestFeat] #根据下表找属性名称当树的根节点
  66. myTree = {bestFeatLabel:{}} #以bestFeatLabel为根节点建一个空树
  67. del(labels[bestFeat]) #从属性列表中删掉已经被选出来当根节点的属性
  68. featValues = [example[bestFeat] for example in dataSet] #找出该属性所有训练数据的值(创建列表)
  69. uniqueVals = set(featValues) #求出该属性的所有值得集合(集合的元素不能重复)
  70. for value in uniqueVals: #根据该属性的值求树的各个分支
  71. subLabels = labels[:]
  72. myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value), subLabels) #根据各个分支递归创建树
  73. return myTree #生成的树
  74. #实用决策树进行分类
  75. def classify(inputTree, featLabels, testVec):
  76. firstStr = inputTree.keys()[0]
  77. secondDict = inputTree[firstStr]
  78. featIndex = featLabels.index(firstStr)
  79. for key in secondDict.keys():
  80. if testVec[featIndex] == key:
  81. if type(secondDict[key]).__name__ == 'dict':
  82. classLabel = classify(secondDict[key], featLabels, testVec)
  83. else: classLabel = secondDict[key]
  84. return classLabel
  85. #读取数据文档中的训练数据(生成二维列表)
  86. def createTrainData():
  87. lines_set = open('../data/ID3/Dataset.txt').readlines()
  88. labelLine = lines_set[2];
  89. labels = labelLine.strip().split()
  90. lines_set = lines_set[4:11]
  91. dataSet = [];
  92. for line in lines_set:
  93. data = line.split();
  94. dataSet.append(data);
  95. return dataSet, labels
  96. #读取数据文档中的测试数据(生成二维列表)
  97. def createTestData():
  98. lines_set = open('../data/ID3/Dataset.txt').readlines()
  99. lines_set = lines_set[15:22]
  100. dataSet = [];
  101. for line in lines_set:
  102. data = line.strip().split();
  103. dataSet.append(data);
  104. return dataSet
  105. myDat, labels = createTrainData()
  106. myTree = createTree(myDat,labels)
  107. print myTree
  108. bootList = ['outlook','temperature', 'humidity', 'windy'];
  109. testList = createTestData();
  110. for testData in testList:
  111. dic = classify(myTree, bootList, testData)
  112. print dic

运行结果如下:

Dataset.txt 训练集和测试集:

  1. 训练集:
  2. outlook temperature humidity windy
  3. ---------------------------------------------------------
  4. sunny hot high false N
  5. sunny hot high true N
  6. overcast hot high false Y
  7. rain mild high false Y
  8. rain cool normal false Y
  9. rain cool normal true N
  10. overcast cool normal true Y
  11. 测试集
  12. outlook temperature humidity windy
  13. ---------------------------------------------------------
  14. sunny mild high false
  15. sunny cool normal false
  16. rain mild normal false
  17. sunny mild normal true
  18. overcast mild high true
  19. overcast hot normal false
  20. rain mild high true

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/你好赵伟/article/detail/611988
推荐阅读
相关标签
  

闽ICP备14008679号