当前位置:   article > 正文

机器学习笔记(四)决策树剪枝_决策树剪枝的定义

决策树剪枝的定义

一、什么是剪枝?

顾名思义,剪枝就是指将决策树的某些内部节点下面的节点都删掉,留下来的内部决策节点作为叶子节点。

二、为什么要剪枝?

决策树是充分考虑了所有的数据点而生成的复杂树,它在学习的过程中为了尽可能的正确的分类训练样本,不停地对结点进行划分,因此这会导致整棵树的分支过多,造成决策树很庞大。决策树过于庞大,有可能出现过拟合的情况,决策树越复杂,过拟合的程度会越高。

所以,为了避免过拟合,咱们需要对决策树进行剪枝。

一般情况下,有两种剪枝策略,分别是预剪枝后剪枝

下面还是通过西瓜这个例子来讲解。

 

首先,先按照信息增益对这10个训练样本构造决策树,方法还是和上面的ID3算法提到的一样。

先计算最开始的训练样本的熵。好瓜有5个,坏瓜有5个,则信息熵为

再计算按照各个属性划分后的信息熵:

 

同理可得,

 

所以,选择脐部作为根节点。按照同样的思路,可以得到一颗为未剪枝前的决策树。

 

三、预剪枝

预剪枝就是在构造决策树的过程中,先对每个结点在划分前进行估计,如果当前结点的划分不能带来决策树模型泛化性能的提升,则不对当前结点进行划分并且将当前结点标记为叶结点。

如下图所示,在构造的时候就考虑到剪枝操作。

1) 首先,是否要按照“脐部”划分。在划分前,只有一个根节点,也是叶子节点,标记为“好瓜”。精度提高,所以按照“脐部”进行划分。

2) 当按照脐部进行划分后,会对结点 (2) 进行划分,再次使用信息增益挑选出值最大的那个特征,信息增益值最大的那个特征是“色泽”,则使用“色泽”划分后决策树为。但是,使用“色泽”划分后,编号为{5}的样本会从“好瓜”被分类为“坏瓜”,只有{4,8,11,12}被正确分类,精确度为47×100%=57.1%。所以,预剪枝操作会不再被这个节点进行划分。

3) 对于节点(3),最优属性为“根蒂”。但是,这么划分后精确度仍然是 71.4% ,所以也不会对这个节点进行操作。

预剪枝得到的决策树如下图所示。

 

优点:

降低过拟合风险

显著减少训练时间和测试时间开销

缺点:

欠拟合风险:有些分支的当前划分虽然不能提升泛化性能,但在其基础上进行的后续划分却有可能显著提高性能。预剪枝基于“贪心”本质禁止这些分支展开,带来了欠拟合风险。

四、后剪枝 

后剪枝就是先把整颗决策树构造完毕,然后自底向上的对非叶结点进行考察,若将该结点对应的子树换为叶结点能够带来泛化性能的提升,则把该子树替换为叶结点。

 

 

 

基于后剪枝策略得到的最终决策树如图所示

 

优点: 

后剪枝比预剪枝保留了更多的分支, 欠拟合风险小 泛化性能往往优于预剪枝决策树

缺点:

 

训练时间开销大 :后剪枝过程是在生成完全决策树之后进行的,需要自底向上对所有非叶结点逐一计算

五、代码实现

