当前位置:   article > 正文

手写决策树ID3算法(python)_id3决策树算法例题手算

id3决策树算法例题手算

决策数(Decision Tree)在机器学习中也是比较常见的一种算法,属于监督学习中的一种。看字面意思应该也比较容易理解,相比其他算法比如支持向量机(SVM)或神经网络,似乎决策树感觉“亲切”许多。

优点:计算复杂度不高,输出结果易于理解,对中间值的缺失值不敏感,可以处理不相关特征数据。
缺点:可能会产生过度匹配的问题。
使用数据类型:数值型和标称型。
简单介绍完毕,让我们来通过一个例子让决策树“原形毕露”。

一天,老师问了个问题,只根据头发和声音怎么判断一位同学的性别。
为了解决这个问题,同学们马上简单的统计了7位同学的相关特征,数据如下:

机智的同学A想了想,先根据头发判断,若判断不出,再根据声音判断,于是画了一幅图,如下:

于是,一个简单、直观的决策树就这么出来了。头发长、声音粗就是男生;头发长、声音细就是女生;头发短、声音粗是男生;头发短、声音细是女生。
原来机器学习中决策树就这玩意,这也太简单了吧。。。
这时又蹦出个同学B,想先根据声音判断,然后再根据头发来判断,如是大手一挥也画了个决策树:

同学B的决策树:首先判断声音,声音细,就是女生;声音粗、头发长是男生;声音粗、头发长是女生。

那么问题来了:同学A和同学B谁的决策树好些?计算机做决策树的时候,面对多个特征,该如何选哪个特征为最佳的划分特征?

划分数据集的大原则是:将无序的数据变得更加有序。
我们可以使用多种方法划分数据集,但是每种方法都有各自的优缺点。于是我们这么想,如果我们能测量数据的复杂度,对比按不同特征分类后的数据复杂度,若按某一特征分类后复杂度减少的更多,那么这个特征即为最佳分类特征。
Claude Shannon 定义了熵(entropy)和信息增益(information gain)。
用熵来表示信息的复杂度,熵越大,则信息越复杂。公式如下:

信息增益(information gain),表示两个信息熵的差值。
首先计算未分类前的熵,总共有8位同学,男生3位,女生5位。
熵(总)=-3/8log2(3/8)-5/8log2(5/8)=0.9544
接着分别计算同学A和同学B分类后信息熵。
同学A首先按头发分类,分类后的结果为:长头发中有1男3女。短头发中有2男2女。
熵(同学A长发)=-1/4log2(1/4)-3/4log2(3/4)=0.8113
熵(同学A短发)=-2/4log2(2/4)-2/4log2(2/4)=1
熵(同学A)=4/80.8113+4/81=0.9057
信息增益(同学A)=熵(总)-熵(同学A)=0.9544-0.9057=0.0487
同理,按同学B的方法,首先按声音特征来分,分类后的结果为:声音粗中有3男3女。声音细中有0男2女。
熵(同学B声音粗)=-3/6log2(3/6)-3/6log2(3/6)=1
熵(同学B声音粗)=-2/2log2(2/2)=0
熵(同学B)=6/81+2/8*0=0.75
信息增益(同学B)=熵(总)-熵(同学B)=0.9544-0.75=0.2087

