当前位置:   article > 正文

python实现机器学习之决策树_python 是否购买电脑 决策树

python 是否购买电脑 决策树
这几天在看决策树算法,发现这算法在实际的应用中使用挺多的。所以想总结一下:
这里给出一些我觉得比较好的博客链接:
http://blog.jobbole.com/86443/ 通俗易懂,同时也讲了一些决策树算法:如ID3、C4.5之类的。以及建立完完整的决策树之后,为了防止过拟合而进行的剪枝的操作。
决策树算法介绍及应用:http://blog.jobbole.com/89072/ 这篇博文写的非常的好,主要是将数据挖掘导论这本数的第四章给总结了一番。主要讲解了决策树的建立,决策树的评估以及最后使用R语言和SPASS来进行决策树的建模。
数据挖掘(6):决策树分类算法:http://blog.jobbole.com/90165/ 这篇博客将分类写的很详细。并且列举了一个不同人群,不同信贷购买电脑的例子。可惜没有实现代码。


在这里我在读完这3篇博文之后,给出了用python语言和sklearn包的完整实现的例子。


    决策树(decision tree)是一个树结构(可以是二叉树或非二叉树)。其每个非叶节点表示一个特征属性上的测试,每个分支代表这个特征属性在某个值域上的输出,而每个叶节点存放一个类别。使用决策树进行决策的过程就是从根节点开始,测试待分类项中相应的特征属性,并按照其值选择输出分支,直到到达叶子节点,将叶子节点存放的类别作为决策结果。
决策树的一般流程:
1、收集数据:可以使用任何的方法
2、准备数据:树构造算法只适用标称型数据,因此数值型的数据必须离散化。
3、分析数据:
4、训练算法:构造树的数据结构
5、测试算法:使用经验树计算错误率
6、使用算法




这里给的例子如下:(注意:这里使用的是ID3)

记录ID 年龄 输入层次 学生 信用等级 是否购买电脑
1 青少年  否   一般    否
2 青少年  否   良好    否
3 中年  否 一般  是
4 老年  否   一般    是
5 老年  是 一般    是
6 老年  是 良好  否
7 中年  是 良好    是
8 青少年  否 一般  否
9 青少年  是 一般  是
10 老年  是 一般    是
11 青少年  是 良好    是
12 中年  否 良好    是
13 中年  是 一般    是
14 老年  否 良好    否

这里我将其转换为特征向量的形式如下:


  1. dataSet = [[1, 3, 0, 1, 'no'],
  2. [1, 3, 0, 2, 'no'],
  3. [2, 3, 0, 1, 'yes'],
  4. [3, 2, 0, 1, 'yes'],
  5. [3, 1, 1, 1, 'yes'],
  6. [3, 1, 1, 2, 'no'],
  7. [2, 1, 1, 2, 'yes'],
  8. [1, 2, 0, 1, 'no'],
  9. [1, 1, 1, 1, 'yes'],
  10. [3, 2, 1, 1, 'yes'],
  11. [1, 2, 1, 2, 'yes'],
  12. [2, 2, 0, 2, 'yes'],
  13. [2, 3, 0, 1, 'yes'],
  14. [3, 2, 0, 2, 'no'],
  15. ]
  16. labels = ['age','salary','isStudent','credit']

