当前位置:   article > 正文

PYTHON机器学习实战——决策树DT_决策树 dt算法python

决策树 dt算法python

  决策树也是有监督机器学习方法。

决策树算法是找到一个优化的决策路径(决策树),使得每次分类尽可能过滤更多的数据,或者说问的问题尽量少。
决策树算法可以用来优化一些知识系统,帮助用户快速找到答案。

基本概念

  • 属性(Feature): 训练数据中每列都是一个属性。
  • 标签(Label):训练数据中的分类结果。

如何构造决策树

这里,要解决的问题是采用哪些数据属性作为分类条件,最佳次序是什么?

  • 方法一:采用二分法,或者按照训练数据中的属性依次构造。
  • 方法二:使用香农熵计算公式。这是书中使用的方法。
  • 方法三:使用基尼不纯度2 (Gini impurity)。
  • 流行的算法: C4.5和CART

代码详解

  1. # -*- coding:utf-8 -*-
  2. #!/usr/bin/python
  3. # 测试 import DecTree as DT DT.test_dt()
  4. from math import log # 对数
  5. import operator # 操作符
  6. import copy # 列表复制,不改变原来的列表
  7. # 画树
  8. import plot_deci_tree as pt
  9. ## 自定义数据集 来进行测试
  10. def createDataSet():
  11. dataSet = [[1, 1, 'yes'],
  12. [1, 1, 'yes'],
  13. [1, 0, 'no'],
  14. [0, 1, 'no'],
  15. [0, 1, 'no']]
  16. labels = ['no surfacing','flippers'] # 属性值列表
  17. #change to discrete values
  18. return dataSet, labels
  19. ## 计算给定数据集的熵(混乱度) sum(概率 * (- log2(概率)))
  20. def calcShannonEnt(dataSet):
  21. numEntries = len(dataSet) # 数据集 样本总数
  22. labelCounts = {} # 标签集合 字典 标签 对应 出现的次数
  23. # 计算 类标签 各自出现的次数
  24. for featVec in dataSet: # 每个样本
  25. currentLabel = featVec[-1] # 每个样本的最后一列为标签
  26. # 若当前标签值不存在,则扩展字典,加入当前关键字
  27. if currentLabel not in labelCounts.keys(): labelCounts[currentLabel] = 0
  28. labelCounts[currentLabel] += 1 # 对应标签 计数 + 1
  29. # 计算信息熵
  30. shannonEnt = 0.0
  31. for key in labelCounts:
  32. prob = float(labelCounts[key])/numEntries # 计算每类出现的概率
  33. shannonEnt -= prob * log(prob,2) # 计算信息熵
  34. return shannonEnt
  35. ## 按照给定特征划分数据集 (取 划分属性值列 的剩余 部分)
  36. def splitDataSet(dataSet, axis, value):
  37. # dataSet 数据集 axis划分的属性(哪一列) 对应属性(axis)值value 的那一行要去掉
  38. retDataSet = [] # 划分好的数据集
  39. for featVec in dataSet: # 每一行 即一个 样本
  40. if featVec[axis] == value: # 选取 符合 对应属性值 的列
  41. reducedFeatVec = featVec[:axis] # 对应属性值 之前的其他属性值
  42. reducedFeatVec.extend(featVec[axis+1:]) # 加入 对应属性值 之前后的其他属性值 成一个 列表
  43. retDataSet.append(reducedFeatVec) # 将其他部分 加入新 的列表里
  44. return retDataSet
  45. ## 选取最好的 划分属性值
  46. def chooseBestFeatureToSplit(dataSet):
  47. numFeatures = len(dataSet[0]) - 1 # 总特征维度, 最后一列是标签
  48. baseEntropy = calcShannonEnt(dataSet) # 计算原来数据集的 信息熵
  49. bestInfoGain = 0.0; bestFeature = -1 # 信息增益 和 最优划分属性初始化
  50. for i in range(numFeatures): # 对于所有的特征(每一列特征对应一个 属性,即对已每一个属性)
  51. featList = [example[i] for example in dataSet] # 列表推导 所有 对应 特征属性
  52. uniqueVals = set(featList) # 从列表创建集合 得到每(列)个属性值的集合 用于划分集合
  53. newEntropy = 0.0
  54. for value in uniqueVals: # 对于该属性 的每个 属性值
  55. subDataSet = splitDataSet(dataSet, i, value) # 选取对应属性对应属性值 的新集合
  56. prob = len(subDataSet)/float(len(dataSet)) # 计算 该属性下该属性值的样本所占总样本数 的比例
  57. newEntropy += prob * calcShannonEnt(subDataSet) # 比例 * 对于子集合的信息熵,求和得到总信息熵
  58. infoGain = baseEntropy - newEntropy # 原始集合信息熵 - 新划分子集信息熵和 得到信息增益
  59. if (infoGain > bestInfoGain): # 信息熵 比划分前 减小了吗? 减小的话 (信息增益增大)
  60. bestInfoGain = infoGain # 更新最优 信息熵
  61. bestFeature = i # 记录当前最优 的划分属性
  62. return bestFeature # 返回全局最有的划分属性
  63. # 统计样本集合 类出现的次数 返回出现最多的 分类名称
  64. def majorityCnt(classList):
  65. classCount={}
  66. for vote in classList: # 每一个样本
  67. if vote not in classCount.keys(): classCount[vote] = 0 # 增加类标签到字典中
  68. classCount[vote] += 1 # 统计次数
  69. sortedClassCount = sorted(classCount.iteritems(), key=operator.itemgetter(1), reverse=True)# 按类出现次数 从大到小排序
  70. return sortedClassCount[0][0] # 返回出现次数最多的
  71. # 输入数据集 和 属性标签 生成 决策树
  72. def createTree(dataSet,labels):
  73. #copy_labels = labels
  74. classList = [example[-1] for example in dataSet] # 每个样本的分类标签
  75. # 终止条件1 所有类标签完全相同
  76. if classList.count(classList[0]) == len(classList):
  77. return classList[0] # 返回该类标签(分类属性)
  78. # 终止条件2 遍历完所有特征
  79. if len(dataSet[0]) == 1:
  80. return majorityCnt(classList) # 返回出现概率最大的类标签
  81. # 选择最好的划分属性(划分特征)
  82. bestFeat = chooseBestFeatureToSplit(dataSet)
  83. # 对于特征值(属性值)的特征(属性)
  84. bestFeatLabel = labels[bestFeat]
  85. # 初始化树
  86. myTree = {bestFeatLabel:{}} # 树的形状 分类属性:子树
  87. del(labels[bestFeat]) # 有问题 改变了 原来属性序列
  88. # 根据最优的划分属性 的 值列表 创建子树
  89. featValues = [example[bestFeat] for example in dataSet]
  90. uniqueVals = set(featValues)
  91. for value in uniqueVals: # 最优的划分属性 的 值列表
  92. subLabels = labels[:] # 每个子树的 子属性标签
  93. # 递归调用 生成决策树
  94. myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value),subLabels)
  95. return myTree
  96. # 使用训练好的决策树做识别分类
  97. def classify(inputTree,featLabels,testVec):
  98. firstStr = inputTree.keys()[0] # 起始分类属性 节点
  99. secondDict = inputTree[firstStr] # 子树
  100. featIndex = featLabels.index(firstStr) # 属性标签索引
  101. # key = testVec[featIndex] # 测试属性值向量 起始分类属性 对应 的属性
  102. # valueOfFeat = secondDict[key] # 对应的子树 或 叶子节点
  103. # if isinstance(valueOfFeat, dict): # 子树还是 树
  104. # classLabel = classify(valueOfFeat, featLabels, testVec) #递归调用
  105. # else: classLabel = valueOfFeat # 将叶子节点的标签赋予 类标签 输出
  106. for key in secondDict.keys():
  107. if testVec[featIndex] == key:
  108. if type(secondDict[key]).__name__ =='dict': # 子树还是 树
  109. classLabel = classify(secondDict[key], featLabels, testVec) #递归调用
  110. else: classLabel = secondDict[key] # 将叶子节点的标签赋予 类标签 输出
  111. return classLabel
  112. # 使用 pickle 对象 存储决策数
  113. def storeTree(inputTree,filename):
  114. import pickle
  115. fw = open(filename,'w')
  116. pickle.dump(inputTree,fw)
  117. fw.close()
  118. # 使用 pickle 对象 载入决策树
  119. def grabTree(filename):
  120. import pickle
  121. fr = open(filename)
  122. return pickle.load(fr)
  123. # 使用佩戴眼镜数据测试
  124. def test_dt():
  125. print '载入数据 lenses.txt ...'
  126. fr = open('lenses.txt')
  127. lenses = [inst.strip().split('\t') for inst in fr.readlines()]
  128. lenses_lab = ['age','prescript','astigmatic','teatRate']
  129. print '创建 lenses 决策数...'
  130. lenses_tree = createTree(lenses,lenses_lab)
  131. print lenses_tree
  132. print '保存 lenses 决策数...'
  133. storeTree(lenses_tree,'lenses_tree.txt')
  134. print '可视化 lenses 决策数...'
  135. pt.createPlot(lenses_tree)
