当前位置:   article > 正文

机器学习--决策树_机器学习决策树

机器学习决策树

一、决策树简介

决策树(DecisionTree),又称为判定树,是另一种特殊的根树,它最初是运筹学中的常用工具之一;之后应用范围不断扩展,目前是人工智能中常见的机器学习方法之一。决策树是一种基于树结构来进行决策的分类算法,我们希望从给定的训练数据集学得一个模型(即决策树),用该模型对新样本分类。决策树可以非常直观展现分类的过程和结果,决策树模型构建成功后,对样本的分类效率也非常高。

二、决策树的优缺点

优点:计算复杂度不高,输出结果易于理解,对中间值的缺失不敏感,可以处理不相关特征数据。
缺点:可能会产生过度匹配问题。

三、决策树的一般流程

(1)收集数据:可以使用如何方法。
(2)准备数据:树构造算法只适用于标称型数据,因此数值型数据必须离散化。
(3)分析数据:可以使用任何方法,构造树完成之后,我们应该检查图形是否符合预期。
(4)训练算法:构造树的数据结构。
(5)测试算法:使用经验树计算错误率。
(6)使用算法:此步骤可以适用于任何监督学习算法,而使用决策树可以更好地理解数据的内在          含义。

四、信息增益

样本有多个属性,该先选哪个样本来划分数据集呢?在划分数据集之前之后信息发生的变化称为信息增益,获得信息增益最高的属性就是最好的选择。

1.信息熵

样本集合D中第k类样本所占的比例p_k(k=1,2,…,|K|),|K|为样本分类的个数,则D的信息熵为:

Ent(D)的值越小,则D的纯度越高。换句话说,信息熵越小,信息增益越大。

2.信息增益

使用属性a对样本集D进行划分所获得的“信息增益”的计算方法是,用样本集的总信息熵

减去属性a的每个分支的信息熵与权重(该分支的样本数除以总样本数)的乘积。

通常,信息增益越大,意味着用属性a进行划分所获得的“纯度提升”越大。我们的目标就是寻找使信息增益最大的属性作为划分的依据。

五、决策树的具体实现

1.收集数据

我收集的数据依旧是集美大学计算机工程学院acm比赛校选的数据,其中每列的属性分别是成绩、用时、年级、奖项。

  

2.准备数据

由于我所用的数据很明显是连续型的,我们需要将数据离散化。这里我随机挑选15个数据进行离散化。 其中将成绩、用时、年级分为4个等级,数字越大,分别代表成绩越高,用时越长,年级越高。奖项分为3个等级,1等奖,2等奖,3等奖。

3.导入数据

用pandas模块的read_csv()函数读取数据文本,分别求出数据集data,标签集labels,所有情况集labels_full

  1. #导入数据
  2. def import_data():
  3. data = pd.read_csv('data1.txt')
  4. data.head(10)
  5. data=np.array(data).tolist()
  6. # 属性值列表
  7. labels = ['得分', '用时', '年级', '奖项']
  8. # 特征对应的所有可能的情况
  9. labels_full = {}
  10. for i in range(len(labels)):
  11. labelList = [example[i] for example in data] #获取每一行的第一个数
  12. uniqueLabel = set(labelList)#去重
  13. labels_full[labels[i]] = uniqueLabel#每一个属性所对应的种类
  14. return data,labels,labels_full
  15. data,labels,labels_full=import_data()

4.计算信息熵

编写计算信息熵的算法,为后面的计算信息增益打下基础

  1. #计算信息熵
  2. def calcShannonEnt(dataSet):
  3. numEntries = len(dataSet)#计算数据集总数
  4. labelCounts = collections.defaultdict(int)#用来统计标签
  5. for featVec in dataSet:
  6. currentLabel = featVec[-1]#得到数据的分类标签
  7. if currentLabel not in labelCounts.keys():#若当前的标签不在标签集中则创建一个
  8. labelCounts[currentLabel] = 0
  9. labelCounts[currentLabel] += 1 #标签集中对应标签数目加一,统计每个类别
  10. shannonEnt = 0.0 #信息熵初始值
  11. for key in labelCounts:
  12. prob = float(labelCounts[key]) / numEntries #pk
  13. shannonEnt -= prob * math.log2(prob)
  14. return shannonEnt
  15. #print("当前数据的总信息熵",calcShannonEnt(data))

计算出该数据集的总信息熵

 