1.创建数据集

  1. import math
  2. import numpy as np
  3. # 创建西瓜书数据集2.0
  4. def createDataXG20():
  5. data = np.array([['青绿', '蜷缩', '浊响', '清晰', '凹陷', '硬滑']
  6. , ['乌黑', '蜷缩', '沉闷', '清晰', '凹陷', '硬滑']
  7. , ['乌黑', '蜷缩', '浊响', '清晰', '凹陷', '硬滑']
  8. , ['青绿', '蜷缩', '沉闷', '清晰', '凹陷', '硬滑']
  9. , ['浅白', '蜷缩', '浊响', '清晰', '凹陷', '硬滑']
  10. , ['青绿', '稍蜷', '浊响', '清晰', '稍凹', '软粘']
  11. , ['乌黑', '稍蜷', '浊响', '稍糊', '稍凹', '软粘']
  12. , ['乌黑', '稍蜷', '浊响', '清晰', '稍凹', '硬滑']
  13. , ['乌黑', '稍蜷', '沉闷', '稍糊', '稍凹', '硬滑']
  14. , ['青绿', '硬挺', '清脆', '清晰', '平坦', '软粘']
  15. , ['浅白', '硬挺', '清脆', '模糊', '平坦', '硬滑']
  16. , ['浅白', '蜷缩', '浊响', '模糊', '平坦', '软粘']
  17. , ['青绿', '稍蜷', '浊响', '稍糊', '凹陷', '硬滑']
  18. , ['浅白', '稍蜷', '沉闷', '稍糊', '凹陷', '硬滑']
  19. , ['乌黑', '稍蜷', '浊响', '清晰', '稍凹', '软粘']
  20. , ['浅白', '蜷缩', '浊响', '模糊', '平坦', '硬滑']
  21. , ['青绿', '蜷缩', '沉闷', '稍糊', '稍凹', '硬滑']])
  22. label = np.array(['是', '是', '是', '是', '是', '是', '是', '是', '否', '否', '否', '否', '否', '否', '否', '否', '否'])
  23. name = np.array(['色泽', '根蒂', '敲声', '纹理', '脐部', '触感'])
  24. return data, label, name
  25. #划分测试集与训练集
  26. def splitXgData20(xgData, xgLabel):
  27. xgDataTrain = xgData[[0, 1, 2, 5, 6, 9, 13, 14, 15, 16],:]
  28. xgDataTest = xgData[[3, 4, 7, 8, 10, 11, 12],:]
  29. xgLabelTrain = xgLabel[[0, 1, 2, 5, 6, 9, 13, 14, 15, 16]]
  30. xgLabelTest = xgLabel[[3, 4, 7, 8, 10, 11, 12]]
  31. return xgDataTrain, xgLabelTrain, xgDataTest, xgLabelTest

