当前位置:   article > 正文

机器学习——决策树剪枝_决策树剪枝代码

决策树剪枝代码

目录

一、决策树剪枝策略

1.1剪枝目的

1.2剪枝策略

1.3判断决策树泛化性能是否提升的方法

二、预剪枝 (prepruning)

2.1概述 

2.2预剪枝优缺点

2.3代码实现

三、后剪枝(postpruning) 

3.1概述

3.2后剪枝优缺点 

3.3代码实现

 

  


代码部分参考决策树python源码实现(含预剪枝和后剪枝)_王路ylu的博客-CSDN博客_构建决策树代码 

一、决策树剪枝策略

1.1剪枝目的

决策树过拟合(数据在训练集上表现的很好,在测试集上表现的不好)风险很大,理论上可以完全分的开数据(想象一下,如果树足够庞大,每个叶子节点就一个数据)

“剪枝”是决策树学习算法对付“过拟合”的主要手段

可以通过“剪枝”来一定程度避免因决策分支过多,以致于把训练集自身的一些特点当做所有数据都具有的一般性质而导致的过拟合

1.2剪枝策略

预剪枝:边建立决策树边进行剪枝的操作(更实用)

后剪枝:当建立完决策树后进行剪枝操作

1.3判断决策树泛化性能是否提升的方法

留出法:预留出一部分数据用作“验证集”以进行性能评估

 

二、预剪枝 (prepruning)

2.1概述 

决策树生成过程中,对每个结点在划分前先进行估计,若当前结点的划分不能带来决策树泛化性能提升,则停止划分并将当前结点记为叶结点,其类别标记为该结点对应训练样例数最多的类别。

策略:

  • 限制深度(例如数据集有十个特征,限制树的深度为3,即只能用其中三个特征创建树)
  • 叶子结点个数(最多只能有五个叶子结点)
  • 叶子结点样本数(一个叶子结点最少得有20个样本)
  • 信息增益

这样说可能还是有点难理解预剪枝怎么做,用西瓜书中的例子简单说明一下

根据信息增益(具体计算过程在机器学习——创建决策树_装进了牛奶箱中的博客-CSDN博客 )我们可以得到一棵未剪枝的树: 

 

(1)首先,我们先判断“脐部”,如果我们不对“脐部”进行划分,也就是说这棵决策树是这样的:

这样下来,也就是说无论你什么瓜过来我都判断它是好瓜。使用验证集进行验证,验证的精准度为: 37×100

如果进行划分(其中红色字体的表示验证集中被划分正确的编号):

 

如果只划分脐部这个属性,,我们可以通过其来划分好瓜和坏瓜,通过验证机去测试,我们可以得到划分后的精确度为:57×100 ,71.4%>42.9%所以选择划分

 (2)再看“脐部=凹陷“这个分支

如果不划分,验证集精度为71.4%

如果划分(其中红色字体的表示验证集中被划分正确的编号),验证集中编号为{4,8,11,12}的样例被划分正确:

 

划分后的精确度为 47×100,57.1%<71.4%所以选择取消划分

(3)对每个结点进行剪枝判断,结点2,3都禁止划分,结点4本身为叶子结点。最终得到仅有一层划分的“决策树桩”

 

2.2预剪枝优缺点

优点降低过拟合风险,显著减少训练时间和测试时间开销。

缺点欠拟合风险:有些分支的当前划分虽然不能提神泛化性能,但在其基础上进行的后续划分却有可能显著提高性能。预剪枝基于“贪心”本质禁止这些分支展开,带来了欠拟合分险。

2.3代码实现

以福建省选调生报考条件和上述判断西瓜好坏为例,选调生数据集如下:

 