完整的实现决策树的代码:

  1. #-*- coding:utf-8 -*-
  2. from math import log
  3. import operator
  4. def createDataSet():
  5. dataSet = [[1, 3, 0, 1, 'no'],
  6. [1, 3, 0, 2, 'no'],
  7. [2, 3, 0, 1, 'yes'],
  8. [3, 2, 0, 1, 'yes'],
  9. [3, 1, 1, 1, 'yes'],
  10. [3, 1, 1, 2, 'no'],
  11. [2, 1, 1, 2, 'yes'],
  12. [1, 2, 0, 1, 'no'],
  13. [1, 1, 1, 1, 'yes'],
  14. [3, 2, 1, 1, 'yes'],
  15. [1, 2, 1, 2, 'yes'],
  16. [2, 2, 0, 2, 'yes'],
  17. [2, 3, 0, 1, 'yes'],
  18. [3, 2, 0, 2, 'no'],
  19. ]
  20. labels = ['age','salary','isStudent','credit']
  21. #change to discrete values
  22. return dataSet, labels
  23. ############计算香农熵###############
  24. def calcShannonEnt(dataSet):
  25. numEntries = len(dataSet)#计算实例的总数
  26. labelCounts = {}#创建一个数据字典,它的key是最后把一列的数值(即标签),value记录当前类型(即标签)出现的次数
  27. for featVec in dataSet: #遍历整个训练集
  28. currentLabel = featVec[-1]
  29. if currentLabel not in labelCounts.keys():
  30. labelCounts[currentLabel] = 0
  31. labelCounts[currentLabel] += 1
  32. shannonEnt = 0.0#初始化香农熵
  33. for key in labelCounts:
  34. prob = float(labelCounts[key])/numEntries
  35. shannonEnt -= prob * log(prob,2) #计算香农熵
  36. return shannonEnt
  37. #########按给定的特征划分数据#########
  38. def splitDataSet(dataSet, axis, value): #axis表示特征的索引  value是返回的特征值
  39. retDataSet = []
  40. for featVec in dataSet:
  41. if featVec[axis] == value:
  42. reducedFeatVec = featVec[:axis] #抽取除axis特征外的所有的记录的内容
  43. reducedFeatVec.extend(featVec[axis+1:])
  44. retDataSet.append(reducedFeatVec)
  45. return retDataSet
  46. #######遍历整个数据集,选择最好的数据集划分方式########
  47. def chooseBestFeatureToSplit(dataSet):
  48. numFeatures = len(dataSet[0]) - 1 #获取当前实例的特征个数,一般最后一列是标签。the last column is used for the labels
  49. baseEntropy = calcShannonEnt(dataSet)#计算当前实例的香农熵
  50. bestInfoGain = 0.0; bestFeature = -1#这里初始化最佳的信息增益和最佳的特征
  51. for i in range(numFeatures): #遍历每一个特征 iterate over all the features
  52. featList = [example[i] for example in dataSet]#create a list of all the examples of this feature
  53. uniqueVals = set(featList) #创建唯一的分类标签
  54. newEntropy = 0.0
  55. for value in uniqueVals:#计算每种划分方式的信息熵
  56. subDataSet = splitDataSet(dataSet, i, value)
  57. prob = len(subDataSet)/float(len(dataSet))
  58. newEntropy += prob * calcShannonEnt(subDataSet)
  59. infoGain = baseEntropy - newEntropy #计算信息增益
  60. if (infoGain > bestInfoGain): #比较每个特征的信息增益,只要最好的信息增益
  61. bestInfoGain = infoGain #if better than current best, set to best
  62. bestFeature = i
  63. return bestFeature,bestInfoGain #返回最佳划分的特征索引和信息增益
  64. '''该函数使用分类名称的列表,然后创建键值为classList中唯一值的数据字典。字典
  65. 对象的存储了classList中每个类标签出现的频率。最后利用operator操作键值排序字典,
  66. 并返回出现次数最多的分类名称
  67. '''
  68. def majorityCnt(classList):
  69. classCount={}
  70. for vote in classList:
  71. if vote not in classCount.keys():
  72. classCount[vote] = 0
  73. classCount[vote] += 1
  74. sortedClassCount = sorted(classCount.iteritems(), key=operator.itemgetter(1), reverse=True)
  75. return sortedClassCount[0][0]
  76. def createTree(dataSet,labels):
  77. classList = [example[-1] for example in dataSet]#返回所有的标签
  78. if classList.count(classList[0]) == len(classList): #当类别完全相同时则停止继续划分,直接返回该类的标签
  79. return classList[0]#stop splitting when all of the classes are equal
  80. if len(dataSet[0]) == 1: #遍历完所有的特征时,仍然不能将数据集划分成仅包含唯一类别的分组
  81. return majorityCnt(classList)#由于无法简单的返回唯一的类标签,这里就返回出现次数最多的类别作为返回值
  82. bestFeat,bestInfogain= chooseBestFeatureToSplit(dataSet)#获取最好的分类特征索引
  83. bestFeatLabel = labels[bestFeat]#获取该特征的名称
  84. #这里直接使用字典变量来存储树信息,这对于绘制树形图很重要。
  85. myTree = {bestFeatLabel:{}}#当前数据集选取最好的特征存储在bestFeat中
  86. del(labels[bestFeat])#删除已经在选取的特征
  87. featValues = [example[bestFeat] for example in dataSet]
  88. uniqueVals = set(featValues)
  89. for value in uniqueVals:
  90. subLabels = labels[:] #copy all of labels, so trees don't mess up existing labels
  91. myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value),subLabels)
  92. return myTree
  93. def classify(inputTree,featLabels,testVec):
  94. firstStr = inputTree.keys()[0]
  95. secondDict = inputTree[firstStr]
  96. featIndex = featLabels.index(firstStr)
  97. key = testVec[featIndex]
  98. valueOfFeat = secondDict[key]
  99. if isinstance(valueOfFeat, dict):
  100. classLabel = classify(valueOfFeat, featLabels, testVec)
  101. else: classLabel = valueOfFeat
  102. return classLabel
  103. def storeTree(inputTree,filename):
  104. import pickle
  105. fw = open(filename,'w')
  106. pickle.dump(inputTree,fw)
  107. fw.close()
  108. def grabTree(filename):
  109. import pickle
  110. fr = open(filename)
  111. return pickle.load(fr)
  112. import matplotlib.pyplot as plt
  113. decisionNode = dict(boxstyle="sawtooth", fc="0.8") #定义文本框与箭头的格式
  114. leafNode = dict(boxstyle="round4", fc="0.8")
  115. arrow_args = dict(arrowstyle="<-")
  116. def getNumLeafs(myTree):#获取树节点的数目
  117. numLeafs = 0
  118. firstStr = myTree.keys()[0]
  119. secondDict = myTree[firstStr]
  120. for key in secondDict.keys():#测试节点的数据类型是不是字典,如果是则就需要递归的调用getNumLeafs()函数
  121. if type(secondDict[key]).__name__=='dict':#test to see if the nodes are dictonaires, if not they are leaf nodes
  122. numLeafs += getNumLeafs(secondDict[key])
  123. else: numLeafs +=1
  124. return numLeafs
  125. def getTreeDepth(myTree):#获取树节点的树的层数
  126. maxDepth = 0
  127. firstStr = myTree.keys()[0]
  128. secondDict = myTree[firstStr]
  129. for key in secondDict.keys():
  130. if type(secondDict[key]).__name__=='dict':#test to see if the nodes are dictonaires, if not they are leaf nodes
  131. thisDepth = 1 + getTreeDepth(secondDict[key])
  132. else: thisDepth = 1
  133. if thisDepth > maxDepth: maxDepth = thisDepth
  134. return maxDepth
  135. def plotNode(nodeTxt, centerPt, parentPt, nodeType): #绘制带箭头的注释
  136. createPlot.ax1.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction',#createPlot.ax1会提供一个绘图区
  137. xytext=centerPt, textcoords='axes fraction',
  138. va="center", ha="center", bbox=nodeType, arrowprops=arrow_args )
  139. def plotMidText(cntrPt, parentPt, txtString):#计算父节点和子节点的中间位置,在父节点间填充文本的信息
  140. xMid = (parentPt[0]-cntrPt[0])/2.0 + cntrPt[0]
  141. yMid = (parentPt[1]-cntrPt[1])/2.0 + cntrPt[1]
  142. createPlot.ax1.text(xMid, yMid, txtString, va="center", ha="center", rotation=30)
  143. def plotTree(myTree, parentPt, nodeTxt):#if the first key tells you what feat was split on
  144. numLeafs = getNumLeafs(myTree) #首先计算树的宽和高
  145. depth = getTreeDepth(myTree)
  146. firstStr = myTree.keys()[0] #the text label for this node should be this
  147. cntrPt = (plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW, plotTree.yOff)
  148. plotMidText(cntrPt, parentPt, nodeTxt)
  149. plotNode(firstStr, cntrPt, parentPt, decisionNode)#标记子节点的属性值
  150. secondDict = myTree[firstStr]
  151. plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD
  152. for key in secondDict.keys():
  153. if type(secondDict[key]).__name__=='dict':#test to see if the nodes are dictonaires, if not they are leaf nodes
  154. plotTree(secondDict[key],cntrPt,str(key)) #recursion
  155. else: #it's a leaf node print the leaf node
  156. plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW
  157. plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode)
  158. plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))
  159. plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD
  160. #if you do get a dictonary you know it's a tree, and the first element will be another dict
  161. #
  162. def createPlot(inTree):
  163. fig = plt.figure(1, facecolor='white')
  164. fig.clf()
  165. axprops = dict(xticks=[], yticks=[])
  166. createPlot.ax1 = plt.subplot(111, frameon=False, **axprops) #no ticks
  167. #createPlot.ax1 = plt.subplot(111, frameon=False) #ticks for demo puropses
  168. plotTree.totalW = float(getNumLeafs(inTree))#c存储树的宽度
  169. plotTree.totalD = float(getTreeDepth(inTree))#存储树的深度。我们使用这两个变量计算树节点的摆放位置
  170. plotTree.xOff = -0.5/plotTree.totalW; plotTree.yOff = 1.0;
  171. plotTree(inTree, (0.5,1.0), '')
  172. plt.show()
  173. # def createPlot():
  174. # fig = plt.figure(1, facecolor='white')
  175. # fig.clf()
  176. # createPlot.ax1 = plt.subplot(111, frameon=False) #创建一个新图形,并清空绘图区
  177. # plotNode('a decision node', (0.5, 0.1), (0.1, 0.5), decisionNode)#然后在绘图区上绘制两个代表不同类型的树节点
  178. # plotNode('a leaf node', (0.8, 0.1), (0.3, 0.8), leafNode)
  179. # plt.show()

