当前位置:   article > 正文

机器学习--决策树(ID3,C4.5,CART)的原理_决策树原理图

决策树原理图

      决策树是机器学习过程中常见的一种算法。本文将浅显的讲解决策树在机器学习过程的一些理论。

一、什么是决策树

树:树结构是计算机领域中常见的一种数据结构,他由1个根节点,若干个中间节点和若干个叶子节点构成。如图所示,1是根节点(根节点明显特征是没有任何的输入),2和3是中间节点(中间节点明显特征是有输入且有输出),4和5及6是叶子节点(叶子节点明显特征是没有任何的输入)。

决策树:决策树顾名思义使用树结构来进行做决策。接下来我们用一件明天是否打高尔夫球的案例来讲解什么是决策树。我们知道打高尔夫球需要考虑天气,气温,湿度和是否有风等因素。我们可以根据这些因素和明天是否打高尔夫球构造出一个决策树,决策明天是否打高尔夫球。那具体构造出来长什么样子呢?

我们首先观察一张关于是否打高尔夫的数据表,如下图:

NO(编号)

outlook(天气)

temperature(气温)

humidity(湿度)

windy(是否有风)

play(是否打高尔夫)

1

sunny

hot

high

false

no

2

sunny

hot

high

true

no

3

overcast

hot

high

false

yes

4

rainy

mild

high

false

yes

5

rainy

cool

normal

false

yes

6

rainy

cool

normal

true

no

7

overcast

cool

normal

true

yes

8

sunny

mild

high

false

no

9

sunny

cool

normal

false

yes

10

rainy

mild

normal

false

yes

11

sunny

mild

normal

true

yes

12

overcast

mild

high

true

yes

13

overcast

hot

normal

false

yes

14

rainy

mild

high

true

no

我们观察到天气有sunny、overcast和rainy三个值,气温有hot、mild和cool三个值,湿度有high、normal两个值,是否刮风有true和false两个值,是否打高尔夫有yes和no两个值。我们现在根据表中的因素和取值情况来构造出一根决策树。决策树节点放上表中的因素(如天气,气温等),决策树的边我们使用取值(如sunny、overcast和rainy)情况表示。在不考虑先后次序的情况下,我们举例构造出决策树如下:

根据上图,我们可以看出,根据表中给出的数据,我们可以构造出一根决策树,用来辅助做决策。但构造过程我们有一个问题,就是谁是根节点,谁是第2层上的节点,谁是第3层的呢?当然叶子节点我们能确定是play。根据不同的确定手段,这时候就出现了ID3,C4.5和CART。下面我们以决策树分类案例来分别讲解他们。

二、ID3确定节点先后位置

一)相关概念

在ID3中有几个概念我们要先理解,分别是信息熵信息增益。我们最终以信息增益大小依次去确定节点的先后位置。

信息熵是一个值,用来计算决.策(比如是否打高尔夫球)的不确定性。他越大不确定性越大。比如根据上面的表,我们假如算出是否打高尔夫球的信息熵是0.9,而根据另外一张假设的表算出来信息熵是0.5,那么我们可以断定上面表中的不确定性更大(换个表述方式,我们可以说上表中是否打高夫球更不纯。yes是9条记录,no是5条记录)。信息熵的计算公式如下:

 其中D是当前样本集合中第k类样本所占比例为pk,比如当前在没有任何约束条件的情况下,表里有14条记录,是否打高尔夫球有2类样本,yes和no

yes有9条,no有5条,那么我们可以计算出当前的Ent(D)=-((9/14)* log2(9/14)+(5/14)* log2(5/14))=0.94

信息增益也是一个值,用来判断使用某因素(如天气,气温等)去做分类的好坏,值越大分类效果越好。那么我们就可以使用最大的那个因素作为首选节点了。信息增益和上面的信息熵有关。计算公式如下:

 Gain(D,a)表达的是再集合D下a因素的信息增益,比如在上表中14条记录下,天气因素的信息增益。v是指a因素的取值情况,就比如sunny、overcast和rainy三种取值。所以14条记录下,天气(outlook)因素的信息增益Gain(D,a)等于集合D(即上表14条记录)下是否打高尔夫的信息熵减去天气取sunny、overcast和rainy三种取值信息熵的差值。