这里只展示预剪枝部分的代码,完整代码可以点击百度网盘链接查看 

  1. # 创建预剪枝决策树
  2. def createTreePrePruning(dataTrain, labelTrain, dataTest, labelTest, names, method = 'id3'):
  3. trainData = np.asarray(dataTrain)
  4. labelTrain = np.asarray(labelTrain)
  5. testData = np.asarray(dataTest)
  6. labelTest = np.asarray(labelTest)
  7. names = np.asarray(names)
  8. # 如果结果为单一结果
  9. if len(set(labelTrain)) == 1:
  10. return labelTrain[0]
  11. # 如果没有待分类特征
  12. elif trainData.size == 0:
  13. return voteLabel(labelTrain)
  14. # 其他情况则选取特征
  15. bestFeat, bestEnt = bestFeature(dataTrain, labelTrain, method = method)
  16. # 取特征名称
  17. bestFeatName = names[bestFeat]
  18. # 从特征名称列表删除已取得特征名称
  19. names = np.delete(names, [bestFeat])
  20. # 根据最优特征进行分割
  21. dataTrainSet, labelTrainSet = splitFeatureData(dataTrain, labelTrain, bestFeat)
  22. # 预剪枝评估
  23. # 划分前的分类标签
  24. labelTrainLabelPre = voteLabel(labelTrain)
  25. labelTrainRatioPre = equalNums(labelTrain, labelTrainLabelPre) / labelTrain.size
  26. # 划分后的精度计算
  27. if dataTest is not None:
  28. dataTestSet, labelTestSet = splitFeatureData(dataTest, labelTest, bestFeat)
  29. # 划分前的测试标签正确比例
  30. labelTestRatioPre = equalNums(labelTest, labelTrainLabelPre) / labelTest.size
  31. # 划分后 每个特征值的分类标签正确的数量
  32. labelTrainEqNumPost = 0
  33. for val in labelTrainSet.keys():
  34. labelTrainEqNumPost += equalNums(labelTestSet.get(val), voteLabel(labelTrainSet.get(val))) + 0.0
  35. # 划分后 正确的比例
  36. labelTestRatioPost = labelTrainEqNumPost / labelTest.size
  37. # 如果没有评估数据 但划分前的精度等于最小值0.5 则继续划分
  38. if dataTest is None and labelTrainRatioPre == 0.5:
  39. decisionTree = {bestFeatName: {}}
  40. for featValue in dataTrainSet.keys():
  41. decisionTree[bestFeatName][featValue] = createTreePrePruning(dataTrainSet.get(featValue), labelTrainSet.get(featValue)
  42. , None, None, names, method)
  43. elif dataTest is None:
  44. return labelTrainLabelPre
  45. # 如果划分后的精度相比划分前的精度下降, 则直接作为叶子节点返回
  46. elif labelTestRatioPost < labelTestRatioPre:
  47. return labelTrainLabelPre
  48. else :
  49. # 根据选取的特征名称创建树节点
  50. decisionTree = {bestFeatName: {}}
  51. # 对最优特征的每个特征值所分的数据子集进行计算
  52. for featValue in dataTrainSet.keys():
  53. decisionTree[bestFeatName][featValue] = createTreePrePruning(dataTrainSet.get(featValue), labelTrainSet.get(featValue)
  54. , dataTestSet.get(featValue), labelTestSet.get(featValue)
  55. , names, method)
  56. return decisionTree
  1. # 将数据分割为测试集和训练集
  2. myDataTrain, myLabelTrain, myDataTest, myLabelTest = splitMyData20(myData, myLabel)
  3. # 生成不剪枝的树
  4. myTreeTrain = createTree(myDataTrain, myLabelTrain, myName, method = 'id3')
  5. # 生成预剪枝的树
  6. myTreePrePruning = createTreePrePruning(myDataTrain, myLabelTrain, myDataTest, myLabelTest, myName, method = 'id3')
  7. # 画剪枝前的树
  8. print("剪枝前的树")
  9. createPlot(myTreeTrain)
  10. # 画剪枝后的树
  11. print("剪枝后的树")
  12. createPlot(myTreePrePruning)

选调生数据集运行结果 

 

  

可能是数据集的原因导致剪枝后的树和事实有点不符 

  1. # 将西瓜数据2.0分割为测试集和训练集
  2. xgDataTrain, xgLabelTrain, xgDataTest, xgLabelTest = splitXgData20(xgData, xgLabel)
  3. # 生成不剪枝的树
  4. xgTreeTrain = createTree(xgDataTrain, xgLabelTrain, xgName, method = 'id3')
  5. # 生成预剪枝的树
  6. xgTreePrePruning = createTreePrePruning(xgDataTrain, xgLabelTrain, xgDataTest, xgLabelTest, xgName, method = 'id3')
  7. # 画剪枝前的树
  8. print("剪枝前的树")
  9. createPlot(xgTreeTrain)
  10. # 画剪枝后的树
  11. print("剪枝后的树")
  12. createPlot(xgTreePrePruning)

西瓜数据集运行结果 

 

 

由于特征选择的问题,最后得到的图像和书上的有差异 

三、后剪枝(postpruning) 

3.1概述

先从训练集生成一棵完整的决策树,然后自底向上地对非叶子结点进行分析计算,若将该结点对应的子树替换为叶结点,能带来决策树泛化性能提升,则将该子树替换为叶结点

同样的,我们使用上诉例子来简单说明一下后剪枝,先从训练集生成一棵完整的决策树

 (1)第一步先考察结点6,如果不剪枝,验证集中编号为{4,11,12}的三个样本被正确分类,因此验证集精度为37×100

如果将其替换为叶结点,根据落在其上的训练样本{7,15}将其标记为“好瓜”,进入该分支的验证集样本有{8,9},样本8被正确分类,对整个验证集编号为{4,8,11,12}的四个样本正确分类,因此验证集的精度为47×100,57.1%>42.9%,所以选择剪掉该分支。

 (2)再来考察结点5,如果不剪枝,验证集精度为57.1%

