赞
踩
目录
剪枝(pruning)的目的是为了避免决策树模型的过拟合。因为决策树算法在学习的过程中为了尽可能的正确的分类训练样本,不停地对结点进行划分,因此这会导致整棵树的分支过多,也就导致了过拟合。
决策树的剪枝策略最基本的有两种:预剪枝(pre-pruning)和后剪枝(post-pruning)
预剪枝:
后剪枝就是先把整颗决策树构造完毕,然后自底向上的对非叶结点进行考察,若将该结点对应的子树换为叶结点能够带来泛华性能的提升,则把该子树替换为叶结点。
后剪枝处理:
这里采用上面西瓜2.0的数据集:
- import math
- import numpy as np
-
- def createMyData():
- data = np.array([['青绿', '蜷缩', '浊响', '清晰', '凹陷', '硬滑']
- , ['乌黑', '蜷缩', '沉闷', '清晰', '凹陷', '硬滑']
- , ['乌黑', '蜷缩', '浊响', '清晰', '凹陷', '硬滑']
- , ['青绿', '蜷缩', '沉闷', '清晰', '凹陷', '硬滑']
- , ['浅白', '蜷缩', '浊响', '清晰', '凹陷', '硬滑']
- , ['青绿', '稍蜷', '浊响', '清晰', '稍凹', '软粘']
- , ['乌黑', '稍蜷', '浊响', '稍糊', '稍凹', '软粘']
- , ['乌黑', '稍蜷', '浊响', '清晰', '稍凹', '硬滑']
- , ['乌黑', '稍蜷', '沉闷', '稍糊', '稍凹', '硬滑']
- , ['青绿', '硬挺', '清脆', '清晰', '平坦', '软粘']
- , ['浅白', '硬挺', '清脆', '模糊', '平坦', '硬滑']
- , ['浅白', '蜷缩', '浊响', '模糊', '平坦', '软粘']
- , ['青绿', '稍蜷', '浊响', '稍糊', '凹陷', '硬滑']
- , ['浅白', '稍蜷', '沉闷', '稍糊', '凹陷', '硬滑']
- , ['乌黑', '稍蜷', '浊响', '清晰', '稍凹', '软粘']
- , ['浅白', '蜷缩', '浊响', '模糊', '平坦', '硬滑']
- , ['青绿', '蜷缩', '沉闷', '稍糊', '稍凹', '硬滑']])
- label = np.array(['是', '是', '是', '是', '是', '是', '是', '是', '否', '否', '否', '否', '否', '否', '否', '否', '否'])
- name = np.array(['色泽', '根蒂', '敲声', '纹理', '脐部', '触感'])
- return data, label, name
-
- def splitMyData20(myData, myLabel):
- myDataTrain = myData[[0, 1, 2, 5, 6, 9, 13, 14, 15, 16],:]
- myDataTest = myData[[3, 4, 7, 8, 10, 11, 12],:]
- myLabelTrain = myLabel[[0, 1, 2, 5, 6, 9, 13, 14, 15, 16]]
- myLabelTest = myLabel[[3, 4, 7, 8, 10, 11, 12]]
- return myDataTrain, myLabelTrain, myDataTest, myLabelTest
- equalNums = lambda x,y: 0 if x is None else x[x==y].size
-
- # 定义计算信息熵的函数
- def singleEntropy(x):
- x = np.asarray(x)
- xValues = set(x)
- entropy = 0
- for xValue in xValues:
- p = equalNums(x, xValue) / x.size
- entropy -= p * math.log(p, 2)
- return entropy
-
-
- # 定义计算条件信息熵的函数
- def conditionnalEntropy(feature, y):
- feature = np.asarray(feature)
- y = np.asarray(y)
- featureValues = set(feature)
- entropy = 0
- for feat in featureValues:
- p = equalNums(feature, feat) / feature.size
- entropy += p * singleEntropy(y[feature == feat])
- return entropy
-
-
- # 定义信息增益
- def infoGain(feature, y):
- return singleEntropy(y) - conditionnalEntropy(feature, y)
-
-
- # 定义信息增益率
- def infoGainRatio(feature, y):
- return 0 if singleEntropy(feature) == 0 else infoGain(feature, y) / singleEntropy(feature)
-
- # 特征选取
- def bestFeature(data, labels, method = 'id3'):
- assert method in ['id3', 'c45'], "method 须为id3或c45"
- data = np.asarray(data)
- labels = np.asarray(labels)
- # 根据输入的method选取 评估特征的方法:id3 -> 信息增益; c45 -> 信息增益率
- def calcEnt(feature, labels):
- if method == 'id3':
- return infoGain(feature, labels)
- elif method == 'c45' :
- return infoGainRatio(feature, labels)
- featureNum = data.shape[1]
- bestEnt = 0
- bestFeat = -1
- for feature in range(featureNum):
- ent = calcEnt(data[:, feature], labels)
- if ent >= bestEnt:
- bestEnt = ent
- bestFeat = feature
- return bestFeat, bestEnt
-
-
- # 根据特征及特征值分割原数据集
- def splitFeatureData(data, labels, feature):
- features = np.asarray(data)[:,feature]
- data = np.delete(np.asarray(data), feature, axis = 1)
- labels = np.asarray(labels)
-
- uniqFeatures = set(features)
- dataSet = {}
- labelSet = {}
- for feat in uniqFeatures:
- dataSet[feat] = data[features == feat]
- labelSet[feat] = labels[features == feat]
- return dataSet, labelSet
-
-
- # 多数投票
- def voteLabel(labels):
- uniqLabels = list(set(labels))
- labels = np.asarray(labels)
-
- finalLabel = 0
- labelNum = []
- for label in uniqLabels:
- labelNum.append(equalNums(labels, label))
- return uniqLabels[labelNum.index(max(labelNum))]
-
-
- # 创建决策树
- def createTree(data, labels, names, method = 'id3'):
- data = np.asarray(data)
- labels = np.asarray(labels)
- names = np.asarray(names)
- if len(set(labels)) == 1:
- return labels[0]
- elif data.size == 0:
- return voteLabel(labels)
- bestFeat, bestEnt = bestFeature(data, labels, method = method)
- bestFeatName = names[bestFeat]
- names = np.delete(names, [bestFeat])
- decisionTree = {bestFeatName: {}}
- dataSet, labelSet = splitFeatureData(data, labels, bestFeat)
- for featValue in dataSet.keys():
- decisionTree[bestFeatName][featValue] = createTree(dataSet.get(featValue), labelSet.get(featValue), names, method)
- return decisionTree
-
-
- # 统计叶子节点数和树深度
- def getTreeSize(decisionTree):
- nodeName = list(decisionTree.keys())[0]
- nodeValue = decisionTree[nodeName]
- leafNum = 0
- treeDepth = 0
- leafDepth = 0
- for val in nodeValue.keys():
- if type(nodeValue[val]) == dict:
- leafNum += getTreeSize(nodeValue[val])[0]
- leafDepth = 1 + getTreeSize(nodeValue[val])[1]
- else :
- leafNum += 1
- leafDepth = 1
- treeDepth = max(treeDepth, leafDepth)
- return leafNum, treeDepth
-
-
- # 使用模型对其他数据分类
- def dtClassify(decisionTree, rowData, names):
- names = list(names)
- feature = list(decisionTree.keys())[0]
- featDict = decisionTree[feature]
- feat = names.index(feature)
- featVal = rowData[feat]
- if featVal in featDict.keys():
- if type(featDict[featVal]) == dict:
- classLabel = dtClassify(featDict[featVal], rowData, names)
- else:
- classLabel = featDict[featVal]
- return classLabel
使用Matplotlib注解绘制树形图:
- import matplotlib.pyplot as plt
-
- #定义文本框和箭头格式
- decisionNode=dict(boxstyle="sawtooth",fc='0.8')
- leafNode=dict(boxstyle="round4",fc='0.8')
- arrow_args=dict(arrowstyle="<-")
-
- #绘制带箭头的注释
- def plotNode(nodeTxt,centerPt,parentPt,nodeType):
- createPlot.axl.annotate(nodeTxt,xy=parentPt,xycoords='axes fraction',
- xytext=centerPt,
- textcoords='axes fraction',
- va="center",ha="center",bbox=nodeType,arrowprops=arrow_args)
-
- #获取叶节点的数目和树的层数
- def getNumLeafs(myTree):
- numLeafs=0
- # firstStr=myTree.keys()[0]
- firstStr=list(myTree.keys())[0]
- secondDict=myTree[firstStr]
- for key in secondDict.keys():
- if type(secondDict[key]).__name__=='dict':
- numLeafs+=getNumLeafs(secondDict[key])
- else:
- numLeafs+=1
- return numLeafs
-
- def getTreeDepth(myTree):
- maxDepth=0
- firstStr=list(myTree.keys())[0]
- secondDict=myTree[firstStr]
- for key in secondDict.keys():
- if type(secondDict[key]).__name__=='dict':
- thisDepth=1+getTreeDepth(secondDict[key])
- else:
- thisDepth=1
- if thisDepth>maxDepth:maxDepth=thisDepth
- return maxDepth
-
-
-
- def plotMidText(cntrPt,parentPt,txtString):
- xMid=(parentPt[0]-cntrPt[0])/2.0+cntrPt[0]
- yMid=(parentPt[1]-cntrPt[1])/2.0+cntrPt[1]
- createPlot.axl.text(xMid,yMid,txtString)
-
- def plotTree(mytree,parentPt,nodeTxt):
- numLeafs=getNumLeafs(mytree)
- depth=getTreeDepth(mytree)
- firstStr=list(mytree.keys())[0]
- cntrPt=(plotTree.xOff+(1.0+float(numLeafs))/2.0/plotTree.totalW,plotTree.yOff)
- plotMidText(cntrPt,parentPt,nodeTxt)
- plotNode(firstStr,cntrPt,parentPt,decisionNode)
- secondDict=mytree[firstStr]
- plotTree.yOff=plotTree.yOff-1.0/plotTree.totalD
- for key in secondDict.keys():
- if type(secondDict[key]).__name__=='dict':
- plotTree(secondDict[key],cntrPt,str(key))
- else:
- plotTree.xOff=plotTree.xOff+1.0/plotTree.totalW
- plotNode(secondDict[key],(plotTree.xOff,plotTree.yOff),cntrPt,leafNode)
- plotMidText((plotTree.xOff,plotTree.yOff),cntrPt,str(key))
- plotTree.yOff=plotTree.yOff+1.0/plotTree.totalD
-
-
- def createPlot(inTree):
- fig=plt.figure(1,facecolor='white')
- fig.clf()
- axprops=dict(xticks=[],yticks=[])
- createPlot.axl=plt.subplot(111,frameon=False,**axprops)
- plotTree.totalW=float(getNumLeafs(inTree))
- plotTree.totalD=float(getTreeDepth(inTree))
- plotTree.xOff=-0.5/plotTree.totalW;plotTree.yOff=1.0;plotTree(inTree,(0.5,1.0),'')
- plt.show()
-
- # 创建预剪枝决策树
- def createTreePrePruning(dataTrain, labelTrain, dataTest, labelTest, names, method = 'id3'):
- trainData = np.asarray(dataTrain)
- labelTrain = np.asarray(labelTrain)
- testData = np.asarray(dataTest)
- labelTest = np.asarray(labelTest)
- names = np.asarray(names)
- if len(set(labelTrain)) == 1:
- return labelTrain[0]
- elif trainData.size == 0:
- return voteLabel(labelTrain)
- bestFeat, bestEnt = bestFeature(dataTrain, labelTrain, method = method)
- bestFeatName = names[bestFeat]
- names = np.delete(names, [bestFeat])
- dataTrainSet, labelTrainSet = splitFeatureData(dataTrain, labelTrain, bestFeat)
-
-
- labelTrainLabelPre = voteLabel(labelTrain)
- labelTrainRatioPre = equalNums(labelTrain, labelTrainLabelPre) / labelTrain.size
- if dataTest is not None:
- dataTestSet, labelTestSet = splitFeatureData(dataTest, labelTest, bestFeat)
- labelTestRatioPre = equalNums(labelTest, labelTrainLabelPre) / labelTest.size
- labelTrainEqNumPost = 0
- for val in labelTrainSet.keys():
- labelTrainEqNumPost += equalNums(labelTestSet.get(val), voteLabel(labelTrainSet.get(val))) + 0.0
- labelTestRatioPost = labelTrainEqNumPost / labelTest.size
-
- if dataTest is None and labelTrainRatioPre == 0.5:
- decisionTree = {bestFeatName: {}}
- for featValue in dataTrainSet.keys():
- decisionTree[bestFeatName][featValue] = createTreePrePruning(dataTrainSet.get(featValue), labelTrainSet.get(featValue)
- , None, None, names, method)
- elif dataTest is None:
- return labelTrainLabelPre
- elif labelTestRatioPost < labelTestRatioPre:
- return labelTrainLabelPre
- else :
- decisionTree = {bestFeatName: {}}
- for featValue in dataTrainSet.keys():
- decisionTree[bestFeatName][featValue] = createTreePrePruning(dataTrainSet.get(featValue), labelTrainSet.get(featValue)
- , dataTestSet.get(featValue), labelTestSet.get(featValue)
- , names, method)
- return decisionTree
-
- myDataTrain, myLabelTrain, myDataTest, myLabelTest = splitMyData20(myData, myLabel)
- myTreeTrain = createTree(myDataTrain, myLabelTrain, myName, method = 'id3')
- myTreePrePruning = createTreePrePruning(myDataTrain, myLabelTrain, myDataTest, myLabelTest, myName, method = 'id3')
- # 画剪枝前的树
- print("剪枝前的树")
- createPlot(myTreeTrain)
- # 画剪枝后的树
- print("剪枝后的树")
- createPlot(myTreePrePruning)
输出结果:
- # 创建决策树 带预划分标签
- def createTreeWithLabel(data, labels, names, method = 'id3'):
- data = np.asarray(data)
- labels = np.asarray(labels)
- names = np.asarray(names)
- votedLabel = voteLabel(labels)
- if len(set(labels)) == 1:
- return votedLabel
- elif data.size == 0:
- return votedLabel
- bestFeat, bestEnt = bestFeature(data, labels, method = method)
- bestFeatName = names[bestFeat]
- names = np.delete(names, [bestFeat])
- decisionTree = {bestFeatName: {"_vpdl": votedLabel}}
- dataSet, labelSet = splitFeatureData(data, labels, bestFeat)
- for featValue in dataSet.keys():
- decisionTree[bestFeatName][featValue] = createTreeWithLabel(dataSet.get(featValue), labelSet.get(featValue), names, method)
- return decisionTree
-
- def convertTree(labeledTree):
- labeledTreeNew = labeledTree.copy()
- nodeName = list(labeledTree.keys())[0]
- labeledTreeNew[nodeName] = labeledTree[nodeName].copy()
- for val in list(labeledTree[nodeName].keys()):
- if val == "_vpdl":
- labeledTreeNew[nodeName].pop(val)
- elif type(labeledTree[nodeName][val]) == dict:
- labeledTreeNew[nodeName][val] = convertTree(labeledTree[nodeName][val])
- return labeledTreeNew
-
-
- # 后剪枝 训练完成后决策节点进行替换评估
- def treePostPruning(labeledTree, dataTest, labelTest, names):
- newTree = labeledTree.copy()
- dataTest = np.asarray(dataTest)
- labelTest = np.asarray(labelTest)
- names = np.asarray(names)
- featName = list(labeledTree.keys())[0]
- featCol = np.argwhere(names==featName)[0][0]
- names = np.delete(names, [featCol])
- newTree[featName] = labeledTree[featName].copy()
- featValueDict = newTree[featName]
- featPreLabel = featValueDict.pop("_vpdl")
- subTreeFlag = 0
- dataFlag = 1 if sum(dataTest.shape) > 0 else 0
- if dataFlag == 1:
- dataTestSet, labelTestSet = splitFeatureData(dataTest, labelTest, featCol)
- for featValue in featValueDict.keys():
- if dataFlag == 1 and type(featValueDict[featValue]) == dict:
- subTreeFlag = 1
- newTree[featName][featValue] = treePostPruning(featValueDict[featValue], dataTestSet.get(featValue), labelTestSet.get(featValue), names)
- if type(featValueDict[featValue]) != dict:
- subTreeFlag = 0
-
-
- if dataFlag == 0 and type(featValueDict[featValue]) == dict:
- subTreeFlag = 1
- newTree[featName][featValue] = convertTree(featValueDict[featValue])
-
- if subTreeFlag == 0:
- ratioPreDivision = equalNums(labelTest, featPreLabel) / labelTest.size
- equalNum = 0
- for val in labelTestSet.keys():
- equalNum += equalNums(labelTestSet[val], featValueDict[val])
- ratioAfterDivision = equalNum / labelTest.size
- if ratioAfterDivision < ratioPreDivision:
- newTree = featPreLabel
- return newTree
-
-
- myTreeTrain1 = createTreeWithLabel(myDataTrain, myLabelTrain, myName, method = 'id3')
- createPlot(myTreeTrain1)
- print(myTreeTrain1)
- xgTreeBeforePostPruning = {"脐部": {"_vpdl": "是"
- , '凹陷': {'色泽':{"_vpdl": "是", '青绿': '是', '乌黑': '是', '浅白': '否'}}
- , '稍凹': {'根蒂':{"_vpdl": "是"
- , '稍蜷': {'色泽': {"_vpdl": "是"
- , '青绿': '是'
- , '乌黑': {'纹理': {"_vpdl": "是"
- , '稍糊': '是', '清晰': '否', '模糊': '是'}}
- , '浅白': '是'}}
- , '蜷缩': '否'
- , '硬挺': '是'}}
- , '平坦': '否'}}
-
- xgTreePostPruning = treePostPruning(xgTreeBeforePostPruning, xgDataTest, xgLabelTest, xgName)
- createPlot(convertTree(xgTreeBeforePostPruning))
- createPlot(xgTreePostPruning)
-
结果:
对比预剪枝与后剪枝生成的决策树,可以看出,后剪枝通常比预剪枝保留更多的分支,其欠拟合风险很小,因此后剪枝的泛化性能往往由于预剪枝决策树。但后剪枝过程是从底往上裁剪,因此其训练时间开销比前剪枝要大。
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。