二)确定优先节点位置

首先,确定根节点。参考上表的数据,我们首先计算play的信息熵,然后再分别计算出outlook,temperatrue,humidity和windy的信息增益。谁的信息增益大,我们把谁最为根节点。

Gain(D,outlook)=0.247 

Gain(D,temperatrue)=0.029

Gain(D,humidity)=0.152 

Gain(D,windy)=0.048 

最后我们将outlook作为根节点。outlook有3种取值,所以第2层有三个节点。那第2层 用哪些因素呢?temperatrue,humidity还是windy?每个第2层的节点使用的因素可以不一样。需要计算信息增益确定。

 

其次,确定中间节点。

我们以1节点为例来进行讲解计算过程,其他中间节点可以按此类似计算。

在1节点这个位置,数据只剩下表中编号是1,2,8,9,11的五条记录。这五条记录组成集合D11,我们计算出在这个集合下,剩余可用的因素windy,temperatrue,humidity的信息增益。

Gain(D11,windy)=0.020

Gain(D11,temperatrue)=0.571

Gain(D11,humidity)=0.971

所以1节点这个位置,我们放上humidity因素。

最后,叶子节点的确定。如果数据从根节点往下分流的过程,遇到中间的某个节点(或根节点)分流,能把数据完全分开(例如,在天气是overcast的情况下,表中对应的记录全是yes),那么这个节点后面就是叶子节点。

本案例的完整决策树如下图。

 我们来总结下上面根节点以后的递归过程。

递归过程如下:

i.如果节点中所有记录都属于同一个类,则该节点是叶节点。

ii.如果节点中包含属于多个类的记录,则选择一个属性测试条件,将记录划分成较小的子集。对于测试条件的每个输出,创建一个子节点,并根据测试结果将父节点中的记录分布到子节点中。然后,对于每个子节点,返回第1步进行判断。

三、C4.5确定节点先后位置

一)决策原理优化

在ID3的分析过程中,我们发现一些不足。

首先信息增益这个指标,对取值数目较多的属性有所偏好,但数目太多并不利于做分类。例如上面表格中的编号,我们去计算他的信息增量是0.94,比其他因素大多了。但试想有14个分支,每个分支1条记录。这个决策就毫无意义了。

其次树的分支个数由取值种数确定,上表中只有2到3种取值,实际应用过程中,很多时候取值是非常多的值的,甚至是连续数值。C4.5针对这些不足进行了改善。

针对信息增益的缺陷。C4.5不再使用信息增益来选择因素放到节点中。而是使用信息增益率。它的计算公式如下:

 针对ID3不能处理连续数值。C4.5对数值进行分裂,就比如1到10分成为一类,11到20分为一类,等等。C4.5分裂的时候仍然使用信息增益率来分裂。具体的分裂步骤如下:

1)对于连续属性a,将 a的取值按从小到大排序。

2)a中相邻两个值的均值被看作可能的分裂点,把原数据划分为两部分;对于给定a的v个值,有 v-1个可能的分裂点。

3)对于每个可能的分裂点,计算对应的信息增益率;选择其中信息增益率最大的分裂点作为连续属性a的分裂点。

二)python具体使用

在python中有一个机器学习库scikit-learn,在这个库里面,有一个模块叫tree。在这个模块一个类DecisionTreeClassifier,这个类是用于构建决策树分类模型的。这个类是决策树的通用模版,里面有很多参数需要指定,当参数criterion的值指定为entropy时,表明构造C4.5决策树。该类的参数说明列举如下:

参数名称

说明

criterion

接收str。表示节点(特征)选择的准则,使用信息增益“entropy”的是C4.5算法;使用基尼系数“gini”的CART算法。默认为“gini”

splitter

接收str,可选参数为“best”或“random”。表示特征划分点选择标准,“best”在特征的所有划分点中找出最优的划分点;“random”在随机的部分划分点中找出局部最优划分点。默认为“best”

max_depth

接收int。表示决策树的最大深度。默认为None

min_samples_split

接收int或float。表示子数据集再切分需要的最小样本量。默认为2

min_samples_leaf

接收int或float。表示叶节点所需的最小样本数,若低于设定值,则该叶节点和其兄弟节点都会被剪枝。默认为1

