当前位置:   article > 正文

机器学习算法系列之决策树_id3 年龄有工作有房子信用类别

id3 年龄有工作有房子信用类别

本系列机器学习的文章打算从机器学习算法的一些理论知识、python实现该算法和调一些该算法的相应包来实现。

目录

一、决策树原理

1、决策树的模型与学习

二、决策树基于python实现

三、基于sklearn实现


一、决策树原理

1、决策树的模型与学习

决策树简介:决策树是一种典型的分类方法。首先对数据进行处理,利用归纳算法生成可读的规则和决策树,然后使用决策树对新数据进行分析。本质上决策树是通过一系列的规则对数据进行分类的过程。

决策树的优点:1、推理过程容易理解,决策过程可以表示为if-then形式;2、推理过程完全依赖于属性变量的取值特点;3、可自动忽略目标变量没有贡献的属性变量,也为判断属性变量的重要性,减少变量的数目提供参考。

1、决策树算法

于决策树算法相关的算法包括:CLS,ID3,C4.5,CART。

决策树的基本组成部分:决策结点、分支和叶子。

决策树中最上面的结点称为根节点,是整个决策树的开始。每个分支是一个新的决策结点,或者是树的叶子。每个决策结点代表一个问题或者决策,通常对应待分类的属性。每个叶子结点代表一种可能分类的结果。

一颗决策是否购买电脑的决策树

在沿着决策树从上到下遍历的过程中,在每个结点都有一个测试。对每个结点上问题的不同测试输出导致 不同的分支,最后会到达一个叶子结点。这一过程就是利用决策树进行分类的过程,利用若干个变量来判断属性的类别。

决策树学习的本质上是从训练数据集中归纳出一组分类规则,于训练数据集不矛盾的决策树, 而能对训练数据集能够正确分类的决策树可能有多个,也可能是一个都没有,我们需要的是一个于训练数据矛盾较小的决策树,同时具有很好的泛化能力。我们最终选择的条件概率模型不仅对训练数据能够有很好的预测,而去对未知数据也能有很好的预测。

2、信息增益

熵(entropy): 信息量大小的度量, 即表示随机变量不确定性的度量。

熵的通俗解释:事件ai的信息量I( ai )可如下度量:其中p(ai)表示事件ai发生的概率。

假设有n个互不相容的事件a1,a2,a3,….,an,它们中有且仅有一个发生, 则其平均的信息量(熵)可如下度量:

熵的理论解释:设X是一个取有限个值的离散随机变量, 其概率分布为:则随机变量X的熵定义为:

熵越大,随机变量的不确定性越大。举例:当为0,1分布时,如果pi=1,那么H(x)=0,可以明显看出熵值为0,因为随机变量x已经确定了为1。

条件熵H(Y|X): 表示在己知随机变量X的条件下随机变量Y的不确定性, 定义为X给定条件下Y的条件概率分布的熵对X的数学期望:

当熵和条件熵中的概率由数据估计(特别是极大似然估计)得到时, 所对应的熵与条件熵分别称为经验熵(empirical entropy)和经验条件熵(empirical conditional entropy )

 

信息增益定义:特征A对训练数据集D的信息增益,g(D,A), 定义为集合D的经验熵H(D)与特征A给定条件下D的经验条件熵H(D|A)之差, 即:g(D,A)=H(D)-H(D|A)

表示得知特征X的信息而使的类Y的信息的不确定性减少的程度。—般地, 熵H(Y)与条件熵H(Y|X)之差称为互信息(mutual information)
决策树学习中的信息增益等价于训练数据集中类与特征的互信息。

信息增益算法流程:

输入: 训练数据集D和特征A;
输出: 特征A对训练数据集D的信息增益g(D,A)
1、 计算数据集D的经验熵H(D)
2、 计算特征A对数据集D的经验条件熵H(D|A)
3、 计算信息增益

