赞
踩
(1)收集数据:可以使用如何方法。
(2)准备数据:树构造算法只适用于标称型数据,因此数值型数据必须离散化。
(3)分析数据:可以使用任何方法,构造树完成之后,我们应该检查图形是否符合预期。
(4)训练算法:构造树的数据结构。
(5)测试算法:使用经验树计算错误率。
(6)使用算法:此步骤可以适用于任何监督学习算法,而使用决策树可以更好地理解数据的内在 含义。
- #导入数据
- def import_data():
- data = pd.read_csv('data1.txt')
- data.head(10)
- data=np.array(data).tolist()
- # 属性值列表
- labels = ['得分', '用时', '年级', '奖项']
-
- # 特征对应的所有可能的情况
- labels_full = {}
-
- for i in range(len(labels)):
- labelList = [example[i] for example in data] #获取每一行的第一个数
- uniqueLabel = set(labelList)#去重
- labels_full[labels[i]] = uniqueLabel#每一个属性所对应的种类
- return data,labels,labels_full
-
- data,labels,labels_full=import_data()

- #计算信息熵
- def calcShannonEnt(dataSet):
- numEntries = len(dataSet)#计算数据集总数
- labelCounts = collections.defaultdict(int)#用来统计标签
- for featVec in dataSet:
- currentLabel = featVec[-1]#得到数据的分类标签
- if currentLabel not in labelCounts.keys():#若当前的标签不在标签集中则创建一个
- labelCounts[currentLabel] = 0
- labelCounts[currentLabel] += 1 #标签集中对应标签数目加一,统计每个类别
-
- shannonEnt = 0.0 #信息熵初始值
- for key in labelCounts:
- prob = float(labelCounts[key]) / numEntries #pk
- shannonEnt -= prob * math.log2(prob)
- return shannonEnt
-
- #print("当前数据的总信息熵",calcShannonEnt(data))

- #划分数据集
- def splitDataSet(dataSet, axis, value):# 待划分的数据集 划分数据集的特征 需要返回的特征的值
- retDataSet = [] #创建一个新的列表
- for featVec in dataSet:
- if featVec[axis]==value:#如果给定的特征值是等于想要的特征值
- #将该特征值前面的内容保存起来
- reducedFeatVec = featVec[:axis]
- #将该特征值后面的内容保存起来
- reducedFeatVec.extend(featVec[axis + 1:])
- #表示去掉在axis中特征值为value的样本后而得到的数据集
- retDataSet.append(reducedFeatVec)
- return retDataSet
- #选择最好的数据集划分方式
- def chooseBestFeatureToSplit(dataSet, labels):
- #得到数据的特征值总数
- numFeatures = len(dataSet[0]) - 1
- #计算出总信息熵
- baseEntropy = calcShannonEnt(dataSet)
- #基础信息增益为0.0
- bestInfoGain = 0.0
- #最好的特征值
- bestFeature = -1
- #对每个特征值进行求信息熵
- for i in range(numFeatures):
- #得到数据集中所有的当前特征值列表
- featList = [example[i] for example in dataSet]
- #去掉重复的
- uniqueVals = set(featList)
- #新的熵,代表当前特征值的熵
- newEntropy = 0.0
- #遍历现在有的特征的可能性
- for value in uniqueVals:
- subDataSet = splitDataSet(dataSet=dataSet, axis=i, value=value)#在全部数据集的当前特征位置上,找到该特征值等于当前值的集合
- prob = len(subDataSet) / float(len(dataSet))#计算权重
- newEntropy += prob * calcShannonEnt(subDataSet)#计算当前特征值的熵
- infoGain = baseEntropy - newEntropy#计算信息增益
- print('当前特征值为:' + labels[i] + ',对应的信息增益值为:' + str(infoGain)+"i等于"+str(i))
- #选出最大的信息增益
- if infoGain > bestInfoGain:
- bestInfoGain = infoGain
- bestFeature = i #新的最好的用来划分的特征值
- print('信息增益最大的特征为:' + labels[bestFeature])
- return bestFeature
-
- #print(chooseBestFeatureToSplit(data,labels))

