当前位置:   article > 正文

决策树的剪枝_决策树剪枝

决策树剪枝

目录

一、为什么要剪枝

二、剪枝的策略

1、预剪枝(pre-pruning)

2、后剪枝(post-pruning)

三、代码实现

1、收集、准备数据:

2、分析数据:

3、预剪枝及测试:

 4、后剪枝及测试:

四、总结


一、为什么要剪枝

剪枝(pruning)的目的是为了避免决策树模型的过拟合。因为决策树算法在学习的过程中为了尽可能的正确的分类训练样本,不停地对结点进行划分,因此这会导致整棵树的分支过多,也就导致了过拟合。

可通过“剪枝”来一定程度避免因决策分支过多,以致于把训练集 自身的一些特点当做所有数据都具有的一般性质而导致的过拟合。

二、剪枝的策略

决策树的剪枝策略最基本的有两种:预剪枝(pre-pruning)和后剪枝(post-pruning)

1、预剪枝(pre-pruning)

预剪枝就是在构造决策树的过程中,先对每个结点在划分前进行估计,若果当前结点的划分不能带来决策树模型泛华性能的提升,则不对当前结点进行划分并且将当前结点标记为叶结点。
数据集:

 预剪枝:

 

预剪枝的优缺点
•优点
        –降低过拟合风险
        –显著减少训练时间和测试时间开销。
•缺点
        –欠拟合风险 :有些分支的当前划分虽然不能提升泛化性能,但 在其基础上进行的后续划分却有可能显著提高性能。预剪枝基于 “ 贪心 ”本质禁止这些分支展开,带来了欠拟合风险。

2、后剪枝(post-pruning)

后剪枝就是先把整颗决策树构造完毕,然后自底向上的对非叶结点进行考察,若将该结点对应的子树换为叶结点能够带来泛华性能的提升,则把该子树替换为叶结点。

 

后剪枝处理:

 

 

 

 

后剪枝的优缺点
•优点
        后剪枝比预剪枝保留了更多的分支, 欠拟合风险小 泛化性能往往优于预剪枝决策树
•缺点
        训练时间开销大 :后剪枝过程是在生成完全决策树 之后进行的,需要自底向上对所有非叶结点逐一计算

三、代码实现

1、收集、准备数据:

这里采用上面西瓜2.0的数据集:

  1. import math
  2. import numpy as np
  3. def createMyData():
  4. data = np.array([['青绿', '蜷缩', '浊响', '清晰', '凹陷', '硬滑']
  5. , ['乌黑', '蜷缩', '沉闷', '清晰', '凹陷', '硬滑']
  6. , ['乌黑', '蜷缩', '浊响', '清晰', '凹陷', '硬滑']
  7. , ['青绿', '蜷缩', '沉闷', '清晰', '凹陷', '硬滑']
  8. , ['浅白', '蜷缩', '浊响', '清晰', '凹陷', '硬滑']
  9. , ['青绿', '稍蜷', '浊响', '清晰', '稍凹', '软粘']
  10. , ['乌黑', '稍蜷', '浊响', '稍糊', '稍凹', '软粘']
  11. , ['乌黑', '稍蜷', '浊响', '清晰', '稍凹', '硬滑']
  12. , ['乌黑', '稍蜷', '沉闷', '稍糊', '稍凹', '硬滑']
  13. , ['青绿', '硬挺', '清脆', '清晰', '平坦', '软粘']
  14. , ['浅白', '硬挺', '清脆', '模糊', '平坦', '硬滑']
  15. , ['浅白', '蜷缩', '浊响', '模糊', '平坦', '软粘']
  16. , ['青绿', '稍蜷', '浊响', '稍糊', '凹陷', '硬滑']
  17. , ['浅白', '稍蜷', '沉闷', '稍糊', '凹陷', '硬滑']
  18. , ['乌黑', '稍蜷', '浊响', '清晰', '稍凹', '软粘']
  19. , ['浅白', '蜷缩', '浊响', '模糊', '平坦', '硬滑']
  20. , ['青绿', '蜷缩', '沉闷', '稍糊', '稍凹', '硬滑']])
  21. label = np.array(['是', '是', '是', '是', '是', '是', '是', '是', '否', '否', '否', '否', '否', '否', '否', '否', '否'])
  22. name = np.array(['色泽', '根蒂', '敲声', '纹理', '脐部', '触感'])
  23. return data, label, name
  24. def splitMyData20(myData, myLabel):
  25. myDataTrain = myData[[0, 1, 2, 5, 6, 9, 13, 14, 15, 16],:]
  26. myDataTest = myData[[3, 4, 7, 8, 10, 11, 12],:]
  27. myLabelTrain = myLabel[[0, 1, 2, 5, 6, 9, 13, 14, 15, 16]]
  28. myLabelTest = myLabel[[3, 4, 7, 8, 10, 11, 12]]
  29. return myDataTrain, myLabelTrain, myDataTest, myLabelTest