按同学B的方法,先按声音特征分类,信息增益更大,区分样本的能力更强,更具有代表性。
以上就是决策树ID3算法的核心思想。
接下来用python代码来实现ID3算法:
 

  1. #决策树ID3算法
  2. from math import log
  3. import operator
  4. def calcShannonEnt(dataSet): # 计算数据的熵(entropy)
  5. numEntries=len(dataSet) # 数据条数
  6. labelCounts={}
  7. for featVec in dataSet:
  8. currentLabel=featVec[-1] # 每行数据的最后一个字(类别)
  9. if currentLabel not in labelCounts.keys():
  10. labelCounts[currentLabel]=0
  11. labelCounts[currentLabel]+=1 # 统计有多少个类以及每个类的数量
  12. shannonEnt=0
  13. for key in labelCounts:
  14. prob=float(labelCounts[key])/numEntries # 计算单个类的熵值
  15. shannonEnt-=prob*log(prob,2) # 累加每个类的熵值
  16. return shannonEnt
  17. def createDataSet1(): # 创造示例数据
  18. dataSet = [['长', '粗', '男'],
  19. ['短', '粗', '男'],
  20. ['短', '粗', '男'],
  21. ['长', '细', '女'],
  22. ['短', '细', '女'],
  23. ['短', '粗', '女'],
  24. ['长', '粗', '女'],
  25. ['长', '粗', '女']]
  26. labels = ['头发','声音'] #两个特征
  27. return dataSet,labels
  28. def splitDataSet(dataSet,axis,value): # 按某个特征分类后的数据
  29. retDataSet=[]
  30. for featVec in dataSet:
  31. if featVec[axis]==value:
  32. reducedFeatVec =featVec[:axis]
  33. reducedFeatVec.extend(featVec[axis+1:])
  34. retDataSet.append(reducedFeatVec)
  35. return retDataSet
  36. def chooseBestFeatureToSplit(dataSet): # 选择最优的分类特征
  37. numFeatures = len(dataSet[0])-1
  38. baseEntropy = calcShannonEnt(dataSet) # 原始的熵
  39. bestInfoGain = 0
  40. bestFeature = -1
  41. for i in range(numFeatures):
  42. featList = [example[i] for example in dataSet]
  43. uniqueVals = set(featList)
  44. newEntropy = 0
  45. for value in uniqueVals:
  46. subDataSet = splitDataSet(dataSet,i,value)
  47. prob =len(subDataSet)/float(len(dataSet))
  48. newEntropy +=prob*calcShannonEnt(subDataSet) # 按特征分类后的熵
  49. infoGain = baseEntropy - newEntropy # 原始熵与按特征分类后的熵的差值
  50. if (infoGain>bestInfoGain): # 若按某特征划分后,熵值减少的最大,则次特征为最优分类特征
  51. bestInfoGain=infoGain
  52. bestFeature = i
  53. return bestFeature
  54. def majorityCnt(classList): #按分类后类别数量排序,比如:最后分类为2男1女,则判定为男;
  55. classCount={}
  56. for vote in classList:
  57. if vote not in classCount.keys():
  58. classCount[vote]=0
  59. classCount[vote]+=1
  60. sortedClassCount = sorted(classCount.items(),key=operator.itemgetter(1),reverse=True)
  61. return sortedClassCount[0][0]
  62. def createTree(dataSet,labels):
  63. classList=[example[-1] for example in dataSet] # 类别:男或女
  64. if classList.count(classList[0])==len(classList):
  65. return classList[0]
  66. if len(dataSet[0])==1:
  67. return majorityCnt(classList)
  68. bestFeat=chooseBestFeatureToSplit(dataSet) #选择最优特征
  69. bestFeatLabel=labels[bestFeat]
  70. myTree={bestFeatLabel:{}} #分类结果以字典形式保存
  71. del(labels[bestFeat])
  72. featValues=[example[bestFeat] for example in dataSet]
  73. uniqueVals=set(featValues)
  74. for value in uniqueVals:
  75. subLabels=labels[:]
  76. myTree[bestFeatLabel][value]=createTree(splitDataSet\
  77. (dataSet,bestFeat,value),subLabels)
  78. return myTree
  79. def predict(mytree, tips, list1):
  80. res = []
  81. for item in list1:
  82. tmp_tree = mytree
  83. iter = tmp_tree.__iter__()
  84. while 1:
  85. try:
  86. key = iter.__next__()
  87. if isinstance(key, str) and (key == "男" or key == "女"):
  88. res.append(key)
  89. break
  90. v = tmp_tree[key]
  91. index = tips[key]
  92. item_res = item[index]
  93. tmp_tree = v[item_res]
  94. iter = tmp_tree.__iter__()
  95. except StopIteration:
  96. break
  97. return res
  98. if __name__=='__main__':
  99. dataSet, labels=createDataSet1() # 创造示列数据
  100. mytree = createTree(dataSet, labels)
  101. print(mytree) # 输出决策树模型结果
  102. #预测
  103. tips = {"头发":0, "声音":1}
  104. res = predict(mytree, tips, [['长', '粗'], ['短', '粗']])
  105. print(res)

 

 

 

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/AllinToyou/article/detail/354662
推荐阅读
相关标签
  

闽ICP备14008679号