这里为什么要引入信息增益的概念呢? 因为后面我们在用决策树算法构建一颗决策树时,需要利用信息增益来确定我们分裂结点优先选用哪个特征来进行分类,因为分类特征选取的顺序不同也会导致决策树的划分不同。

3、决策树的生成

ID3算法:ID3算法是一种经典的决策树学习算法, 由Quinlan于1979年提出。ID3算法主要针对属性选择问题。 是决策树学习方法中最具影响和最为典型的算法。该方法使用信息增益度选择测试属性。当获取信息时, 将不确定的内容转为确定的内容, 因此信息伴着不确定性。从直觉上讲, 小概率事件比大概率事件包含的信息量大。 如果某件事情是“百年一见” 则肯定比“习以为常” 的事件包含的信息量大。
在决策树分类中, 假设S是训练样本集合, |S|是训练样本数, 样本划分为n个不同的类C1,C2,….Cn, 这些类的大小分别标记为|C1|, |C2|, …..,|Cn|。 

 

ID3算法计算流程:

1 决定分类属性;
2 对目前的数据表, 建立一个节点N
3 如果数据库中的数据都属于同一个类, N就是树叶, 在树叶上标出所属的类
4 如果数据表中没有其他属性可以考虑, 则N也是树叶, 按照少数服从多数的原则在树叶上标出所属类别
5 否则, 根据平均信息期望值E或GAIN值选出一个最佳属性作为节点N的测试属性
6 节点属性选定后, 对于该属性中的每个值:
从N生成一个分支, 并将数据表中与该分支有关的数据收集形
成分支节点的数据表, 在表中删除节点属性那一栏如果分支数
据表非空, 则运用以上算法从该节点建立子树。

假设我们得到一批数据:

我们用这些数据来讲解构建一个决策树的流程:

第1步:计算决策属性的熵决策属性“买计算机”/结果:买/不买)

统计可以得到 |C1|(买)=641,|C2|(不买)=383    |D|=|C1|+|C2|=1024   可以计算出某类占总数的概率:p1=641/1021=0.6260  p2=0.3740。 这时候可以计带入经验熵计算公式:算出数据集的经验熵:

 

第二步:计算条件属性的熵

从上表中可以看出条件属性一共有四个:年龄、收入、学生、信誉。分别计算不同属性的信息增益。

条件熵计算公式

这里只举一个年龄属性的计算方法(其他属性计算方法一样。)

年龄共分三组:青年、中年、老年。

青年买与不买比例为128/256  |D11|(青年人买)=128,|D12|(青年不买)=256  |D1|=384, p1=128/384  p2=256/384

带入公式可以计算出:H(D1)=0.9138。

中年买与不买比例为256/0  |D21|(中年人买)=256,|D22|(中年不买)=0  |D2|=256, p1=256/256  p2=0/256

同理带入熵计算公式得出:H(D2)=0

老年买与不买比例为125/127  |D31|(老年人买)=125,|D32|(老年不买)=127  |D3|=252, p1=125/252  p2=127/252

同理带入熵计算公式得出:H(D3)= 0.9157

各类所占比例:青年:384/1024=0.375;中年:256/1024=0.25;老年:384/1024=0.375

计算年龄的平均信息期望:E(年龄)=0.375*0.9138+0.25*0+0.375*0.9157=0.6877;  G(年龄信息增益)=0.9537-0.6877=0.2660

同样的方法计算出其他属性的增益信息为:

收入信息增益=0.9537-0.9361=0.0176 
学生信息增益=0.9537-0.7811=0.1726 
信誉信息增益=0.9537-0.9048=0.0453 

可以看到年龄的信息增益值最大,所以年龄这个分类属性首先就被归类为分裂结点。

在接下来的分裂过程中同样使用这样的办法,以此来确定最终的决策树。

总结:ID3算法的基本思想是, 以信息熵为度量, 用于决策树节点的属性选择, 每次优先选取信息量最多的属性, 亦即能使熵值变为最小的属性, 以构造一颗熵值下降最快的决策树, 到叶子节点处的熵值为0。 此时, 每个叶子节点对应的实例集中的实例属于同一类。