2、分析数据:

  1. equalNums = lambda x,y: 0 if x is None else x[x==y].size
  2. # 定义计算信息熵的函数
  3. def singleEntropy(x):
  4. x = np.asarray(x)
  5. xValues = set(x)
  6. entropy = 0
  7. for xValue in xValues:
  8. p = equalNums(x, xValue) / x.size
  9. entropy -= p * math.log(p, 2)
  10. return entropy
  11. # 定义计算条件信息熵的函数
  12. def conditionnalEntropy(feature, y):
  13. feature = np.asarray(feature)
  14. y = np.asarray(y)
  15. featureValues = set(feature)
  16. entropy = 0
  17. for feat in featureValues:
  18. p = equalNums(feature, feat) / feature.size
  19. entropy += p * singleEntropy(y[feature == feat])
  20. return entropy
  21. # 定义信息增益
  22. def infoGain(feature, y):
  23. return singleEntropy(y) - conditionnalEntropy(feature, y)
  24. # 定义信息增益率
  25. def infoGainRatio(feature, y):
  26. return 0 if singleEntropy(feature) == 0 else infoGain(feature, y) / singleEntropy(feature)
  1. # 特征选取
  2. def bestFeature(data, labels, method = 'id3'):
  3. assert method in ['id3', 'c45'], "method 须为id3或c45"
  4. data = np.asarray(data)
  5. labels = np.asarray(labels)
  6. # 根据输入的method选取 评估特征的方法:id3 -> 信息增益; c45 -> 信息增益率
  7. def calcEnt(feature, labels):
  8. if method == 'id3':
  9. return infoGain(feature, labels)
  10. elif method == 'c45' :
  11. return infoGainRatio(feature, labels)
  12. featureNum = data.shape[1]
  13. bestEnt = 0
  14. bestFeat = -1
  15. for feature in range(featureNum):
  16. ent = calcEnt(data[:, feature], labels)
  17. if ent >= bestEnt:
  18. bestEnt = ent
  19. bestFeat = feature
  20. return bestFeat, bestEnt
  21. # 根据特征及特征值分割原数据集
  22. def splitFeatureData(data, labels, feature):
  23. features = np.asarray(data)[:,feature]
  24. data = np.delete(np.asarray(data), feature, axis = 1)
  25. labels = np.asarray(labels)
  26. uniqFeatures = set(features)
  27. dataSet = {}
  28. labelSet = {}
  29. for feat in uniqFeatures:
  30. dataSet[feat] = data[features == feat]
  31. labelSet[feat] = labels[features == feat]
  32. return dataSet, labelSet
  33. # 多数投票
  34. def voteLabel(labels):
  35. uniqLabels = list(set(labels))
  36. labels = np.asarray(labels)
  37. finalLabel = 0
  38. labelNum = []
  39. for label in uniqLabels:
  40. labelNum.append(equalNums(labels, label))
  41. return uniqLabels[labelNum.index(max(labelNum))]
  42. # 创建决策树
  43. def createTree(data, labels, names, method = 'id3'):
  44. data = np.asarray(data)
  45. labels = np.asarray(labels)
  46. names = np.asarray(names)
  47. if len(set(labels)) == 1:
  48. return labels[0]
  49. elif data.size == 0:
  50. return voteLabel(labels)
  51. bestFeat, bestEnt = bestFeature(data, labels, method = method)
  52. bestFeatName = names[bestFeat]
  53. names = np.delete(names, [bestFeat])
  54. decisionTree = {bestFeatName: {}}
  55. dataSet, labelSet = splitFeatureData(data, labels, bestFeat)
  56. for featValue in dataSet.keys():
  57. decisionTree[bestFeatName][featValue] = createTree(dataSet.get(featValue), labelSet.get(featValue), names, method)
  58. return decisionTree
  59. # 统计叶子节点数和树深度
  60. def getTreeSize(decisionTree):
  61. nodeName = list(decisionTree.keys())[0]
  62. nodeValue = decisionTree[nodeName]
  63. leafNum = 0
  64. treeDepth = 0
  65. leafDepth = 0
  66. for val in nodeValue.keys():
  67. if type(nodeValue[val]) == dict:
  68. leafNum += getTreeSize(nodeValue[val])[0]
  69. leafDepth = 1 + getTreeSize(nodeValue[val])[1]
  70. else :
  71. leafNum += 1
  72. leafDepth = 1
  73. treeDepth = max(treeDepth, leafDepth)
  74. return leafNum, treeDepth
  75. # 使用模型对其他数据分类
  76. def dtClassify(decisionTree, rowData, names):
  77. names = list(names)
  78. feature = list(decisionTree.keys())[0]
  79. featDict = decisionTree[feature]
  80. feat = names.index(feature)
  81. featVal = rowData[feat]
  82. if featVal in featDict.keys():
  83. if type(featDict[featVal]) == dict:
  84. classLabel = dtClassify(featDict[featVal], rowData, names)
  85. else:
  86. classLabel = featDict[featVal]
  87. return classLabel