2.生成决策树

  1. # 定义一个常用函数 用来求numpy array中数值等于某值的元素数量
  2. equalNums = lambda x,y: 0 if x is None else x[x==y].size
  3. # 定义计算信息熵的函数
  4. def singleEntropy(x):
  5. """计算一个输入序列的信息熵"""
  6. # 转换为 numpy 矩阵
  7. x = np.asarray(x)
  8. # 取所有不同值
  9. xValues = set(x)
  10. # 计算熵值
  11. entropy = 0
  12. for xValue in xValues:
  13. p = equalNums(x, xValue) / x.size
  14. entropy -= p * math.log(p, 2)
  15. return entropy
  16. # 定义计算条件信息熵的函数
  17. def conditionnalEntropy(feature, y):
  18. """计算 某特征feature 条件下y的信息熵"""
  19. # 转换为numpy
  20. feature = np.asarray(feature)
  21. y = np.asarray(y)
  22. # 取特征的不同值
  23. featureValues = set(feature)
  24. # 计算熵值
  25. entropy = 0
  26. for feat in featureValues:
  27. # 解释:feature == feat 是得到取feature中所有元素值等于feat的元素的索引(类似这样理解)
  28. # y[feature == feat] 是取y中 feature元素值等于feat的元素索引的 y的元素的子集
  29. p = equalNums(feature, feat) / feature.size
  30. entropy += p * singleEntropy(y[feature == feat])
  31. return entropy
  32. # 定义信息增益
  33. def infoGain(feature, y):
  34. return singleEntropy(y) - conditionnalEntropy(feature, y)
  35. # 定义信息增益率
  36. def infoGainRatio(feature, y):
  37. return 0 if singleEntropy(feature) == 0 else infoGain(feature, y) / singleEntropy(feature)
  38. # 特征选取
  39. def bestFeature(data, labels, method = 'id3'):
  40. assert method in ['id3', 'c45'], "method 须为id3或c45"
  41. data = np.asarray(data)
  42. labels = np.asarray(labels)
  43. # 根据输入的method选取 评估特征的方法:id3 -> 信息增益; c45 -> 信息增益率
  44. def calcEnt(feature, labels):
  45. if method == 'id3':
  46. return infoGain(feature, labels)
  47. elif method == 'c45' :
  48. return infoGainRatio(feature, labels)
  49. # 特征数量 即 data 的列数量
  50. featureNum = data.shape[1]
  51. # 计算最佳特征
  52. bestEnt = 0
  53. bestFeat = -1
  54. for feature in range(featureNum):
  55. ent = calcEnt(data[:, feature], labels)
  56. if ent >= bestEnt:
  57. bestEnt = ent
  58. bestFeat = feature
  59. # print("feature " + str(feature + 1) + " ent: " + str(ent)+ "\t bestEnt: " + str(bestEnt))
  60. return bestFeat, bestEnt
  61. # 根据特征及特征值分割原数据集 删除data中的feature列,并根据feature列中的值分割 data和label
  62. def splitFeatureData(data, labels, feature):
  63. """feature 为特征列的索引"""
  64. # 取特征列
  65. features = np.asarray(data)[:,feature]
  66. # 数据集中删除特征列
  67. data = np.delete(np.asarray(data), feature, axis = 1)
  68. # 标签
  69. labels = np.asarray(labels)
  70. uniqFeatures = set(features)
  71. dataSet = {}
  72. labelSet = {}
  73. for feat in uniqFeatures:
  74. dataSet[feat] = data[features == feat]
  75. labelSet[feat] = labels[features == feat]
  76. return dataSet, labelSet
  77. # 多数投票
  78. def voteLabel(labels):
  79. uniqLabels = list(set(labels))
  80. labels = np.asarray(labels)
  81. finalLabel = 0
  82. labelNum = []
  83. for label in uniqLabels:
  84. # 统计每个标签值得数量
  85. labelNum.append(equalNums(labels, label))
  86. # 返回数量最大的标签
  87. return uniqLabels[labelNum.index(max(labelNum))]
  88. # 创建决策树
  89. def createTree(data, labels, names, method = 'id3'):
  90. data = np.asarray(data)
  91. labels = np.asarray(labels)
  92. names = np.asarray(names)
  93. # 如果结果为单一结果
  94. if len(set(labels)) == 1:
  95. return labels[0]
  96. # 如果没有待分类特征
  97. elif data.size == 0:
  98. return voteLabel(labels)
  99. # 其他情况则选取特征
  100. bestFeat, bestEnt = bestFeature(data, labels, method = method)
  101. # 取特征名称
  102. bestFeatName = names[bestFeat]
  103. # 从特征名称列表删除已取得特征名称
  104. names = np.delete(names, [bestFeat])
  105. # 根据选取的特征名称创建树节点
  106. decisionTree = {bestFeatName: {}}
  107. # 根据最优特征进行分割
  108. dataSet, labelSet = splitFeatureData(data, labels, bestFeat)
  109. # 对最优特征的每个特征值所分的数据子集进行计算
  110. for featValue in dataSet.keys():
  111. decisionTree[bestFeatName][featValue] = createTree(dataSet.get(featValue), labelSet.get(featValue), names, method)
  112. return decisionTree
  113. # 树信息统计 叶子节点数量 和 树深度
  114. def getTreeSize(decisionTree):
  115. nodeName = list(decisionTree.keys())[0]
  116. nodeValue = decisionTree[nodeName]
  117. leafNum = 0
  118. treeDepth = 0
  119. leafDepth = 0
  120. for val in nodeValue.keys():
  121. if type(nodeValue[val]) == dict:
  122. leafNum += getTreeSize(nodeValue[val])[0]
  123. leafDepth = 1 + getTreeSize(nodeValue[val])[1]
  124. else :
  125. leafNum += 1
  126. leafDepth = 1
  127. treeDepth = max(treeDepth, leafDepth)
  128. return leafNum, treeDepth
  129. # 使用模型对其他数据分类
  130. def dtClassify(decisionTree, rowData, names):
  131. names = list(names)
  132. # 获取特征
  133. feature = list(decisionTree.keys())[0]
  134. # 决策树对于该特征的值的判断字段
  135. featDict = decisionTree[feature]
  136. # 获取特征的列
  137. feat = names.index(feature)
  138. # 获取数据该特征的值
  139. featVal = rowData[feat]
  140. # 根据特征值查找结果,如果结果是字典说明是子树,调用本函数递归
  141. if featVal in featDict.keys():
  142. if type(featDict[featVal]) == dict:
  143. classLabel = dtClassify(featDict[featVal], rowData, names)
  144. else:
  145. classLabel = featDict[featVal]
  146. return classLabel

