赞
踩
决策树,顾名思义,就是用来产生决策的树,我们通过数据的属性特征来构造这棵树。相比于K近邻算法,决策树的主要优势在于数据形式非常容易理解。决策树的一个很重要的任务就是理解数据中所蕴含的知识信息,从数据中提取出一系列规则的过程,也就是决策树构建的过程、机器学习的过程。
决策树算法是随机森林等集成学习算法的基础,最基本的决策树没有反馈,没有修正,就是简单的输入训练集,得出一个可以应对其他不同数据集的分类方法。目前常用的决策树算法有ID3算法、改进的C4.5算法和CART算法。在本篇blog中,我着重介绍决策树的思想和实现过程,以及实现过程中的一些思考。代码学习自《机器学习实战》,代码很规范很高效,我感觉值得深入挖掘其思想。
想必大家都知道周志华老师的《机器学习》一书,俗称西瓜书,书中在讲决策树的时候,用的也是如何买瓜的数据,如下图所示。树中的内部节点表示某个属性,节点引出的分支表示此属性的所有可能的值,叶子节点表示最终的判断结果也就是分类后的类型。
那么在考虑一棵决策树的时候,我认为有以下两点需要注意一下:
因为决策树的生成是一个递归的过程,那么终止递归的条件需要思考:
决策树学习的关键是如何选择最优化分属性。一般而言,随着划分过程的不断进行,我们希望决策树的分支节点所包含的样本尽量属于同一类别,即结点的“纯度”越来越高。度量节点“纯度”的方法有多种,我们先介绍信息增益。
信息增益是信息熵的有效减少量,那信息熵又是什么呢?信息熵定义为信息的期望值,那信息又是怎么定义的呢?让我们一步一步来解释。
如果待分类的事务可能划分在多个分类之中,则符号
x
i
x_i
xi的信息定义为:
其中
p
(
x
i
)
p(x_i)
p(xi) 是选择该分类的概率。
信息熵(information entropy)是度量样本集合纯度最常用的一种指标,若$ x_i $ 构成样本集合D,那么,D的信息熵定义如下,Ent值越小,纯度越高。计算信息熵时约定:若$ p=0
,
则
,则
,则 plog_2p = 0$。
def createDataSet():
dataSet = [[1, 1, 'yes'],
[1, 1, 'yes'],
[1, 0, 'no'],
[0, 1, 'no'],
[0, 1, 'no']]
labels = ['no surfacing', 'flippers']
return dataSet, labels
from math import log
import operator
def compute_Shang(dataSet):
num = len(dataSet) # 由于代码中多次用到该值,为提高效率,显式地声明一个变量保存实例总数
labelCounts = {} # 定义一个字典
for featVec in dataSet:
currentLabel = featVec[-1] # 该条数据的标签
if currentLabel not in labelCounts.keys():
labelCounts[currentLabel] = 0 # 在label统计数组中加一条总数值为0的记录
labelCounts[currentLabel] += 1 # 统计各个标签的总次数
Shang = 0.0
for key in labelCounts:
prob = float(labelCounts[key])/num
Shang -= prob * log(prob, 2)
return Shang
跟每个介绍决策树的章节一样,我们需要选定某一个特征,然后对其他剩下的信息,根据label进行划分,从而计算熵,在下面这个函数中,我们只做了把该特征拿出来,返回其他的信息。
def splitDataSet(dataSet, axis, value):
"""
para:带划分的数据集、划分数据集的特征、需要返回的特征值
Note that:python在函数中传递的是列表的引用,在函数内部对列表对象的修改,将会影响该列表对象的
整个生命周期。为了消除这个不良影响,我们需要在函数的开始声明一个新列表对象。因为该函数代码在
同一个数据集上被调用多次,为了不修改原始数据集,创建一个新的列表对象。
"""
retDataSet = [] #
for featVec in dataSet:
if featVec[axis] == value:
reduceFeatVec = featVec[:axis]
reduceFeatVec.extend(featVec[axis+1:]) # 将要寻找的索引栏空出来,输出其他特征及标签
retDataSet.append(reduceFeatVec) # [1,2,3],[4,5,6] extend:[1,2,3,4,5,6] append:[1,2,3,[4,5,6]]
return retDataSet
在这个函数中,我们算出在一个层次中的最好的数据集划分方式,也就是找出最合适的特征。set(featList)
选择出所有的不同特征,循环遍历,计算在这个特征充当划分结点时,整体的信息熵,最后比较出最合适的特征并返回。
def chooseBestFeatureToSplit(dataSet): num_Features = len(dataSet[0]) - 1 # 定义特征的数量 base_Entropy = compute_Shang(dataSet) bestInfoGain = 0.0 bestFeature = -1 for i in range(num_Features): # 迭代所有的特征 featList = [example[i] for example in dataSet] # 这个特征下所有的样例 [1,1,1,1,0,0] [1,1,1,0,1,1] uniqueVals = set(featList) # set去掉重复 newEntropy = 0.0 for value in uniqueVals: subDataSet = splitDataSet(dataSet, i, value) # i + value : 00 01 10 11 # print(subDataSet) prob = len(subDataSet)/float(len(dataSet)) # 对应公式 newEntropy += prob * compute_Shang(subDataSet) # 对应公式 infoGain = base_Entropy - newEntropy if (infoGain > bestInfoGain): bestInfoGain = infoGain bestFeature = i return bestFeature
在训练过程中我们有时会遇到这样的情况:如果数据集处理完了所有的属性,但是类标签依然不是唯一的,此时我们需要决定如何定义该叶子节点。这是,我们采用“多数表决”的方法决定该叶子节点的分类。
def majorityCnt(classList):
classCount = {}
for vote in classList:
# classList存储了每个类标签出现的频率
if vote not in classCount.keys():
classCount[vote] = 0
classCount[vote] += 1
sortedClassCount = sorted(classCount.items(), key=operator.itemgetter(1), reverse=True)
# python3中,要将iteritems变为items
return sortedClassCount[0][0]
def createTree(dataSet, labels): """ Note that:del是python内置的关键字(比如import, return等都是python的关键字),并不是python的内置函数 (内置函数有range(), sorted()等等),del的作用是删除一个对象,不仅可以删除list中的某一个元素, 也可以删除一个list,一个变量,或者类的实例 """ classList = [example[-1] for example in dataSet] # 递归停止的第一个条件:所有类标签完全相同 if classList.count(classList[0]) == len(classList): return classList[0] # 递归停止的第二个条件是:使用完了所有特征,仍然不能把所有数据集划分成仅包含唯一类别的分组 if len(dataSet[0]) == 1: return majorityCnt(classList) # 返回出现次数最多的类别 bestFeat = chooseBestFeatureToSplit(dataSet) bestFeatLabel = labels[bestFeat] myTree = {bestFeatLabel: {}} del(labels[bestFeat]) featValues = [example[bestFeat] for example in dataSet] uniqueVals = set(featValues) for value in uniqueVals: subLabels = labels[:] myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value), subLabels) return myTree
mydata, labels = trees.createDataSet()
print(trees.createTree(mydata, labels))
结果:
{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}}
matplotlib并没有提供决策树绘制接口,所以需要自己来实现,因为决策树的主要优势是直观,而且易于理解,如果不能把它直观的显示出来,就无法发挥其优势。而且我们需要写一个通用的代码,能一直为不同决策树提供接口的代码。
首先我们先来了解一下matplotlib的注解工具annotation
,它可以在数据图形上添加文本注解。
import matplotlib.pyplot as plt
# 定义文本框和箭头格式,dict用来创建空字典
decisionNode = dict(boxstyle="sawtooth", fc="0.8")
leafNode = dict(boxstyle="round4", fc="0.8")
arrow_args = dict(arrowstyle="<-")
# 绘制带箭头的注解
def plotNode(nodeTxt, centerPt, parentPt, nodeType):
createPlot.ax1.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction',
xytext=centerPt, textcoords='axes fraction',
va="center", ha="center", bbox=nodeType, arrowprops=arrow_args)
def createPlot():
fig = plt.figure(1, facecolor='white')
fig.clf()
createPlot.axl = plt.subplot(111, frameon=False)
plotNode('决策节点', (0.5, 0.1), (0.1, 0.5), decisionNode)
plotNode('叶节点', (0.8, 0.1), (0.3, 0.8), leafNode)
plt.show()
运行示例:createPlot()
结果如下:
def retrieveTree(i):
listOfTrees =[{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}},
{'no surfacing': {0: 'no', 1: {'flippers': {0: {'head': {0: 'no', 1: 'yes'}}, 1: 'no'}}}}
]
return listOfTrees[i]
我们需要知道有多少叶节点,以便确定x轴的长度。需要知道有多少层,以便确定y轴长度。
用递归思想来得到叶节点个数和树的深度,递归停止的条件是输入到函数中的参数不再是dict类型,意味着不再拥有子节点,即为叶节点。所以一旦到达叶节点,则从递归调用中返回执行else语句,num加一。同理树的深度也是这样。
def getNumLeafs(myTree): numLeafs = 0 firstStr = list(myTree.keys())[0] # py2: myTree.keys()[0] secondDict = myTree[firstStr] for key in secondDict.keys(): if type(secondDict[key]).__name__ == 'dict': # test to see if the nodes are dictonaires, if not they are leaf nodes numLeafs += getNumLeafs(secondDict[key]) else: numLeafs += 1 return numLeafs def getTreeDepth(myTree): maxDepth = 0 firstStr = list(myTree.keys())[0] secondDict = myTree[firstStr] for key in secondDict.keys(): if type(secondDict[key]).__name__ == 'dict': # test to see if the nodes are dictonaires, if not they are leaf nodes thisDepth = 1 + getTreeDepth(secondDict[key]) else: thisDepth = 1 if thisDepth > maxDepth: maxDepth = thisDepth return maxDepth
# 在父节点和子节点之间添加文本信息 def plotMidText(cntrPt, parentPt, txtString): xMid = (parentPt[0]-cntrPt[0])/2.0 + cntrPt[0] yMid = (parentPt[1]-cntrPt[1])/2.0 + cntrPt[1] createPlot.ax1.text(xMid, yMid, txtString, va="center", ha="center", rotation=30) def plotTree(myTree, parentPt, nodeTxt): # if the first key tells you what feat was split on numLeafs = getNumLeafs(myTree) # this determines the x width of this tree depth = getTreeDepth(myTree) firstStr = list(myTree.keys())[0] # the text label for this node should be this cntrPt = (plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW, plotTree.yOff) plotMidText(cntrPt, parentPt, nodeTxt) plotNode(firstStr, cntrPt, parentPt, decisionNode) secondDict = myTree[firstStr] plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD for key in secondDict.keys(): if type(secondDict[key]).__name__ == 'dict':# test to see if the nodes are dictonaires, if not they are leaf nodes plotTree(secondDict[key], cntrPt, str(key)) # recursion else: # it's a leaf node print the leaf node plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode) plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key)) plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD # if you do get a dictonary you know it's a tree, and the first element will be another dict
其实上述分析与代码构建就是决策树源码的一部分构建过程,下面我们来构建测试集的测试代码。其实测试的过程很简单,就是将待测试的数据,沿着建好的树遍历,直到到达叶节点,返回求证。具体的代码解释见注释。
# 测试 使用特征标签特征列表
def classify(inputTree, featLabels, testVec):
firstStr = list(inputTree.keys())[0] # 取出当前树(json)中的第一个key
secondDict = inputTree[firstStr] # 第一个key对应的值(key-value)
featIndex = featLabels.index(firstStr) # 第一个key是第几个特征?返回index
key = testVec[featIndex] # 取出待分类行数据中的该特征的具体值
valueOfFeat = secondDict[key] # 这一层的value,也就是准备下一层的key
if isinstance(valueOfFeat, dict): # 比较testVec中的值与树节点的值,如果到达叶子节点,则返回当前节点的分类标签
classLabel = classify(valueOfFeat, featLabels, testVec)
else:
classLabel = valueOfFeat
return classLabel
myDat, labels = trees.createDataSet()
myTree = treePlotter.retrieveTree(0)
print(trees.classify(myTree, labels, [1,0]))
本数据集包含90个数据(训练集),分为2类,每类45个数据,每个数据4个属性:
分类种类: Iris Setosa(山鸢尾)、Iris Versicolour(杂色鸢尾),部分数据如下:
import trees import treePlotter a = [] train_data = [] count = 0 # 切分训练集 with open(r"C:\\Users\\Administrator\\Desktop\\第一次作业 (1)\\第一次作业\\Iris.txt", "r") as f: for line in f.readlines(): count = count + 1 if 41 < count < 52: pass else: line = line.strip('\n') train_data.append(line) train_data[-1] = train_data[-1].split(",") f.close() # 总的数据集 with open(r"C:\\Users\\Administrator\\Desktop\\第一次作业 (1)\\第一次作业\\Iris.txt", "r") as f: for line in f.readlines(): line = line.strip('\n') a.append(line) a[-1] = a[-1].split(",") print("train_data:") print(train_data) # 切分测试集 test_data = a[41:51] # 定义标签 labels = ['Sepal.Length', 'Sepal.Width', 'Petal.Length', 'Petal.Width'] myTree = trees.createTree(train_data, labels) print(labels) # 输出树的结构 print(myTree) # 树结构反馈在图表中 treePlotter.createPlot(myTree) # 测试如下 # classify(inputTree, featLabels, testVec) test_data_feat = [] test_data_label = [] for i in range(len(test_data)): test_data_feat.append(test_data[i][0:4]) test_data_label.append(test_data[i][-1]) test_labels = [] # 需要重新定义一下labels,因为前面的labels在经过createTree函数时,del了递归中的所有最优label,剩下的就是没有用到的标签,所以需要重新定义一下。 labels = ['Sepal.Length', 'Sepal.Width', 'Petal.Length', 'Petal.Width'] for i in range(len(test_data)): test_labels.append(trees.classify(myTree, labels, test_data_feat[i])) count2 = 0 for i in range(len(test_data)): if test_labels[i] != test_data_label[i]: count2 = count2 + 1 print("Accuracy:") print((len(test_data) - count2) / len(test_data))
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。