4、决策树的剪枝

决策树剪枝通过极小化决策树整体函数和损失函数来实现。

设树T的叶结点个数为|T|,t是树T的叶结点, 该叶结点有Nt个样本点, 其中k类的样本点有Ntk个, k=1,2..K,Ht(T)为叶结点t上的经验熵, α≥0为参数, 损失函数:

最终:

树的剪枝算法:

输入:生成算法产生的整个树T,参数alpha;

输出:修建后的子树T(α)

(1)计算每个结点的经验熵

(2)递归地从树的结点向上回缩,设一组叶结点回缩到其父结点之前与之和的损失函数分
别为:如果: 则进行剪枝

(3)返回(2),直至不能继续为止,得到损失函数最小地子树。

 


二、决策树基于python实现

 

  1. # -*- coding: utf-8 -*-
  2. """
  3. 用字典存储决策树结构:
  4. {'有自己的房子':{0:{'有工作':{0:'no', 1:'yes'}}, 1:'yes'}}
  5. 年龄:0代表青年,1代表中年,2代表老年
  6. 有工作:0代表否,1代表是
  7. 有自己的房子:0代表否,1代表是
  8. 信贷情况:0代表一般,1代表好,2代表非常好
  9. 类别(是否给贷款):no代表否,yes代表是
  10. pickle包可以将决策树保存下来,方便下次直接调用
  11. """
  12. from matplotlib.font_manager import FontProperties
  13. import matplotlib.pyplot as plt
  14. from math import log
  15. import operator
  16. import pickle
  17. """
  18. 函数说明:创建测试数据集
  19. Parameters:
  20. None
  21. Returns:
  22. dataSet - 数据集
  23. labels - 分类属性
  24. """
  25. def createDataSet():
  26. # 数据集
  27. dataSet = [[0, 0, 0, 0, 'no'],
  28. [0, 0, 0, 1, 'no'],
  29. [0, 1, 0, 1, 'yes'],
  30. [0, 1, 1, 0, 'yes'],
  31. [0, 0, 0, 0, 'no'],
  32. [1, 0, 0, 0, 'no'],
  33. [1, 0, 0, 1, 'no'],
  34. [1, 1, 1, 1, 'yes'],
  35. [1, 0, 1, 2, 'yes'],
  36. [1, 0, 1, 2, 'yes'],
  37. [2, 0, 1, 2, 'yes'],
  38. [2, 0, 1, 1, 'yes'],
  39. [2, 1, 0, 1, 'yes'],
  40. [2, 1, 0, 2, 'yes'],
  41. [2, 0, 0, 0, 'no']]
  42. # 分类属性
  43. labels = ['年龄', '有工作', '有自己的房子', '信贷情况']
  44. # 返回数据集和分类属性
  45. return dataSet, labels
  46. """
  47. 函数说明:计算给定数据集的经验熵(香农熵)
  48. Ent(D) = -SUM(kp*Log2(kp))
  49. Parameters:
  50. dataSet - 数据集
  51. Returns:
  52. shannonEnt - 经验熵(香农熵)
  53. """
  54. def calcShannonEnt(dataSet):
  55. # 返回数据集的行数
  56. numEntires = len(dataSet)
  57. # 保存每个标签(Label)出现次数的“字典”
  58. labelCounts = {}
  59. # 对每组特征向量进行统计
  60. for featVec in dataSet:
  61. # 提取标签(Label)信息
  62. currentLabel = featVec[-1]
  63. # 如果标签(Label)没有放入统计次数的字典,添加进去
  64. if currentLabel not in labelCounts.keys():
  65. # 创建一个新的键值对,键为currentLabel值为0
  66. labelCounts[currentLabel] = 0
  67. # Label计数
  68. labelCounts[currentLabel] += 1
  69. # 经验熵(香农熵)
  70. shannonEnt = 0.0
  71. # 计算香农熵
  72. for key in labelCounts:
  73. # 选择该标签(Label)的概率
  74. prob = float(labelCounts[key]) / numEntires
  75. # 利用公式计算
  76. shannonEnt -= prob*log(prob, 2)
  77. # 返回经验熵(香农熵)
  78. return shannonEnt
  79. """
  80. 函数说明:按照给定特征划分数据集
  81. Parameters:
  82. dataSet - 待划分的数据集
  83. axis - 划分数据集的特征
  84. values - 需要返回的特征的值
  85. Returns:
  86. None
  87. """
  88. def splitDataSet(dataSet, axis, value):
  89. # 创建返回的数据集列表
  90. retDataSet = []
  91. # 遍历数据集的每一行
  92. for featVec in dataSet:
  93. if featVec[axis] == value:
  94. # 去掉axis特征
  95. reducedFeatVec = featVec[:axis]
  96. # 将符合条件的添加到返回的数据集
  97. # extend() 函数用于在列表末尾一次性追加另一个序列中的多个值(用新列表扩展原来的列表)。
  98. reducedFeatVec.extend(featVec[axis+1:])
  99. # 列表中嵌套列表
  100. retDataSet.append(reducedFeatVec)
  101. # 返回划分后的数据集
  102. return retDataSet
  103. """
  104. 函数说明:选择最优特征
  105. Gain(D,g) = Ent(D) - SUM(|Dv|/|D|)*Ent(Dv)
  106. Parameters:
  107. dataSet - 数据集
  108. Returns:
  109. bestFeature - 信息增益最大的(最优)特征的索引值
  110. """
  111. def chooseBestFeatureToSplit(dataSet):
  112. # 特征数量
  113. numFeatures = len(dataSet[0]) - 1
  114. # 计算数据集的香农熵
  115. baseEntropy = calcShannonEnt(dataSet)
  116. # 信息增益
  117. bestInfoGain = 0.0
  118. # 最优特征的索引值
  119. bestFeature = -1
  120. # 遍历所有特征
  121. for i in range(numFeatures):
  122. # 获取dataSet的第i个所有特征存在featList这个列表中(列表生成式)
  123. featList = [example[i] for example in dataSet]
  124. # 创建set集合{},元素不可重复,重复的元素均被删掉
  125. # 从列表中创建集合是python语言得到列表中唯一元素值得最快方法
  126. uniqueVals = set(featList)
  127. # 经验条件熵
  128. newEntropy = 0.0
  129. # 计算信息增益
  130. for value in uniqueVals:
  131. # subDataSet划分后的子集
  132. subDataSet = splitDataSet(dataSet, i, value)
  133. # 计算子集的概率
  134. prob = len(subDataSet) / float(len(dataSet))
  135. # 根据公式计算经验条件熵
  136. newEntropy += prob * calcShannonEnt(subDataSet)
  137. # 信息增益
  138. infoGain = baseEntropy - newEntropy
  139. # 打印每个特征的信息增益
  140. print("第%d个特征的增益为%.3f" % (i, infoGain))
  141. # 计算信息增益
  142. if(infoGain > bestInfoGain):
  143. # 更新信息增益,找到最大的信息增益
  144. bestInfoGain = infoGain
  145. # 记录信息增益最大的特征的索引值
  146. bestFeature = i
  147. # 返回信息增益最大的特征的索引值
  148. return bestFeature
  149. """
  150. 函数说明:统计classList中出现次数最多的元素(类标签)
  151. 服务于递归第两个终止条件
  152. Parameters:
  153. classList - 类标签列表
  154. Returns:
  155. sortedClassCount[0][0] - 出现次数最多的元素(类标签)
  156. """
  157. def majorityCnt(classList):
  158. classCount = {}
  159. # 统计classList中每个元素出现的次数
  160. for vote in classList:
  161. if vote not in classCount.keys():
  162. classCount[vote] = 0
  163. classCount[vote] += 1
  164. # 根据字典的值降序排序
  165. # operator.itemgetter(1)获取对象的第1列的值
  166. sortedClassCount = sorted(classCount.items(), key = operator.itemgetter(1), reverse = True)
  167. # 返回classList中出现次数最多的元素
  168. return sortedClassCount[0][0]
  169. """
  170. 函数说明:创建决策树(ID3算法)
  171. 递归有两个终止条件:1、所有的类标签完全相同,直接返回类标签
  172. 2、用完所有标签但是得不到唯一类别的分组,即特征不够用,挑选出现数量最多的类别作为返回
  173. Parameters:
  174. dataSet - 训练数据集
  175. labels - 分类属性标签
  176. featLabels - 存储选择的最优特征标签
  177. Returns:
  178. myTree - 决策树
  179. """
  180. def createTree(dataSet, labels, featLabels):
  181. # 取分类标签(是否放贷:yes or no
  182. classList = [example[-1] for example in dataSet]
  183. # 如果类别完全相同则停止继续划分
  184. if classList.count(classList[0]) == len(classList):
  185. return classList[0]
  186. # 遍历完所有特征时返回出现次数最多的类标签
  187. if len(dataSet[0]) == 1:
  188. return majorityCnt(classList)
  189. # 选择最优特征
  190. bestFeat = chooseBestFeatureToSplit(dataSet)
  191. # 最优特征的标签
  192. bestFeatLabel = labels[bestFeat]
  193. featLabels.append(bestFeatLabel)
  194. # 根据最优特征的标签生成树
  195. myTree = {bestFeatLabel:{}}
  196. # 删除已经使用的特征标签
  197. del(labels[bestFeat])
  198. # 得到训练集中所有最优解特征的属性值
  199. featValues = [example[bestFeat] for example in dataSet]
  200. # 去掉重复的属性值
  201. uniqueVals = set(featValues)
  202. # 遍历特征,创建决策树
  203. for value in uniqueVals:
  204. myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value), labels, featLabels)
  205. return myTree
  206. """
  207. 函数说明:获取决策树叶子结点的数目
  208. Parameters:
  209. myTree - 决策树
  210. Returns:
  211. numLeafs - 决策树的叶子结点的数目
  212. """
  213. def getNumLeafs(myTree):
  214. # 初始化叶子
  215. numLeafs = 0
  216. # python3中myTree.keys()返回的是dict_keys,不是list,所以不能用
  217. # myTree.keys()[0]的方法获取结点属性,可以使用list(myTree.keys())[0]
  218. # next() 返回迭代器的下一个项目 next(iterator[, default])
  219. firstStr = next(iter(myTree))
  220. # 获取下一组字典
  221. secondDict = myTree[firstStr]
  222. for key in secondDict.keys():
  223. # 测试该结点是否为字典,如果不是字典,代表此节点为叶子结点
  224. if type(secondDict[key]).__name__ == 'dict':
  225. numLeafs += getNumLeafs(secondDict[key])
  226. else:
  227. numLeafs += 1
  228. return numLeafs
  229. """
  230. 函数说明:获取决策树的层数
  231. Parameters:
  232. myTree - 决策树
  233. Returns:
  234. maxDepth - 决策树的层数
  235. """
  236. def getTreeDepth(myTree):
  237. # 初始化决策树深度
  238. maxDepth = 0
  239. # python3中myTree.keys()返回的是dict_keys,不是list,所以不能用
  240. # myTree.keys()[0]的方法获取结点属性,可以使用list(myTree.keys())[0]
  241. # next() 返回迭代器的下一个项目 next(iterator[, default])
  242. firstStr = next(iter(myTree))
  243. # 获取下一个字典
  244. secondDict = myTree[firstStr]
  245. for key in secondDict.keys():
  246. # 测试该结点是否为字典,如果不是字典,代表此节点为叶子结点
  247. if type(secondDict[key]).__name__ == 'dict':
  248. thisDepth = 1 + getTreeDepth(secondDict[key])
  249. else:
  250. thisDepth = 1
  251. # 更新最深层数
  252. if thisDepth > maxDepth:
  253. maxDepth = thisDepth
  254. # 返回决策树的层数
  255. return maxDepth
  256. """
  257. 函数说明:绘制结点
  258. Parameters:
  259. nodeTxt - 结点名
  260. centerPt - 文本位置
  261. parentPt - 标注的箭头位置
  262. nodeType - 结点格式
  263. Returns:
  264. None
  265. """
  266. def plotNode(nodeTxt, centerPt, parentPt, nodeType):
  267. # 定义箭头格式
  268. arrow_args = dict(arrowstyle="<-")
  269. # 设置中文字体
  270. font = FontProperties(fname=r"C:\Windows\Fonts\simsun.ttc", size=14)
  271. # 绘制结点createPlot.ax1创建绘图区
  272. # annotate是关于一个数据点的文本
  273. # nodeTxt为要显示的文本,centerPt为文本的中心点,箭头所在的点,parentPt为指向文本的点
  274. createPlot.ax1.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction',
  275. xytext=centerPt, textcoords='axes fraction',
  276. va='center', ha='center', bbox=nodeType,
  277. arrowprops=arrow_args, FontProperties=font)
  278. """
  279. 函数说明:标注有向边属性值
  280. Parameters:
  281. cntrPt、parentPt - 用于计算标注位置
  282. txtString - 标注内容
  283. Returns:
  284. None
  285. """
  286. def plotMidText(cntrPt, parentPt, txtString):
  287. # 计算标注位置(箭头起始位置的中点处)
  288. xMid = (parentPt[0]-cntrPt[0])/2.0 + cntrPt[0]
  289. yMid = (parentPt[1]-cntrPt[1])/2.0 + cntrPt[1]
  290. createPlot.ax1.text(xMid, yMid, txtString, va="center", ha="center", rotation=30)
  291. """
  292. 函数说明:绘制决策树
  293. Parameters:
  294. myTree - 决策树(字典)
  295. parentPt - 标注的内容
  296. nodeTxt - 结点名
  297. Returns:
  298. None
  299. """
  300. def plotTree(myTree, parentPt, nodeTxt):
  301. # 设置结点格式boxstyle为文本框的类型,sawtooth是锯齿形,fc是边框线粗细
  302. decisionNode = dict(boxstyle="sawtooth", fc="0.8")
  303. # 设置叶结点格式
  304. leafNode = dict(boxstyle="round4", fc="0.8")
  305. # 获取决策树叶结点数目,决定了树的宽度
  306. numLeafs = getNumLeafs(myTree)
  307. # 获取决策树层数
  308. depth = getTreeDepth(myTree)
  309. # 下个字典
  310. firstStr = next(iter(myTree))
  311. # 中心位置
  312. cntrPt = (plotTree.xoff + (1.0 + float(numLeafs)) / 2.0 / plotTree.totalW, plotTree.yoff)
  313. # 标注有向边属性值
  314. plotMidText(cntrPt, parentPt, nodeTxt)
  315. # 绘制结点
  316. plotNode(firstStr, cntrPt, parentPt, decisionNode)
  317. # 下一个字典,也就是继续绘制结点
  318. secondDict = myTree[firstStr]
  319. # y偏移
  320. plotTree.yoff = plotTree.yoff - 1.0 / plotTree.totalD
  321. for key in secondDict.keys():
  322. # 测试该结点是否为字典,如果不是字典,代表此结点为叶子结点
  323. if type(secondDict[key]).__name__ == 'dict':
  324. # 不是叶结点,递归调用继续绘制
  325. plotTree(secondDict[key], cntrPt, str(key))
  326. # 如果是叶结点,绘制叶结点,并标注有向边属性值
  327. else:
  328. plotTree.xoff = plotTree.xoff + 1.0 / plotTree.totalW
  329. plotNode(secondDict[key], (plotTree.xoff, plotTree.yoff), cntrPt, leafNode)
  330. plotMidText((plotTree.xoff, plotTree.yoff), cntrPt, str(key))
  331. plotTree.yoff = plotTree.yoff + 1.0 / plotTree.totalD
  332. """
  333. 函数说明:创建绘图面板
  334. Parameters:
  335. inTree - 决策树(字典)
  336. Returns:
  337. None
  338. """
  339. def createPlot(inTree):
  340. # 创建fig
  341. fig = plt.figure(1, facecolor="white")
  342. # 清空fig
  343. fig.clf()
  344. axprops = dict(xticks=[], yticks=[])
  345. # 去掉x、y轴
  346. createPlot.ax1 = plt.subplot(111, frameon=False, **axprops)
  347. # 获取决策树叶结点数目
  348. plotTree.totalW = float(getNumLeafs(inTree))
  349. # 获取决策树层数
  350. plotTree.totalD = float(getTreeDepth(inTree))
  351. # x偏移
  352. plotTree.xoff = -0.5 / plotTree.totalW
  353. plotTree.yoff = 1.0
  354. # 绘制决策树
  355. plotTree(inTree, (0.5, 1.0), '')
  356. # 显示绘制结果
  357. plt.show()
  358. """
  359. 函数说明:使用决策树分类
  360. Parameters:
  361. inputTree - 已经生成的决策树
  362. featLabels - 存储选择的最优特征标签
  363. testVec - 测试数据列表,顺序对应最优特征标签
  364. Returns:
  365. classLabel - 分类结果
  366. """
  367. def classify(inputTree, featLabels, testVec):
  368. # 获取决策树结点
  369. firstStr = next(iter(inputTree))
  370. # 下一个字典
  371. secondDict = inputTree[firstStr]
  372. featIndex = featLabels.index(firstStr)
  373. for key in secondDict.keys():
  374. if testVec[featIndex] == key:
  375. if type(secondDict[key]).__name__ == 'dict':
  376. classLabel = classify(secondDict[key], featLabels, testVec)
  377. else:
  378. classLabel = secondDict[key]
  379. return classLabel
  380. """
  381. 函数说明:存储决策树
  382. Parameters:
  383. inputTree - 已经生成的决策树
  384. filename - 决策树的存储文件名
  385. Returns:
  386. None
  387. """
  388. def storeTree(inputTree, filename):
  389. with open(filename, 'wb') as fw:
  390. pickle.dump(inputTree, fw)
  391. """
  392. 函数说明:读取决策树
  393. Parameters:
  394. filename - 决策树的存储文件名
  395. Returns:
  396. pickle.load(fr) - 决策树字典
  397. """
  398. def grabTree(filename):
  399. fr = open(filename, 'rb')
  400. return pickle.load(fr)
  401. """
  402. 函数说明:main函数
  403. Parameters:
  404. None
  405. Returns:
  406. None
  407. """
  408. def main():
  409. dataSet, features = createDataSet()
  410. featLabels = []
  411. myTree = createTree(dataSet, features, featLabels)
  412. # 测试数据
  413. testVec = [0, 1, 1, 1]
  414. result = classify(myTree, featLabels, testVec)
  415. if result == 'yes':
  416. print('放贷')
  417. if result == 'no':
  418. print('不放贷')
  419. print(myTree)
  420. print("最终地决策树为:\n")
  421. createPlot(myTree)
  422. print("最优特征索引值:" + str(chooseBestFeatureToSplit(dataSet)))
  423. if __name__ == '__main__':
  424. main()