3.可视化

  1. # 可视化 主要源自《机器学习实战》
  2. import matplotlib.pyplot as plt
  3. plt.rcParams['font.sans-serif']=['SimHei'] #用来正常显示中文标签
  4. plt.rcParams['figure.figsize'] = (10.0, 8.0) # set default size of plots
  5. plt.rcParams['image.interpolation'] = 'nearest'
  6. plt.rcParams['image.cmap'] = 'gray'
  7. decisionNodeStyle = dict(boxstyle = "sawtooth", fc = "0.8")
  8. leafNodeStyle = {"boxstyle": "round4", "fc": "0.8"}
  9. arrowArgs = {"arrowstyle": "<-"}
  10. # 画节点
  11. def plotNode(nodeText, centerPt, parentPt, nodeStyle):
  12. createPlot.ax1.annotate(nodeText, xy = parentPt, xycoords = "axes fraction", xytext = centerPt
  13. , textcoords = "axes fraction", va = "center", ha="center", bbox = nodeStyle, arrowprops = arrowArgs)
  14. # 添加箭头上的标注文字
  15. def plotMidText(centerPt, parentPt, lineText):
  16. xMid = (centerPt[0] + parentPt[0]) / 2.0
  17. yMid = (centerPt[1] + parentPt[1]) / 2.0
  18. createPlot.ax1.text(xMid, yMid, lineText)
  19. # 画树
  20. def plotTree(decisionTree, parentPt, parentValue):
  21. # 计算宽与高
  22. leafNum, treeDepth = getTreeSize(decisionTree)
  23. # 在 1 * 1 的范围内画图,因此分母为 1
  24. # 每个叶节点之间的偏移量
  25. plotTree.xOff = plotTree.figSize / (plotTree.totalLeaf - 1)
  26. # 每一层的高度偏移量
  27. plotTree.yOff = plotTree.figSize / plotTree.totalDepth
  28. # 节点名称
  29. nodeName = list(decisionTree.keys())[0]
  30. # 根节点的起止点相同,可避免画线;如果是中间节点,则从当前叶节点的位置开始,
  31. # 然后加上本次子树的宽度的一半,则为决策节点的横向位置
  32. centerPt = (plotTree.x + (leafNum - 1) * plotTree.xOff / 2.0, plotTree.y)
  33. # 画出该决策节点
  34. plotNode(nodeName, centerPt, parentPt, decisionNodeStyle)
  35. # 标记本节点对应父节点的属性值
  36. plotMidText(centerPt, parentPt, parentValue)
  37. # 取本节点的属性值
  38. treeValue = decisionTree[nodeName]
  39. # 下一层各节点的高度
  40. plotTree.y = plotTree.y - plotTree.yOff
  41. # 绘制下一层
  42. for val in treeValue.keys():
  43. # 如果属性值对应的是字典,说明是子树,进行递归调用; 否则则为叶子节点
  44. if type(treeValue[val]) == dict:
  45. plotTree(treeValue[val], centerPt, str(val))
  46. else:
  47. plotNode(treeValue[val], (plotTree.x, plotTree.y), centerPt, leafNodeStyle)
  48. plotMidText((plotTree.x, plotTree.y), centerPt, str(val))
  49. # 移到下一个叶子节点
  50. plotTree.x = plotTree.x + plotTree.xOff
  51. # 递归完成后返回上一层
  52. plotTree.y = plotTree.y + plotTree.yOff
  53. # 画出决策树
  54. def createPlot(decisionTree):
  55. fig = plt.figure(1, facecolor = "white")
  56. fig.clf()
  57. axprops = {"xticks": [], "yticks": []}
  58. createPlot.ax1 = plt.subplot(111, frameon = False, **axprops)
  59. # 定义画图的图形尺寸
  60. plotTree.figSize = 1.5
  61. # 初始化树的总大小
  62. plotTree.totalLeaf, plotTree.totalDepth = getTreeSize(decisionTree)
  63. # 叶子节点的初始位置x 和 根节点的初始层高度y
  64. plotTree.x = 0
  65. plotTree.y = plotTree.figSize
  66. plotTree(decisionTree, (plotTree.figSize / 2.0, plotTree.y), "")
  67. plt.show()