决策数可视化代码:
plot_deci_tree.py
  1. # -*- coding:utf-8 -*-
  2. #!/usr/bin/python
  3. '''
  4. 绘制决策树生成的树类型字典
  5. '''
  6. import matplotlib.pyplot as plt
  7. # 文本框类型 和 箭头类型
  8. decisionNode = dict(boxstyle="sawtooth", fc="0.8") # 决策节点 文本框类型 花边矩形
  9. leafNode = dict(boxstyle="round4", fc="0.8") # 叶子节点 文本框类型 倒角矩形
  10. arrow_args = dict(arrowstyle="<-") # 箭头类型
  11. # 得到 叶子节点总树,用以确定 图的横轴长度
  12. def getNumLeafs(myTree):
  13. # myTree = {bestFeatLabel:{}} # 树的形状 分类属性:子树 {'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}}
  14. numLeafs = 0
  15. firstStr = myTree.keys()[0] # 开始 节点 分类属性
  16. secondDict = myTree[firstStr] # 对应后面的子节点(子字典)
  17. for key in secondDict.keys():
  18. if type(secondDict[key]).__name__=='dict': # 子节点 是字典 (孩子节点) 循环调用
  19. numLeafs += getNumLeafs(secondDict[key]) #
  20. else: numLeafs +=1 # 子节点的 如果不是 字典(孩子节点),就是叶子节点
  21. return numLeafs
  22. # 得到 树的深度(层数),用以确定 图的纵轴长度
  23. def getTreeDepth(myTree):
  24. maxDepth = 0
  25. firstStr = myTree.keys()[0] # 开始 节点 分类属性
  26. secondDict = myTree[firstStr] # 对应后面的子节点(子字典)
  27. for key in secondDict.keys():
  28. if type(secondDict[key]).__name__=='dict': # 子节点 是字典 (孩子节点) 循环调用
  29. thisDepth = 1 + getTreeDepth(secondDict[key]) #
  30. else: thisDepth = 1 # 深度为1
  31. if thisDepth > maxDepth: maxDepth = thisDepth
  32. return maxDepth
  33. # 绘制节点 带箭头的注释 框内文本 起点 终点 框类型
  34. def plotNode(nodeTxt, centerPt, parentPt, nodeType):
  35. createPlot.ax1.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction',
  36. xytext=centerPt, textcoords='axes fraction',
  37. va="center", ha="center", bbox=nodeType, arrowprops=arrow_args )
  38. # 在 父子节点中间连线的线上 添加文本信息(属性分类值)
  39. # 起点 终点 文本信息
  40. def plotMidText(cntrPt, parentPt, txtString):
  41. # 计算中点位置(文本放置的位置)
  42. xMid = (parentPt[0]-cntrPt[0])/2.0 + cntrPt[0]
  43. yMid = (parentPt[1]-cntrPt[1])/2.0 + cntrPt[1]
  44. createPlot.ax1.text(xMid, yMid, txtString, va="center", ha="center", rotation=30) # 偏转角度 初始位置为90 度
  45. # 画树
  46. def plotTree(myTree, parentPt, nodeTxt):# if the first key tells you what feat was split on
  47. numLeafs = getNumLeafs(myTree) # 树的宽度(图的横轴坐标,叶子节点的数量)
  48. depth = getTreeDepth(myTree) # 树的高度(图的纵坐标, 树的层数)
  49. firstStr = myTree.keys()[0] # 父节点,开始节点信息(分裂属性)
  50. cntrPt = (plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW, plotTree.yOff)# 起点
  51. plotMidText(cntrPt, parentPt, nodeTxt) # 线上注释
  52. plotNode(firstStr, cntrPt, parentPt, decisionNode) # 画分类属性节点 和 箭头
  53. secondDict = myTree[firstStr] # 后面的子树,子字典
  54. plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD
  55. for key in secondDict.keys():
  56. if type(secondDict[key]).__name__=='dict': # 子字典内还有字典
  57. plotTree(secondDict[key],cntrPt,str(key)) # 递归调用画子树
  58. else: # 画叶子节点
  59. plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW
  60. plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode)
  61. plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))
  62. plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD
  63. #if you do get a dictonary you know it's a tree, and the first element will be another dict
  64. # 画树
  65. def createPlot(inTree):
  66. fig = plt.figure(1, facecolor='white') # 图1 背景白色
  67. fig.clf() # 清空显示
  68. axprops = dict(xticks=[], yticks=[]) # 标尺
  69. createPlot.ax1 = plt.subplot(111, frameon=False, **axprops) # no ticks
  70. #createPlot.ax1 = plt.subplot(111, frameon=False) # ticks for demo puropses
  71. plotTree.totalW = float(getNumLeafs(inTree)) # 叶子节点总数(以便确定图的横轴长度)
  72. plotTree.totalD = float(getTreeDepth(inTree)) # 树的深度(层树)(以便确定 纵轴长度)
  73. plotTree.xOff = -0.5/plotTree.totalW; plotTree.yOff = 1.0;
  74. plotTree(inTree, (0.5,1.0), '')
  75. plt.show()
  76. # 带箭头的含有文本框的 图
  77. def plot_tree_demo():
  78. fig = plt.figure(1, facecolor='white')
  79. fig.clf()
  80. createPlot.ax1 = plt.subplot(111, frameon=False) #ticks for demo puropses
  81. plotNode('a decision node', (0.5, 0.1), (0.1, 0.5), decisionNode)
  82. plotNode('a leaf node', (0.8, 0.1), (0.3, 0.8), leafNode)
  83. plt.show()
  84. # 事先存储一个 树信息
  85. def retrieveTree(i):
  86. listOfTrees =[{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}},
  87. {'no surfacing': {0: 'no', 1: {'flippers': {0: {'head': {0: 'no', 1: 'yes'}}, 1: 'no'}}}}
  88. ]
  89. return listOfTrees[i]
  90. #createPlot(thisTree)





声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/盐析白兔/article/detail/847466
推荐阅读
相关标签
  

闽ICP备14008679号