5.划分数据集

 我们对每个特征划分数据集的结果计算一次信息熵,然后判断按照哪个特征划分数据集是最好的划分方法。

  1. #划分数据集
  2. def splitDataSet(dataSet, axis, value):# 待划分的数据集 划分数据集的特征 需要返回的特征的值
  3. retDataSet = [] #创建一个新的列表
  4. for featVec in dataSet:
  5. if featVec[axis]==value:#如果给定的特征值是等于想要的特征值
  6. #将该特征值前面的内容保存起来
  7. reducedFeatVec = featVec[:axis]
  8. #将该特征值后面的内容保存起来
  9. reducedFeatVec.extend(featVec[axis + 1:])
  10. #表示去掉在axis中特征值为value的样本后而得到的数据集
  11. retDataSet.append(reducedFeatVec)
  12. return retDataSet

通过计算每个特征的信息增益,我们的目标是找出信息增益最大的的数据集划分方式,这个划分方式就是最好的数据集划分方式。

  1. #选择最好的数据集划分方式
  2. def chooseBestFeatureToSplit(dataSet, labels):
  3. #得到数据的特征值总数
  4. numFeatures = len(dataSet[0]) - 1
  5. #计算出总信息熵
  6. baseEntropy = calcShannonEnt(dataSet)
  7. #基础信息增益为0.0
  8. bestInfoGain = 0.0
  9. #最好的特征值
  10. bestFeature = -1
  11. #对每个特征值进行求信息熵
  12. for i in range(numFeatures):
  13. #得到数据集中所有的当前特征值列表
  14. featList = [example[i] for example in dataSet]
  15. #去掉重复的
  16. uniqueVals = set(featList)
  17. #新的熵,代表当前特征值的熵
  18. newEntropy = 0.0
  19. #遍历现在有的特征的可能性
  20. for value in uniqueVals:
  21. subDataSet = splitDataSet(dataSet=dataSet, axis=i, value=value)#在全部数据集的当前特征位置上,找到该特征值等于当前值的集合
  22. prob = len(subDataSet) / float(len(dataSet))#计算权重
  23. newEntropy += prob * calcShannonEnt(subDataSet)#计算当前特征值的熵
  24. infoGain = baseEntropy - newEntropy#计算信息增益
  25. print('当前特征值为:' + labels[i] + ',对应的信息增益值为:' + str(infoGain)+"i等于"+str(i))
  26. #选出最大的信息增益
  27. if infoGain > bestInfoGain:
  28. bestInfoGain = infoGain
  29. bestFeature = i #新的最好的用来划分的特征值
  30. print('信息增益最大的特征为:' + labels[bestFeature])
  31. return bestFeature
  32. #print(chooseBestFeatureToSplit(data,labels))

我们发现信息增益的最大值是得分,其次是用时,最后是年级。

6.递归构建决策树

在构建决策树,可能会出现这一种情况,如果数据集已经处理了所有的属性,但是类标签依然不是唯一的。在这种情况下,我们通常会采用多数表决的方法决定叶子节点的分类。

  1. #投票分类
  2. def majorityCnt(classList):
  3. classCount={}
  4. for vote in classList:
  5. if vote not in classCount.keys():classCount[vote]=0
  6. classCount[vote]+=1
  7. sortedClassCount=sorted(classCount.iteritems(),key=operator.itemgetter(1),reverse=True)
  8. print(sortedClassCount)
  9. return sortedClassCount[0][0] #返回出现次数最多的分类

对树进行创建

  1. #创建树
  2. def createTree(dataSet,labels):
  3. #拿到所有数据的分类标签
  4. classList=[example[-1] for example in dataSet]
  5. #类别完全相同则停止继续划分
  6. if classList.count(classList[0])==len(classList):
  7. return classList[0]
  8. #遍历完所有特征时返回出现次数最多的类别
  9. if len(dataSet[0])==1:
  10. return majorityCnt(classList)
  11. bestFeat=chooseBestFeatureToSplit(dataSet,labels)#选择最好的划分特征,得到该特征的下标
  12. print(bestFeat)
  13. bestFeatLabel=labels[bestFeat]#得到最好特征的名称
  14. print(bestFeatLabel)
  15. #使用一个字典来存储树结构,分叉处为划分的特征名称
  16. myTree={bestFeatLabel:{}}
  17. del(labels[bestFeat])#删除本次划分的特征值
  18. featValues=[example[bestFeat] for example in dataSet ]
  19. uniqueVals=set(featValues)
  20. for value in uniqueVals:
  21. #得到剩下的特征值
  22. subLabels=labels[:]
  23. #递归调用
  24. myTree[bestFeatLabel][value]=createTree(splitDataSet(dataSet,bestFeat,value),subLabels)
  25. return myTree
  26. #print(createTree(data,labels))

我们得到一个字典类型的树。
 

7.使用Matplotlib注解绘制树形图