4.预剪枝

  1. # 创建预剪枝决策树
  2. def createTreePrePruning(dataTrain, labelTrain, dataTest, labelTest, names, method = 'id3'):
  3. trainData = np.asarray(dataTrain)
  4. labelTrain = np.asarray(labelTrain)
  5. testData = np.asarray(dataTest)
  6. labelTest = np.asarray(labelTest)
  7. names = np.asarray(names)
  8. # 如果结果为单一结果
  9. if len(set(labelTrain)) == 1:
  10. return labelTrain[0]
  11. # 如果没有待分类特征
  12. elif trainData.size == 0:
  13. return voteLabel(labelTrain)
  14. # 其他情况则选取特征
  15. bestFeat, bestEnt = bestFeature(dataTrain, labelTrain, method = method)
  16. # 取特征名称
  17. bestFeatName = names[bestFeat]
  18. # 从特征名称列表删除已取得特征名称
  19. names = np.delete(names, [bestFeat])
  20. # 根据最优特征进行分割
  21. dataTrainSet, labelTrainSet = splitFeatureData(dataTrain, labelTrain, bestFeat)
  22. # 预剪枝评估
  23. # 划分前的分类标签
  24. labelTrainLabelPre = voteLabel(labelTrain)
  25. labelTrainRatioPre = equalNums(labelTrain, labelTrainLabelPre) / labelTrain.size
  26. # 划分后的精度计算
  27. if dataTest is not None:
  28. dataTestSet, labelTestSet = splitFeatureData(dataTest, labelTest, bestFeat)
  29. # 划分前的测试标签正确比例
  30. labelTestRatioPre = equalNums(labelTest, labelTrainLabelPre) / labelTest.size
  31. # 划分后 每个特征值的分类标签正确的数量
  32. labelTrainEqNumPost = 0
  33. for val in labelTrainSet.keys():
  34. labelTrainEqNumPost += equalNums(labelTestSet.get(val), voteLabel(labelTrainSet.get(val))) + 0.0
  35. # 划分后 正确的比例
  36. labelTestRatioPost = labelTrainEqNumPost / labelTest.size
  37. # 如果没有评估数据 但划分前的精度等于最小值0.5 则继续划分
  38. if dataTest is None and labelTrainRatioPre == 0.5:
  39. decisionTree = {bestFeatName: {}}
  40. for featValue in dataTrainSet.keys():
  41. decisionTree[bestFeatName][featValue] = createTreePrePruning(dataTrainSet.get(featValue), labelTrainSet.get(featValue)
  42. , None, None, names, method)
  43. elif dataTest is None:
  44. return labelTrainLabelPre
  45. # 如果划分后的精度相比划分前的精度下降, 则直接作为叶子节点返回
  46. elif labelTestRatioPost < labelTestRatioPre:
  47. return labelTrainLabelPre
  48. else :
  49. # 根据选取的特征名称创建树节点
  50. decisionTree = {bestFeatName: {}}
  51. # 对最优特征的每个特征值所分的数据子集进行计算
  52. for featValue in dataTrainSet.keys():
  53. decisionTree[bestFeatName][featValue] = createTreePrePruning(dataTrainSet.get(featValue), labelTrainSet.get(featValue)
  54. , dataTestSet.get(featValue), labelTestSet.get(featValue)
  55. , names, method)
  56. return decisionTree
  1. # 将西瓜数据2.0分割为测试集和训练集
  2. xgDataTrain, xgLabelTrain, xgDataTest, xgLabelTest = splitXgData20(xgData, xgLabel)
  3. # 生成不剪枝的树
  4. xgTreeTrain = createTree(xgDataTrain, xgLabelTrain, xgName, method = 'id3')
  5. # 生成预剪枝的树
  6. xgTreePrePruning = createTreePrePruning(xgDataTrain, xgLabelTrain, xgDataTest, xgLabelTest, xgName, method = 'id3')
  7. # 画剪枝前的树
  8. print("剪枝前的树")
  9. createPlot(xgTreeTrain)
  10. # 画剪枝后的树
  11. print("剪枝后的树")
  12. createPlot(xgTreePrePruning)

 