测试代码:

  1. #########测试代码#############
  2. myDat,labels=createDataSet()
  3. print calcShannonEnt(myDat)
  4. print myDat
  5. bestfeature,bestInfogain=chooseBestFeatureToSplit(myDat)
  6. print bestfeature,bestInfogain
  7. myTree=createTree(myDat, labels)
  8. print myTree
  9. print getNumLeafs(myTree)
  10. print getTreeDepth(myTree)
  11. createPlot(myTree)
  12. ##########测试结束#############
结果显示:
0.940285958671
[[1, 3, 0, 1, 'no'], [1, 3, 0, 2, 'no'], [2, 3, 0, 1, 'yes'], [3, 2, 0, 1, 'yes'], [3, 1, 1, 1, 'yes'], [3, 1, 1, 2, 'no'], [2, 1, 1, 2, 'yes'], [1, 2, 0, 1, 'no'], [1, 1, 1, 1, 'yes'], [3, 2, 1, 1, 'yes'], [1, 2, 1, 2, 'yes'], [2, 2, 0, 2, 'yes'], [2, 3, 0, 1, 'yes'], [3, 2, 0, 2, 'no']]
0 0.246749819774
{'age': {1: {'isStudent': {0: 'no', 1: 'yes'}}, 2: 'yes', 3: {'credit': {1: 'yes', 2: 'no'}}}}
5
2