最后构建出地决策树为:


三、基于sklearn实现

数据:数据的Labels依次是age、prescript、astigmatic、tearRate、class
年龄、症状、是否散光、眼泪数量、分类标签

young    myope    no    reduced    no lenses
young    myope    no    normal    soft
young    myope    yes    reduced    no lenses
young    myope    yes    normal    hard
young    hyper    no    reduced    no lenses
young    hyper    no    normal    soft
young    hyper    yes    reduced    no lenses
young    hyper    yes    normal    hard
pre    myope    no    reduced    no lenses
pre    myope    no    normal    soft
pre    myope    yes    reduced    no lenses
pre    myope    yes    normal    hard
pre    hyper    no    reduced    no lenses
pre    hyper    no    normal    soft
pre    hyper    yes    reduced    no lenses
pre    hyper    yes    normal    no lenses
presbyopic    myope    no    reduced    no lenses
presbyopic    myope    no    normal    no lenses
presbyopic    myope    yes    reduced    no lenses
presbyopic    myope    yes    normal    hard
presbyopic    hyper    no    reduced    no lenses
presbyopic    hyper    no    normal    soft
presbyopic    hyper    yes    reduced    no lenses
presbyopic    hyper    yes    normal    no lenses
 

代码:

  1. # -*- coding: utf-8 -*-
  2. """
  3. 数据的Labels依次是age、prescript、astigmatic、tearRate、class
  4. 年龄、症状、是否散光、眼泪数量、分类标签
  5. """
  6. from sklearn.preprocessing import LabelEncoder, OneHotEncoder
  7. import pandas as pd
  8. import numpy as np
  9. import pydotplus
  10. from sklearn.externals.six import StringIO
  11. from sklearn import tree
  12. from IPython.display import display, Image
  13. if __name__ == '__main__':
  14. # 加载文件
  15. with open('lenses.txt') as fr:
  16. # 处理文件,去掉每行两头的空白符,以\t分隔每个数据
  17. lenses = [inst.strip().split('\t') for inst in fr.readlines()]
  18. # 提取每组数据的类别,保存在列表里
  19. lenses_targt = []
  20. for each in lenses:
  21. # 存储Label到lenses_targt中
  22. lenses_targt.append([each[-1]])
  23. # 特征标签
  24. lensesLabels = ['age', 'prescript', 'astigmatic', 'tearRate']
  25. # 保存lenses数据的临时列表
  26. lenses_list = []
  27. # 保存lenses数据的字典,用于生成pandas
  28. lenses_dict = {}
  29. # 提取信息,生成字典
  30. for each_label in lensesLabels:
  31. for each in lenses:
  32. # index方法用于从列表中找出某个值第一个匹配项的索引位置
  33. lenses_list.append(each[lensesLabels.index(each_label)])
  34. lenses_dict[each_label] = lenses_list
  35. lenses_list = []
  36. # 打印字典信息
  37. # print(lenses_dict)
  38. # 生成pandas.DataFrame用于对象的创建
  39. lenses_pd = pd.DataFrame(lenses_dict)
  40. # 打印数据
  41. # print(lenses_pd)
  42. # 创建LabelEncoder对象
  43. le = LabelEncoder()
  44. # 为每一列序列化
  45. for col in lenses_pd.columns:
  46. # fit_transform()干了两件事:fit找到数据转换规则,并将数据标准化
  47. # transform()直接把转换规则拿来用,需要先进行fit
  48. # transform函数是一定可以替换为fit_transform函数的,fit_transform函数不能替换为transform函数
  49. lenses_pd[col] = le.fit_transform(lenses_pd[col])
  50. # 打印归一化的结果
  51. # print(lenses_pd)
  52. # 创建DecisionTreeClassifier()类
  53. clf = tree.DecisionTreeClassifier(criterion='entropy', max_depth=4)
  54. # 使用数据构造决策树
  55. # fit(X,y):Build a decision tree classifier from the training set(X,y)
  56. # 所有的sklearn的API必须先fit
  57. clf = clf.fit(lenses_pd.values.tolist(), lenses_targt)
  58. dot_data = StringIO()
  59. # 绘制决策树
  60. tree.export_graphviz(clf, out_file=dot_data, feature_names=lenses_pd.keys(),
  61. class_names=clf.classes_, filled=True, rounded=True,
  62. special_characters=True)
  63. graph = pydotplus.graph_from_dot_data(dot_data.getvalue())
  64. #预测
  65. print(clf.predict([[1,1,1,0]]))

最终地决策树:

 

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/Gausst松鼠会/article/detail/635831
推荐阅读
相关标签
  

闽ICP备14008679号