剪枝后,减少了一些分支,降低了过拟合的风险。 

5.后剪枝

  1. # 创建决策树 带预划分标签
  2. def createTreeWithLabel(data, labels, names, method = 'id3'):
  3. data = np.asarray(data)
  4. labels = np.asarray(labels)
  5. names = np.asarray(names)
  6. # 如果不划分的标签为
  7. votedLabel = voteLabel(labels)
  8. # 如果结果为单一结果
  9. if len(set(labels)) == 1:
  10. return votedLabel
  11. # 如果没有待分类特征
  12. elif data.size == 0:
  13. return votedLabel
  14. # 其他情况则选取特征
  15. bestFeat, bestEnt = bestFeature(data, labels, method = method)
  16. # 取特征名称
  17. bestFeatName = names[bestFeat]
  18. # 从特征名称列表删除已取得特征名称
  19. names = np.delete(names, [bestFeat])
  20. # 根据选取的特征名称创建树节点 划分前的标签votedPreDivisionLabel=_vpdl
  21. decisionTree = {bestFeatName: {"_vpdl": votedLabel}}
  22. # 根据最优特征进行分割
  23. dataSet, labelSet = splitFeatureData(data, labels, bestFeat)
  24. # 对最优特征的每个特征值所分的数据子集进行计算
  25. for featValue in dataSet.keys():
  26. decisionTree[bestFeatName][featValue] = createTreeWithLabel(dataSet.get(featValue), labelSet.get(featValue), names, method)
  27. return decisionTree
  28. # 将带预划分标签的tree转化为常规的tree
  29. # 函数中进行的copy操作,原因见有道笔记 【YL20190621】关于Python中字典存储修改的思考
  30. def convertTree(labeledTree):
  31. labeledTreeNew = labeledTree.copy()
  32. nodeName = list(labeledTree.keys())[0]
  33. labeledTreeNew[nodeName] = labeledTree[nodeName].copy()
  34. for val in list(labeledTree[nodeName].keys()):
  35. if val == "_vpdl":
  36. labeledTreeNew[nodeName].pop(val)
  37. elif type(labeledTree[nodeName][val]) == dict:
  38. labeledTreeNew[nodeName][val] = convertTree(labeledTree[nodeName][val])
  39. return labeledTreeNew
  40. # 后剪枝 训练完成后决策节点进行替换评估 这里可以直接对xgTreeTrain进行操作
  41. def treePostPruning(labeledTree, dataTest, labelTest, names):
  42. newTree = labeledTree.copy()
  43. dataTest = np.asarray(dataTest)
  44. labelTest = np.asarray(labelTest)
  45. names = np.asarray(names)
  46. # 取决策节点的名称 即特征的名称
  47. featName = list(labeledTree.keys())[0]
  48. # print("\n当前节点:" + featName)
  49. # 取特征的列
  50. featCol = np.argwhere(names==featName)[0][0]
  51. names = np.delete(names, [featCol])
  52. # print("当前节点划分的数据维度:" + str(names))
  53. # print("当前节点划分的数据:" )
  54. # print(dataTest)
  55. # print(labelTest)
  56. # 该特征下所有值的字典
  57. newTree[featName] = labeledTree[featName].copy()
  58. featValueDict = newTree[featName]
  59. featPreLabel = featValueDict.pop("_vpdl")
  60. # print("当前节点预划分标签:" + featPreLabel)
  61. # 是否为子树的标记
  62. subTreeFlag = 0
  63. # 分割测试数据 如果有数据 则进行测试或递归调用 np的array我不知道怎么判断是否None, 用is None是错的
  64. dataFlag = 1 if sum(dataTest.shape) > 0 else 0
  65. if dataFlag == 1:
  66. # print("当前节点有划分数据!")
  67. dataTestSet, labelTestSet = splitFeatureData(dataTest, labelTest, featCol)
  68. for featValue in featValueDict.keys():
  69. # print("当前节点属性 {0} 的子节点:{1}".format(featValue ,str(featValueDict[featValue])))
  70. if dataFlag == 1 and type(featValueDict[featValue]) == dict:
  71. subTreeFlag = 1
  72. # 如果是子树则递归
  73. newTree[featName][featValue] = treePostPruning(featValueDict[featValue], dataTestSet.get(featValue), labelTestSet.get(featValue), names)
  74. # 如果递归后为叶子 则后续进行评估
  75. if type(featValueDict[featValue]) != dict:
  76. subTreeFlag = 0
  77. # 如果没有数据 则转换子树
  78. if dataFlag == 0 and type(featValueDict[featValue]) == dict:
  79. subTreeFlag = 1
  80. # print("当前节点无划分数据!直接转换树:"+str(featValueDict[featValue]))
  81. newTree[featName][featValue] = convertTree(featValueDict[featValue])
  82. # print("转换结果:" + str(convertTree(featValueDict[featValue])))
  83. # 如果全为叶子节点, 评估需要划分前的标签,这里思考两种方法,
  84. # 一是,不改变原来的训练函数,评估时使用训练数据对划分前的节点标签重新打标
  85. # 二是,改进训练函数,在训练的同时为每个节点增加划分前的标签,这样可以保证评估时只使用测试数据,避免再次使用大量的训练数据
  86. # 这里考虑第二种方法 写新的函数 createTreeWithLabel,当然也可以修改createTree来添加参数实现
  87. if subTreeFlag == 0:
  88. ratioPreDivision = equalNums(labelTest, featPreLabel) / labelTest.size
  89. equalNum = 0
  90. for val in labelTestSet.keys():
  91. equalNum += equalNums(labelTestSet[val], featValueDict[val])
  92. ratioAfterDivision = equalNum / labelTest.size
  93. # print("当前节点预划分标签的准确率:" + str(ratioPreDivision))
  94. # print("当前节点划分后的准确率:" + str(ratioAfterDivision))
  95. # 如果划分后的测试数据准确率低于划分前的,则划分无效,进行剪枝,即使节点等于预划分标签
  96. # 注意这里取的是小于,如果有需要 也可以取 小于等于
  97. if ratioAfterDivision < ratioPreDivision:
  98. newTree = featPreLabel
  99. return newTree
  1. xgTreePostPruning = treePostPruning(xgTreeTrain, xgDataTest, xgLabelTest, xgName)
  2. createPlot(xgTreePostPruning)

 

剪除多余节点后明显比预剪枝保留了更多的分支,泛化能力更强。

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/IT小白/article/detail/330810
推荐阅读
相关标签
  

闽ICP备14008679号