使用Matplotlib注解绘制树形图:

  1. import matplotlib.pyplot as plt
  2. #定义文本框和箭头格式
  3. decisionNode=dict(boxstyle="sawtooth",fc='0.8')
  4. leafNode=dict(boxstyle="round4",fc='0.8')
  5. arrow_args=dict(arrowstyle="<-")
  6. #绘制带箭头的注释
  7. def plotNode(nodeTxt,centerPt,parentPt,nodeType):
  8. createPlot.axl.annotate(nodeTxt,xy=parentPt,xycoords='axes fraction',
  9. xytext=centerPt,
  10. textcoords='axes fraction',
  11. va="center",ha="center",bbox=nodeType,arrowprops=arrow_args)
  12. #获取叶节点的数目和树的层数
  13. def getNumLeafs(myTree):
  14. numLeafs=0
  15. # firstStr=myTree.keys()[0]
  16. firstStr=list(myTree.keys())[0]
  17. secondDict=myTree[firstStr]
  18. for key in secondDict.keys():
  19. if type(secondDict[key]).__name__=='dict':
  20. numLeafs+=getNumLeafs(secondDict[key])
  21. else:
  22. numLeafs+=1
  23. return numLeafs
  24. def getTreeDepth(myTree):
  25. maxDepth=0
  26. firstStr=list(myTree.keys())[0]
  27. secondDict=myTree[firstStr]
  28. for key in secondDict.keys():
  29. if type(secondDict[key]).__name__=='dict':
  30. thisDepth=1+getTreeDepth(secondDict[key])
  31. else:
  32. thisDepth=1
  33. if thisDepth>maxDepth:maxDepth=thisDepth
  34. return maxDepth
  35. def plotMidText(cntrPt,parentPt,txtString):
  36. xMid=(parentPt[0]-cntrPt[0])/2.0+cntrPt[0]
  37. yMid=(parentPt[1]-cntrPt[1])/2.0+cntrPt[1]
  38. createPlot.axl.text(xMid,yMid,txtString)
  39. def plotTree(mytree,parentPt,nodeTxt):
  40. numLeafs=getNumLeafs(mytree)
  41. depth=getTreeDepth(mytree)
  42. firstStr=list(mytree.keys())[0]
  43. cntrPt=(plotTree.xOff+(1.0+float(numLeafs))/2.0/plotTree.totalW,plotTree.yOff)
  44. plotMidText(cntrPt,parentPt,nodeTxt)
  45. plotNode(firstStr,cntrPt,parentPt,decisionNode)
  46. secondDict=mytree[firstStr]
  47. plotTree.yOff=plotTree.yOff-1.0/plotTree.totalD
  48. for key in secondDict.keys():
  49. if type(secondDict[key]).__name__=='dict':
  50. plotTree(secondDict[key],cntrPt,str(key))
  51. else:
  52. plotTree.xOff=plotTree.xOff+1.0/plotTree.totalW
  53. plotNode(secondDict[key],(plotTree.xOff,plotTree.yOff),cntrPt,leafNode)
  54. plotMidText((plotTree.xOff,plotTree.yOff),cntrPt,str(key))
  55. plotTree.yOff=plotTree.yOff+1.0/plotTree.totalD
  56. def createPlot(inTree):
  57. fig=plt.figure(1,facecolor='white')
  58. fig.clf()
  59. axprops=dict(xticks=[],yticks=[])
  60. createPlot.axl=plt.subplot(111,frameon=False,**axprops)
  61. plotTree.totalW=float(getNumLeafs(inTree))
  62. plotTree.totalD=float(getTreeDepth(inTree))
  63. plotTree.xOff=-0.5/plotTree.totalW;plotTree.yOff=1.0;plotTree(inTree,(0.5,1.0),'')
  64. plt.show()