画出决策树图:
决策树1.png


接下来我打算使用python的scikit-learn机器学习算法包来实现上面的决策树:

  1. #-*- coding:utf-8 -*-
  2. from sklearn.datasets import load_iris
  3. from sklearn import tree
  4. dataSet = [[1, 3, 0, 1, 'no'],
  5. [1, 3, 0, 2, 'no'],
  6. [2, 3, 0, 1, 'yes'],
  7. [3, 2, 0, 1, 'yes'],
  8. [3, 1, 1, 1, 'yes'],
  9. [3, 1, 1, 2, 'no'],
  10. [2, 1, 1, 2, 'yes'],
  11. [1, 2, 0, 1, 'no'],
  12. [1, 1, 1, 1, 'yes'],
  13. [3, 2, 1, 1, 'yes'],
  14. [1, 2, 1, 2, 'yes'],
  15. [2, 2, 0, 2, 'yes'],
  16. [2, 3, 0, 1, 'yes'],
  17. [3, 2, 0, 2, 'no'],
  18. ]
  19. labels = ['age','salary','isStudent','credit']
  20. from sklearn.cross_validation import train_test_split #这里是引用了交叉验证
  21. FeatureSet=[]
  22. Label=[]
  23. for i in dataSet:
  24. FeatureSet.append(i[:-1])
  25. Label.append(i[-1])
  26. X_train,X_test, y_train, y_test = train_test_split(FeatureSet, Label, random_state=1)#将数据随机分成训练集和测试集
  27. print X_train
  28. print X_test
  29. print y_train
  30. print y_test
  31. #print iris
  32. clf = tree.DecisionTreeClassifier()
  33. clf = clf.fit(X_train, y_train)
  34. from sklearn.externals.six import StringIO
  35. with open("isBuy.dot", 'w') as f:
  36. f = tree.export_graphviz(clf, out_file=f)
  37. import os
  38. os.unlink('isBuy.dot')
  39. #
  40. from sklearn.externals.six import StringIO
  41. import pydot #注意要安装pydot2这个python插件。否则会报错。
  42. dot_data = StringIO()
  43. tree.export_graphviz(clf, out_file=dot_data)
  44. graph = pydot.graph_from_dot_data(dot_data.getvalue())
  45. graph.write_pdf("isBuy.pdf") #将决策树以pdf格式输出
  46. pre_labels=clf.predict(X_test)
  47. print pre_labels


结果显示如下:
[[1, 2, 1, 2], [3, 1, 1, 1], [1, 3, 0, 2], [2, 3, 0, 1], [1, 3, 0, 1], [3, 2, 0, 2], [3, 2, 1, 1], [1, 1, 1, 1], [2, 2, 0, 2], [3, 1, 1, 2]]
[[3, 2, 0, 1], [1, 2, 0, 1], [2, 1, 1, 2], [2, 3, 0, 1]]
['yes', 'yes', 'no', 'yes', 'no', 'no', 'yes', 'yes', 'yes', 'no']
['yes', 'no', 'yes', 'yes']
['yes' 'yes' 'yes' 'yes']

由最后两行我们可以发现第二个结果预测错误了。但是这并不能说明我们的决策树不够好。相反我觉得这样反而不会出现过拟合的现象。当然也不是绝对的,如果过拟合的化,就需要考虑剪枝了。这是后话了。
决策树的图片如下:
图片中的X[3]表示的是labels = ['age','salary','isStudent','credit']中的索引为3的特征。
该图将会保存在isBuy.pd 








声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/你好赵伟/article/detail/768628
推荐阅读
相关标签
  

闽ICP备14008679号