- #投票分类
- def majorityCnt(classList):
- classCount={}
- for vote in classList:
- if vote not in classCount.keys():classCount[vote]=0
- classCount[vote]+=1
- sortedClassCount=sorted(classCount.iteritems(),key=operator.itemgetter(1),reverse=True)
- print(sortedClassCount)
- return sortedClassCount[0][0] #返回出现次数最多的分类
- #创建树
- def createTree(dataSet,labels):
- #拿到所有数据的分类标签
- classList=[example[-1] for example in dataSet]
- #类别完全相同则停止继续划分
- if classList.count(classList[0])==len(classList):
- return classList[0]
- #遍历完所有特征时返回出现次数最多的类别
- if len(dataSet[0])==1:
- return majorityCnt(classList)
- bestFeat=chooseBestFeatureToSplit(dataSet,labels)#选择最好的划分特征,得到该特征的下标
- print(bestFeat)
- bestFeatLabel=labels[bestFeat]#得到最好特征的名称
- print(bestFeatLabel)
- #使用一个字典来存储树结构,分叉处为划分的特征名称
- myTree={bestFeatLabel:{}}
- del(labels[bestFeat])#删除本次划分的特征值
- featValues=[example[bestFeat] for example in dataSet ]
- uniqueVals=set(featValues)
- for value in uniqueVals:
- #得到剩下的特征值
- subLabels=labels[:]
- #递归调用
- myTree[bestFeatLabel][value]=createTree(splitDataSet(dataSet,bestFeat,value),subLabels)
- return myTree
-
- #print(createTree(data,labels))

- import matplotlib.pyplot as plt
- import matplotlib
-
-
- # 能够显示中文
- matplotlib.rcParams['font.sans-serif'] = ['SimHei']
- matplotlib.rcParams['font.serif'] = ['SimHei']
- #定义文本框和箭头格式
- decisionNode=dict(boxstyle="sawtooth",fc='0.8') #分叉节点
- leafNode=dict(boxstyle="round4",fc='0.8') #叶子节点
- arrow_args=dict(arrowstyle="<-")
-
- def plotNode(nodeTxt,centerPt,parentPt,nodeType):
- createPlot.axl.annotate(nodeTxt,xy=parentPt,xycoords='axes fraction',
- xytext=centerPt,
- textcoords='axes fraction',
- va="center",ha="center",bbox=nodeType,arrowprops=arrow_args)
-
- #获取叶节点的数目和树的层数
- def getNumLeafs(myTree):
- numLeafs=0#统计叶子节点总数
- firstStr=list(myTree.keys())[0]#得到根节点
- secondDict=myTree[firstStr]#第一个节点对应的内容
- for key in secondDict.keys(): #如果key对应的是一个字典,就递归调用
- if type(secondDict[key]).__name__=='dict':
- numLeafs+=getNumLeafs(secondDict[key])
- else:
- numLeafs+=1
- return numLeafs
-
- def getTreeDepth(myTree):
- maxDepth=0
- firstStr=list(myTree.keys())[0]
- secondDict=myTree[firstStr]
- for key in secondDict.keys():
- if type(secondDict[key]).__name__=='dict':
- thisDepth=1+getTreeDepth(secondDict[key])
- else:
- thisDepth=1
- if thisDepth>maxDepth:maxDepth=thisDepth
- return maxDepth
-
-
- #计算出父节点和子节点的中间位置,填充信息
- def plotMidText(cntrPt,parentPt,txtString):
- xMid=(parentPt[0]-cntrPt[0])/2.0+cntrPt[0]
- yMid=(parentPt[1]-cntrPt[1])/2.0+cntrPt[1]
- createPlot.axl.text(xMid,yMid,txtString)
- #绘制出树的所有节点,递归绘制
- def plotTree(mytree,parentPt,nodeTxt):
- numLeafs=getNumLeafs(mytree)
- depth=getTreeDepth(mytree)
- firstStr=list(mytree.keys())[0]
- cntrPt=(plotTree.xOff+(1.0+float(numLeafs))/2.0/plotTree.totalW,plotTree.yOff)
- plotMidText(cntrPt,parentPt,nodeTxt)
- plotNode(firstStr,cntrPt,parentPt,decisionNode)
- secondDict=mytree[firstStr]
- plotTree.yOff=plotTree.yOff-1.0/plotTree.totalD
- for key in secondDict.keys():
- if type(secondDict[key]).__name__=='dict':
- plotTree(secondDict[key],cntrPt,str(key))
- else:
- plotTree.xOff=plotTree.xOff+1.0/plotTree.totalW
- plotNode(secondDict[key],(plotTree.xOff,plotTree.yOff),cntrPt,leafNode)
- plotMidText((plotTree.xOff,plotTree.yOff),cntrPt,str(key))
- plotTree.yOff=plotTree.yOff+1.0/plotTree.totalD
-
- #绘制决策树
- def createPlot(inTree):
- fig=plt.figure(1,facecolor='white')
- fig.clf()
- axprops=dict(xticks=[],yticks=[])
- createPlot.axl=plt.subplot(111,frameon=False,**axprops)
- plotTree.totalW=float(getNumLeafs(inTree))
- plotTree.totalD=float(getTreeDepth(inTree))
- plotTree.xOff=-0.5/plotTree.totalW;plotTree.yOff=1.0;plotTree(inTree,(0.5,1.0),'')
- plt.show()
-
-

- if __name__ == '__main__':
- mytree=createTree(data,labels)
- createPlot(mytree)
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。