如果将其替换为叶子结点,根据落在其上的训练样本{6,7,15}将其标记为“好瓜”,进入该分支的验证集样本有{8,9},样本8被正确分类,对整个验证集编号为{4,8,11,12}的四个样本正确分类,因此验证集的精度为47×100,57.1%=57.2%,所以不剪枝

                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                       

(3)对结点2,如果不剪枝,验证集精度为57.1%

如果将其替换为叶子结点,根据落在其上的训练样本{1,2,3,14}将其标记为“好瓜”,进入该分支的验证集样本有{4,5,13},样本4,5被正确分类,对整个验证集编号为{4,5,8,11,12}的五个样本正确分类,因此验证集的精度为57×100,71.4%>57.1%,所以选择剪掉该分支。

 

基于 后剪枝策略 得到的最终决策树如图所示

3.2后剪枝优缺点 

优点:后剪枝比预剪枝保留了更多的分支,欠拟合分险小,泛化性能往往优于预剪枝决策树

缺点:训练时间开销大:后剪枝过程是在生成完全决策树之后进行的,需要自底向上对所有非叶子结点逐一计算 

3.3代码实现

  1. # 创建决策树 带预划分标签
  2. def createTreeWithLabel(data, labels, names, method = 'id3'):
  3. data = np.asarray(data)
  4. labels = np.asarray(labels)
  5. names = np.asarray(names)
  6. # 如果不划分的标签为
  7. votedLabel = voteLabel(labels)
  8. # 如果结果为单一结果
  9. if len(set(labels)) == 1:
  10. return votedLabel
  11. # 如果没有待分类特征
  12. elif data.size == 0:
  13. return votedLabel
  14. # 其他情况则选取特征
  15. bestFeat, bestEnt = bestFeature(data, labels, method = method)
  16. # 取特征名称
  17. bestFeatName = names[bestFeat]
  18. # 从特征名称列表删除已取得特征名称
  19. names = np.delete(names, [bestFeat])
  20. # 根据选取的特征名称创建树节点 划分前的标签votedPreDivisionLabel=_vpdl
  21. decisionTree = {bestFeatName: {"_vpdl": votedLabel}}
  22. # 根据最优特征进行分割
  23. dataSet, labelSet = splitFeatureData(data, labels, bestFeat)
  24. # 对最优特征的每个特征值所分的数据子集进行计算
  25. for featValue in dataSet.keys():
  26. decisionTree[bestFeatName][featValue] = createTreeWithLabel(dataSet.get(featValue), labelSet.get(featValue), names, method)
  27. return decisionTree
  28. # 将带预划分标签的tree转化为常规的tree
  29. # 函数中进行的copy操作,原因见有道笔记 【YL20190621】关于Python中字典存储修改的思考
  30. def convertTree(labeledTree):
  31. labeledTreeNew = labeledTree.copy()
  32. nodeName = list(labeledTree.keys())[0]
  33. labeledTreeNew[nodeName] = labeledTree[nodeName].copy()
  34. for val in list(labeledTree[nodeName].keys()):
  35. if val == "_vpdl":
  36. labeledTreeNew[nodeName].pop(val)
  37. elif type(labeledTree[nodeName][val]) == dict:
  38. labeledTreeNew[nodeName][val] = convertTree(labeledTree[nodeName][val])
  39. return labeledTreeNew
  40. # 后剪枝 训练完成后决策节点进行替换评估 这里可以直接对xgTreeTrain进行操作
  41. def treePostPruning(labeledTree, dataTest, labelTest, names):
  42. newTree = labeledTree.copy()
  43. dataTest = np.asarray(dataTest)
  44. labelTest = np.asarray(labelTest)
  45. names = np.asarray(names)
  46. # 取决策节点的名称 即特征的名称
  47. featName = list(labeledTree.keys())[0]
  48. # print("\n当前节点:" + featName)
  49. # 取特征的列
  50. featCol = np.argwhere(names==featName)[0][0]
  51. names = np.delete(names, [featCol])
  52. # print("当前节点划分的数据维度:" + str(names))
  53. # print("当前节点划分的数据:" )
  54. # print(dataTest)
  55. # print(labelTest)
  56. # 该特征下所有值的字典
  57. newTree[featName] = labeledTree[featName].copy()
  58. featValueDict = newTree[featName]
  59. featPreLabel = featValueDict.pop("_vpdl")
  60. # print("当前节点预划分标签:" + featPreLabel)
  61. # 是否为子树的标记
  62. subTreeFlag = 0
  63. # 分割测试数据 如果有数据 则进行测试或递归调用 np的array我不知道怎么判断是否None, 用is None是错的
  64. dataFlag = 1 if sum(dataTest.shape) > 0 else 0
  65. if dataFlag == 1:
  66. # print("当前节点有划分数据!")
  67. dataTestSet, labelTestSet = splitFeatureData(dataTest, labelTest, featCol)
  68. for featValue in featValueDict.keys():
  69. # print("当前节点属性 {0} 的子节点:{1}".format(featValue ,str(featValueDict[featValue])))
  70. if dataFlag == 1 and type(featValueDict[featValue]) == dict:
  71. subTreeFlag = 1
  72. # 如果是子树则递归
  73. newTree[featName][featValue] = treePostPruning(featValueDict[featValue], dataTestSet.get(featValue), labelTestSet.get(featValue), names)
  74. # 如果递归后为叶子 则后续进行评估
  75. if type(featValueDict[featValue]) != dict:
  76. subTreeFlag = 0
  77. # 如果没有数据 则转换子树
  78. if dataFlag == 0 and type(featValueDict[featValue]) == dict:
  79. subTreeFlag = 1
  80. # print("当前节点无划分数据!直接转换树:"+str(featValueDict[featValue]))
  81. newTree[featName][featValue] = convertTree(featValueDict[featValue])
  82. # print("转换结果:" + str(convertTree(featValueDict[featValue])))
  83. # 如果全为叶子节点, 评估需要划分前的标签,这里思考两种方法,
  84. # 一是,不改变原来的训练函数,评估时使用训练数据对划分前的节点标签重新打标
  85. # 二是,改进训练函数,在训练的同时为每个节点增加划分前的标签,这样可以保证评估时只使用测试数据,避免再次使用大量的训练数据
  86. # 这里考虑第二种方法 写新的函数 createTreeWithLabel,当然也可以修改createTree来添加参数实现
  87. if subTreeFlag == 0:
  88. ratioPreDivision = equalNums(labelTest, featPreLabel) / labelTest.size
  89. equalNum = 0
  90. for val in labelTestSet.keys():
  91. equalNum += equalNums(labelTestSet[val], featValueDict[val])
  92. ratioAfterDivision = equalNum / labelTest.size
  93. # print("当前节点预划分标签的准确率:" + str(ratioPreDivision))
  94. # print("当前节点划分后的准确率:" + str(ratioAfterDivision))
  95. # 如果划分后的测试数据准确率低于划分前的,则划分无效,进行剪枝,即使节点等于预划分标签
  96. # 注意这里取的是小于,如果有需要 也可以取 小于等于
  97. if ratioAfterDivision < ratioPreDivision:
  98. newTree = featPreLabel
  99. return newTree

 

  1. #选调生数据集的树结构
  2. myTreeBeforePostPruning ={'生源地': {'_vpdl': '否', '贫困县': '是', '山区': {'学习成绩':
  3. {'_vpdl': '否', '优秀': {'政治面貌': {'_vpdl': '否', '党员': '是', '团员': '否'}}, '良好': '是', '及格': '否'}}, '沿海':
  4. {'学习成绩': {'_vpdl': '否', '优秀': {'政治面貌': {'_vpdl': '否', '党员': '是', '团员': '否'}}, '良好': '否', '及格': '否'}}}}
  5. myTreePostPruning = treePostPruning(myTreeBeforePostPruning, myDataTest, myLabelTest, myName)
  6. createPlot(convertTree(myTreeBeforePostPruning))
  7. createPlot(myTreePostPruning)