将树的结构可视化,有助于我们理解具体的分类过程。

  1. import matplotlib.pyplot as plt
  2. import matplotlib
  3. # 能够显示中文
  4. matplotlib.rcParams['font.sans-serif'] = ['SimHei']
  5. matplotlib.rcParams['font.serif'] = ['SimHei']
  6. #定义文本框和箭头格式
  7. decisionNode=dict(boxstyle="sawtooth",fc='0.8') #分叉节点
  8. leafNode=dict(boxstyle="round4",fc='0.8') #叶子节点
  9. arrow_args=dict(arrowstyle="<-")
  10. def plotNode(nodeTxt,centerPt,parentPt,nodeType):
  11. createPlot.axl.annotate(nodeTxt,xy=parentPt,xycoords='axes fraction',
  12. xytext=centerPt,
  13. textcoords='axes fraction',
  14. va="center",ha="center",bbox=nodeType,arrowprops=arrow_args)
  15. #获取叶节点的数目和树的层数
  16. def getNumLeafs(myTree):
  17. numLeafs=0#统计叶子节点总数
  18. firstStr=list(myTree.keys())[0]#得到根节点
  19. secondDict=myTree[firstStr]#第一个节点对应的内容
  20. for key in secondDict.keys(): #如果key对应的是一个字典,就递归调用
  21. if type(secondDict[key]).__name__=='dict':
  22. numLeafs+=getNumLeafs(secondDict[key])
  23. else:
  24. numLeafs+=1
  25. return numLeafs
  26. def getTreeDepth(myTree):
  27. maxDepth=0
  28. firstStr=list(myTree.keys())[0]
  29. secondDict=myTree[firstStr]
  30. for key in secondDict.keys():
  31. if type(secondDict[key]).__name__=='dict':
  32. thisDepth=1+getTreeDepth(secondDict[key])
  33. else:
  34. thisDepth=1
  35. if thisDepth>maxDepth:maxDepth=thisDepth
  36. return maxDepth
  37. #计算出父节点和子节点的中间位置,填充信息
  38. def plotMidText(cntrPt,parentPt,txtString):
  39. xMid=(parentPt[0]-cntrPt[0])/2.0+cntrPt[0]
  40. yMid=(parentPt[1]-cntrPt[1])/2.0+cntrPt[1]
  41. createPlot.axl.text(xMid,yMid,txtString)
  42. #绘制出树的所有节点,递归绘制
  43. def plotTree(mytree,parentPt,nodeTxt):
  44. numLeafs=getNumLeafs(mytree)
  45. depth=getTreeDepth(mytree)
  46. firstStr=list(mytree.keys())[0]
  47. cntrPt=(plotTree.xOff+(1.0+float(numLeafs))/2.0/plotTree.totalW,plotTree.yOff)
  48. plotMidText(cntrPt,parentPt,nodeTxt)
  49. plotNode(firstStr,cntrPt,parentPt,decisionNode)
  50. secondDict=mytree[firstStr]
  51. plotTree.yOff=plotTree.yOff-1.0/plotTree.totalD
  52. for key in secondDict.keys():
  53. if type(secondDict[key]).__name__=='dict':
  54. plotTree(secondDict[key],cntrPt,str(key))
  55. else:
  56. plotTree.xOff=plotTree.xOff+1.0/plotTree.totalW
  57. plotNode(secondDict[key],(plotTree.xOff,plotTree.yOff),cntrPt,leafNode)
  58. plotMidText((plotTree.xOff,plotTree.yOff),cntrPt,str(key))
  59. plotTree.yOff=plotTree.yOff+1.0/plotTree.totalD
  60. #绘制决策树
  61. def createPlot(inTree):
  62. fig=plt.figure(1,facecolor='white')
  63. fig.clf()
  64. axprops=dict(xticks=[],yticks=[])
  65. createPlot.axl=plt.subplot(111,frameon=False,**axprops)
  66. plotTree.totalW=float(getNumLeafs(inTree))
  67. plotTree.totalD=float(getTreeDepth(inTree))
  68. plotTree.xOff=-0.5/plotTree.totalW;plotTree.yOff=1.0;plotTree(inTree,(0.5,1.0),'')
  69. plt.show()

展示结果

  1. if __name__ == '__main__':
  2. mytree=createTree(data,labels)
  3. createPlot(mytree)

六.实验总结

本次实验只是对决策树的创建原理和创建算法以及展示创建的决策树进行了主要介绍。下一次实验,我们将会具体涉及到树的预剪枝、后剪枝、连续数据的离散化。

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/小小林熬夜学编程/article/detail/330745
推荐阅读
相关标签
  

闽ICP备14008679号