3、预剪枝及测试:

  1. # 创建预剪枝决策树
  2. def createTreePrePruning(dataTrain, labelTrain, dataTest, labelTest, names, method = 'id3'):
  3. trainData = np.asarray(dataTrain)
  4. labelTrain = np.asarray(labelTrain)
  5. testData = np.asarray(dataTest)
  6. labelTest = np.asarray(labelTest)
  7. names = np.asarray(names)
  8. if len(set(labelTrain)) == 1:
  9. return labelTrain[0]
  10. elif trainData.size == 0:
  11. return voteLabel(labelTrain)
  12. bestFeat, bestEnt = bestFeature(dataTrain, labelTrain, method = method)
  13. bestFeatName = names[bestFeat]
  14. names = np.delete(names, [bestFeat])
  15. dataTrainSet, labelTrainSet = splitFeatureData(dataTrain, labelTrain, bestFeat)
  16. labelTrainLabelPre = voteLabel(labelTrain)
  17. labelTrainRatioPre = equalNums(labelTrain, labelTrainLabelPre) / labelTrain.size
  18. if dataTest is not None:
  19. dataTestSet, labelTestSet = splitFeatureData(dataTest, labelTest, bestFeat)
  20. labelTestRatioPre = equalNums(labelTest, labelTrainLabelPre) / labelTest.size
  21. labelTrainEqNumPost = 0
  22. for val in labelTrainSet.keys():
  23. labelTrainEqNumPost += equalNums(labelTestSet.get(val), voteLabel(labelTrainSet.get(val))) + 0.0
  24. labelTestRatioPost = labelTrainEqNumPost / labelTest.size
  25. if dataTest is None and labelTrainRatioPre == 0.5:
  26. decisionTree = {bestFeatName: {}}
  27. for featValue in dataTrainSet.keys():
  28. decisionTree[bestFeatName][featValue] = createTreePrePruning(dataTrainSet.get(featValue), labelTrainSet.get(featValue)
  29. , None, None, names, method)
  30. elif dataTest is None:
  31. return labelTrainLabelPre
  32. elif labelTestRatioPost < labelTestRatioPre:
  33. return labelTrainLabelPre
  34. else :
  35. decisionTree = {bestFeatName: {}}
  36. for featValue in dataTrainSet.keys():
  37. decisionTree[bestFeatName][featValue] = createTreePrePruning(dataTrainSet.get(featValue), labelTrainSet.get(featValue)
  38. , dataTestSet.get(featValue), labelTestSet.get(featValue)
  39. , names, method)
  40. return decisionTree
  1. myDataTrain, myLabelTrain, myDataTest, myLabelTest = splitMyData20(myData, myLabel)
  2. myTreeTrain = createTree(myDataTrain, myLabelTrain, myName, method = 'id3')
  3. myTreePrePruning = createTreePrePruning(myDataTrain, myLabelTrain, myDataTest, myLabelTest, myName, method = 'id3')
  4. # 画剪枝前的树
  5. print("剪枝前的树")
  6. createPlot(myTreeTrain)
  7. # 画剪枝后的树
  8. print("剪枝后的树")
  9. createPlot(myTreePrePruning)