运行结果 

 

  1. xgTreeBeforePostPruning = {"脐部": {"_vpdl": "是"
  2. , '凹陷': {'色泽':{"_vpdl": "是", '青绿': '是', '乌黑': '是', '浅白': '否'}}
  3. , '稍凹': {'根蒂':{"_vpdl": "是"
  4. , '稍蜷': {'色泽': {"_vpdl": "是"
  5. , '青绿': '是'
  6. , '乌黑': {'纹理': {"_vpdl": "是"
  7. , '稍糊': '是', '清晰': '否', '模糊': '是'}}
  8. , '浅白': '是'}}
  9. , '蜷缩': '否'
  10. , '硬挺': '是'}}
  11. , '平坦': '否'}}
  12. xgTreePostPruning = treePostPruning(xgTreeBeforePostPruning, xgDataTest, xgLabelTest, xgName)
  13. createPlot(convertTree(xgTreeBeforePostPruning))
  14. createPlot(xgTreePostPruning)

运行结果 

 

 

完整代码

链接: https://pan.baidu.com/s/1jBL03BtDhD0_LOMHSZHluw?pwd=b95u 提取码: b95u 

 

 

 

 

 

 

 

  

 

 

 

 

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/Cpp五条/article/detail/320888
推荐阅读
相关标签
  

闽ICP备14008679号