min_weight_fraction_leaf

接收int、float、str或None。表示在叶节点处的所有输入样本权重总和的最小加权分数。默认为None

max_features

接收float。表示特征切分时考虑的最大特征数量,默认是对所有特征进行切分。传入int类型的值,表示具体的特征个数;浮点数表示特征个数的百分比;sqrt表示总特征数的平方根;log2表示总特征数求log2后的个数的特征。默认为None

random_state

接收int、RandomState实例或None。表示随机种子的数量,若设置了随机种子,则最后的准确率都是一样的;若接收int,则指定随机数生成器的种子;若接收RandomState,则指定随机数生成器;若为None,则指定使用默认的随机数生成器。默认为None

max_leaf_nodes

接收int或None。表示最大叶节点数。默认为None,即无限制

参数名称

说明

min_impurity_decrease

接收float。表示切分点不纯度最小减少的程度,若某节点的不纯度减少小于或等于这个值,则切分点就会被移除。默认为0.0

min_impurity_split

接收float。表示切分点最小不纯度,它用来限制数据集的继续切分(决策树的生成)。若某个节点的不纯度(分类错误率)小于这个阈值,则该点的数据将不再进行切分。无默认,但该参数将被移除,可使用min_impurity_decrease参数代替

class_weight

接收dict、dict列表、balanced或None。表示分类模型中各种类别的权重,在出现样本不平衡时,可以考虑调整class_weight系数去调整,防止算法对训练样本多的类别偏倚。默认为None

presort

接收bool。表示是否提前对特征进行排序。默认为False

四、CART确定节点先后位置

cart算法和c4.5类似也是在确定节点先后的指标(ID3中是信息增量)和连续数值上进行类改进。CART分类树采用“基尼指数”(Gini index)作为纯度度量,回归树选取Gini_ δ为评价分裂属性的指标。

一)分类树

在CART中,使用Gini(D)和Gini_index代替原来ID3种的信息熵Ent(D)和信息增量Gain(D,a)。指标具体求解公式如下:

 Gini(D) 反映了从数据集D中随机抽取两个样本,其类别标记不一致的概率。因此Gini(D) 越小,则数据集D的纯度越高。

CART分类树构建的是一棵二叉树,基尼指数考察每个属性的二元划分。如果取值是连续数值,按照C4.5的方式进行分裂,再按照二元划分的结果把数据集划分为两部分。再计算每个二元划分下属性a的基尼指数Gini_index(D,a)选择其中基尼指数最小的二元划分作为属性a的分裂点。

二)CART回归树

其实求解过程和分类树差不多,只不过指标进行了对应的更换。Gini(D)、Gini_index被δ(D)、Gini_δ(D,a)代替。他两者的具体计算公式如下:

其中μ表示样本集D中预测结果的均值, 表yk示第k个样本预测结果。 

对于含有n个样本的样本集D,根据属性a的二元划分,将数据集D划分成两部分,则划分成两部分之后,计算Gini_δ(D,a),选取其中的最小值作为属性a 的最优二元划分方案。

三)python实现CART

实现过程中用的类DecisionTreeClassifier我们在C4.5那一节已经说过了。下面我们看一段简单的f分类实现代码。

  1. from sklearn.tree import DecisionTreeClassifier
  2. from sklearn.datasets import load_iris
  3. from sklearn.model_selection import train_test_split
  4. iris = load_iris() # 加载数据
  5. data = iris.data # 属性列
  6. target = iris.target # 标签列
  7. # 划分训练集、测试集
  8. traindata, testdata, traintarget, testtarget =train_test_split(data, target, test_size=0.2, random_state=120)
  9. model_dtc = DecisionTreeClassifier() # 确定决策树参数
  10. model_dtc.fit(traindata, traintarget) # 拟合数据
  11. print("建立的决策树模型为:\n", model_dtc)
  12. # 预测测试集结果
  13. testtarget_pre = model_dtc.predict(testdata)
  14. print('前10条记录的预测值为:\n', testtarget_pre[:10])
  15. print('前10条记录的实际值为:\n', testtarget[:10])

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/Gausst松鼠会/article/detail/526680
推荐阅读
相关标签
  

闽ICP备14008679号