输出结果:

 

 

 4、后剪枝及测试:

  1. # 创建决策树 带预划分标签
  2. def createTreeWithLabel(data, labels, names, method = 'id3'):
  3. data = np.asarray(data)
  4. labels = np.asarray(labels)
  5. names = np.asarray(names)
  6. votedLabel = voteLabel(labels)
  7. if len(set(labels)) == 1:
  8. return votedLabel
  9. elif data.size == 0:
  10. return votedLabel
  11. bestFeat, bestEnt = bestFeature(data, labels, method = method)
  12. bestFeatName = names[bestFeat]
  13. names = np.delete(names, [bestFeat])
  14. decisionTree = {bestFeatName: {"_vpdl": votedLabel}}
  15. dataSet, labelSet = splitFeatureData(data, labels, bestFeat)
  16. for featValue in dataSet.keys():
  17. decisionTree[bestFeatName][featValue] = createTreeWithLabel(dataSet.get(featValue), labelSet.get(featValue), names, method)
  18. return decisionTree
  19. def convertTree(labeledTree):
  20. labeledTreeNew = labeledTree.copy()
  21. nodeName = list(labeledTree.keys())[0]
  22. labeledTreeNew[nodeName] = labeledTree[nodeName].copy()
  23. for val in list(labeledTree[nodeName].keys()):
  24. if val == "_vpdl":
  25. labeledTreeNew[nodeName].pop(val)
  26. elif type(labeledTree[nodeName][val]) == dict:
  27. labeledTreeNew[nodeName][val] = convertTree(labeledTree[nodeName][val])
  28. return labeledTreeNew
  29. # 后剪枝 训练完成后决策节点进行替换评估
  30. def treePostPruning(labeledTree, dataTest, labelTest, names):
  31. newTree = labeledTree.copy()
  32. dataTest = np.asarray(dataTest)
  33. labelTest = np.asarray(labelTest)
  34. names = np.asarray(names)
  35. featName = list(labeledTree.keys())[0]
  36. featCol = np.argwhere(names==featName)[0][0]
  37. names = np.delete(names, [featCol])
  38. newTree[featName] = labeledTree[featName].copy()
  39. featValueDict = newTree[featName]
  40. featPreLabel = featValueDict.pop("_vpdl")
  41. subTreeFlag = 0
  42. dataFlag = 1 if sum(dataTest.shape) > 0 else 0
  43. if dataFlag == 1:
  44. dataTestSet, labelTestSet = splitFeatureData(dataTest, labelTest, featCol)
  45. for featValue in featValueDict.keys():
  46. if dataFlag == 1 and type(featValueDict[featValue]) == dict:
  47. subTreeFlag = 1
  48. newTree[featName][featValue] = treePostPruning(featValueDict[featValue], dataTestSet.get(featValue), labelTestSet.get(featValue), names)
  49. if type(featValueDict[featValue]) != dict:
  50. subTreeFlag = 0
  51. if dataFlag == 0 and type(featValueDict[featValue]) == dict:
  52. subTreeFlag = 1
  53. newTree[featName][featValue] = convertTree(featValueDict[featValue])
  54. if subTreeFlag == 0:
  55. ratioPreDivision = equalNums(labelTest, featPreLabel) / labelTest.size
  56. equalNum = 0
  57. for val in labelTestSet.keys():
  58. equalNum += equalNums(labelTestSet[val], featValueDict[val])
  59. ratioAfterDivision = equalNum / labelTest.size
  60. if ratioAfterDivision < ratioPreDivision:
  61. newTree = featPreLabel
  62. return newTree
  1. myTreeTrain1 = createTreeWithLabel(myDataTrain, myLabelTrain, myName, method = 'id3')
  2. createPlot(myTreeTrain1)
  3. print(myTreeTrain1)
  1. xgTreeBeforePostPruning = {"脐部": {"_vpdl": "是"
  2. , '凹陷': {'色泽':{"_vpdl": "是", '青绿': '是', '乌黑': '是', '浅白': '否'}}
  3. , '稍凹': {'根蒂':{"_vpdl": "是"
  4. , '稍蜷': {'色泽': {"_vpdl": "是"
  5. , '青绿': '是'
  6. , '乌黑': {'纹理': {"_vpdl": "是"
  7. , '稍糊': '是', '清晰': '否', '模糊': '是'}}
  8. , '浅白': '是'}}
  9. , '蜷缩': '否'
  10. , '硬挺': '是'}}
  11. , '平坦': '否'}}
  12. xgTreePostPruning = treePostPruning(xgTreeBeforePostPruning, xgDataTest, xgLabelTest, xgName)
  13. createPlot(convertTree(xgTreeBeforePostPruning))
  14. createPlot(xgTreePostPruning)

 结果:

 

 

 

四、总结

对比预剪枝与后剪枝生成的决策树,可以看出,后剪枝通常比预剪枝保留更多的分支,其欠拟合风险很小,因此后剪枝的泛化性能往往由于预剪枝决策树。但后剪枝过程是从底往上裁剪,因此其训练时间开销比前剪枝要大。

声明:本文内容由网友自发贡献,转载请注明出处:【wpsshop】
推荐阅读
相关标签
  

闽ICP备14008679号