当前位置:   article > 正文

【机器学习15】决策树模型详解

决策树模型


手动反爬虫: 原博地址

 知识梳理不易,请尊重劳动成果,文章仅发布在CSDN网站上,在其他网站看到该博文均属于未经作者授权的恶意爬取信息
  • 1

如若转载,请标明出处,谢谢!

前言

随着人工智能的不断发展,机器学习这门技术也越来越重要,很多人都开启了学习机器学习,本文就介绍了机器学习的决策树的详细内容。

一、决策树的概述

决策树是一种树形结构,其中每个内部节点表示一个属性上的测试,每个分支代表一个测试输出,每个叶子节点代表一种类别。

树的组成:(如下图示,来源百度,只做结构演示说明)
根节点:第一个选择点
非叶子结点与分支:中间过程
叶子节点:最终的决策结果
在这里插入图片描述

两大特征:从根节点开始一步步走到叶子节点(决策的过程)
所以的数据最终都会落到叶子节点,既可以做分类也可以做回归

实现决策树的流程,下面以一个简单的一家人是否爱玩游戏进行划分,如下
在这里插入图片描述
将一家人看做为是一份数据,输入到决策树中,首先会进行年龄的判断是否大于15岁(人为主观认定的数值),如果大于15,就判断为有较小可能性玩游戏,小于或等于15岁则认为有较大的可能性玩游戏,然后再进行下一步细分,判断性别,如果是男生则认为有较大的可能性玩游戏,如果为女生则为较小的可能性喜欢玩游戏。

这就是一个简单的决策树的过程,有点像最初python学习时的进行if-else成绩好坏等级分类的程序。但是,决策树这里的分类的先后顺序通常是不可以调换的,比如这里为什么要把年龄的判断放在性别判断前边,就是因为希望第一次决策判断就能实现大部分数据的筛选,尽可能的都做对了,然后再进行下一步,进而实现对上一步存在偏差数据的微调,因此根节点的重要性可想而知,其要实现对数据样本大致的判断,筛选出较为精确的数据。可以对比篮球比赛,当然是先上首发阵容,其次在考虑替补,针对于短板的地方进行补充。

那么问题就来了?(根节点如何选择)—— 凭什么先按照年龄进行划分,或者说凭什么认为他们是首发吗?

判断的依据是啥???接下来呢?次根节点又如何进行切分呢???

二、熵的作用

目标:通过一种衡量标准,来计算通过不同特征进行分支选择后的分类情况,找出来最好的那个当成根节点,以此类推。

衡量标准-熵:熵是表示随机变量不确定性的度量(解释:说人话就是物体内部的混乱程度,比如义务杂货市场里面什么都有那肯定混乱呀,专卖店里面只卖一个牌子的那就稳定多啦)

举个例子:如下,A中决策分类完成后一侧是有三个三角二个圆,另一侧是两个三角一个圆,而B中决策分类后是一侧是三角一侧是圆,显然是B方案的决策判断更靠谱一些,用熵进行解释就是熵值越小(混乱程度越低),决策的效果越好

在这里插入图片描述
有些时候是可以凭借的肉眼进行观察的,但是大部分的决策结果并不能仅仅通过人为的评判,而需要一个量化的评判标准,于是就有了判断公式: H ( X ) = − ∑ p i ∗ l o g ( p i ) , i = 1 , 2 , . . . , n H(X)=- ∑ p_{i}* log(p_{i}), i=1,2, ... , n H(X)=pilog(pi),i=1,2,...,n

这里以上面的例子进行公式解读,都只单看左侧的分类结果,对于B中的,只有三角,也就是一个分类结果, p i p_{i} pi即为取值概率,这里就为100%,再结合一下log函数,其值在[0,1]之间是递增的,那么前面加上一个负号就是递减,因此这个B右侧分类结果带入计算公式值就是0,而0又是这个公式值中的最小值。再看A中左侧的分类结果,由于存在着两种情况,因此公式中就出现了累加,分别计算两种结果的熵值情况,最后汇总,其值必然是大于0的,故A中的类别较多,熵值也就大了不少,B中的类别较为稳定(那么在分类任务重我们希望通过节点分支后数据类别的熵值大还是小呢?)

其实都不是,在分类之前数据有一个熵值,在采用决策树之后也会有熵值, 还拿A举例,最初的状态五个三角三个圆(对应一个熵值1),经过决策之后形成左侧三个三角两个圆(对应熵值2)和右侧的两个三角一个圆(对应熵值3),如果最后的熵值2+熵值3 < 熵值1,那么就可以判定这次分类较好,比原来有进步,也就是通过对比熵值(不确定性)减少的程度判断此次决策判断的好坏,不是只看分类后熵值的大小,而是要看决策前后熵值变化的情况。

为了方便记忆,于是有了 信息增益 :表示特征X使得类Y的不确定性减少的程度。(分类后的专一性,希望分类后的结果是同类在一起,比如上面希望把三角形分在一块,圆形分在一块)

三、决策树构造实例

这里使用官网提供的示例数据进行讲解,数据为14天打球的情况(实际的情况);特征为4种环境变化( x i x_{i} xi);最后的目标是希望构建决策树实现最后是否打球的预测(yes|no),数据如下

x 1 : x_{1}: x1: outlook
x 2 : x_{2}: x2: temperature
x 3 : x_{3}: x3: humidity
x 4 : x_{4}: x4: windy

p l a y : play: play: yes|no

在这里插入图片描述
由数据可知共有4种特征,因此在进行决策树构建的时候根节点的选择就有4种情况,如下。那么就回到最初的问题上面了,到底哪个作为根节点呢?是否4种划分方式均可以呢?因此 信息增益 就要正式的出场露面了
在这里插入图片描述
由于是要判断决策前后的熵的变化,首先确定一下在历史数据中(14天)有9天打球,5天不打球,所以此时的熵应为(一般log函数的底取2,要求计算的时候统一底数即可): − 9 14 ∗ l o g 2 9 14 − 5 14 ∗ l o g 2 5 14 = 0.940 - \frac{9}{14}*log_{2}\frac{9}{14} - \frac{5}{14}*log_{2} \frac{5}{14} = 0.940 149log2149145log2145=0.940

先从第一个特征下手,计算决策树分类后的熵值的变化,还是使用公式进行计算

Outlook = sunny时,熵值为0.971 ( − 2 5 ∗ l o g 2 2 5 − 3 5 ∗ l o g 2 3 5 = 0.971 - \frac{2}{5}*log_{2}\frac{2}{5} - \frac{3}{5}*log_{2} \frac{3}{5} = 0.971 52log25253log253=0.971
Outlook = overcast时,熵值为0
Outlook = rainy时,熵值为0.971( − 3 5 ∗ l o g 2 3 5 − 2 5 ∗ l o g 2 2 5 = 0.971 - \frac{3}{5}*log_{2}\frac{3}{5} - \frac{2}{5}*log_{2} \frac{2}{5} = 0.971 53log25352log252=0.971

注意:直接将计算得到的结果和上面计算出初始的结果相比较吗? (当然不是,outlook取到sunny,overcast,rainy是有不同的概率的,因此最后的计算结果要考虑这个情况)

最终的熵值计算就为: 5 14 ∗ 0.971 + 0 + 5 14 ∗ 0.971 = 0.693 \frac{5}{14}*0.971 + 0 + \frac{5}{14}*0.971 = 0.693 1450.971+0+1450.971=0.693信息增益:系统的熵值就由原始的0.940下降到了0.693,增益为
g a i n ( o u t l o o k ) = 0.247 gain(outlook) = 0.247 gain(outlook)=0.247依次类推,可以分别求出剩下三种特征分类的信息增益如下:
g a i n ( t e m p e r a t u r e ) = 0.029 , g a i n ( h u m i d i t y ) = 0.152 , g a i n ( w i n d y ) = 0.048 gain(temperature) = 0.029, gain(humidity) = 0.152, gain(windy) = 0.048 gain(temperature)=0.029,gain(humidity)=0.152,gain(windy)=0.048最后我们选择最大的那个就可以啦,相当于是遍历了一遍特征,找出来了根节点(老大),然后在其余的中继续通过信息增益找子节点(老二)…,最终整个决策树就构建完成了!

四、信息增益率和gini系数

之前使用信息增益进行判断根节点有没有什么问题,或者是这种方法是不是存在bug,有些问题是解决不了的???答案是当然有的,比如还是使用上面的14个人打球的数据,这里添加一个特征为打球的次数ID,分别为1,2,3,…,12,13,14,那么如果按照此特征进行决策判断,如下
在这里插入图片描述
由此特征进行决策判断后的结果可以发现均为单个的分支,计算熵值的结果也就为0,这样分类的结果信息增益是最大的,说明这个特征是非常有用的,如果还是按照信息增益来进行评判,树模型就势必会按照ID进行根节点的选择,而实际上按照这个方式进行决策判断并不可行,只看每次打球的ID并不能说这一天是不是会打球。

从上面的示例中可以发现信息增益无法解决这种特征分类(类似ID)后结果特别特别多的情况,故就发展了另外的决策树算法叫做 信息增益率gini系数

这里介绍一下构建决策树中使用的算法(至于前面的英文称呼,知道是一种指代关系就可以了,比如说的信息增益也可以使用ID3进行表示):

英文称呼中文称呼
ID3信息增益 (本身是存在着问题的)
C4.5信息增益率 (解决了ID3问题,考虑了自身熵)
CART使用gini系数作为衡量标准,计算公式: G i n i ( p ) = ∑ k = 1 K p k ( 1 − p k ) = 1 − ∑ k = 1 K p k 2 Gini(p) = \sum_{k=1}^{K}p_{k}(1-p_{k})=1- \sum_{k=1}^{K}p_{k}^{2} Gini(p)=k=1Kpk(1pk)=1k=1Kpk2

还是以14个人打球的数据,讲解一下信息增益率,这里说考虑了自身熵,解决了ID3的问题是如何解决的呢?假设按照每次打球次数ID进行决策判断,结果还是分为了14类,计算后的信息增益为Q,从数值的大小上看一般是一个较小数值(0.940-0=0.940,绝对数值),但是对于其他数据特征分类的结果来看这个数值又是很大(0.940相较于其他的gain数值,对比数值),这时候的信息增益率就为 Q ( − 1 14 l o g 2 1 14 ) ∗ 14 \frac{Q}{(-\frac{1}{14}log_{2}\frac{1}{14})*14} (141log2141)14Q,从对比数值来看,这个Q是很大,但是考虑到自身的熵值,参考一下log函数的图像,分母的值就是更大了,由此这个公式计算的数值(信息增益率)就会很小,也就解决了信息增益中无法处理分类后数据类别特别多的情况

gini系数计算公式和熵的衡量标准类似,只是计算方式不相同,这里值越小代表这决策分类的效果越好,比如当p的累计值取1了,那么最后结果就为0,当p取值较小时,经过平方后就更小了,由此计算的结果也就趋近1了

信息增益率 是对根据熵值进行判定方式的改进,而 gini系数 则是另起炉灶,有着自己的计算方式

五、剪枝方法

首先明确一下为啥会有剪枝的操作,对比一下日常生活中的种植园工修理花草树木,如果不管,任其生长,最后结果很可能是杂草灌木丛生,决策树过拟合风险很大,理论上是可以完全分得开全部的数据,也就是树会野蛮生长,每个叶子节点都会有一个数据,然后就把所有的数据全部分类完成

过拟合通俗的讲就是:你在日常的测试做题或者考核的时候很好,但是一到大型的考试就不行了,这时候就存在边看答案边做题的现象,因此就可以把每道题都做对,就像决策树一样,可以无限的细分下去,满足所有的分类结果,但是在最后预测的时候却表现不好,故需要避免这种情况,也就需要进行剪枝,控制模型的复杂程度
在这里插入图片描述

剪枝策略:预剪枝(边建立决策树边进行剪枝的操作,比较实用)、后剪枝(当建立完成后再进行剪枝操作)

预剪枝方式:限制深度(比如指定到某一具体数值后不再进行分裂)、叶子节点个数、叶子节点样本数、信息增益量等
后剪枝方式:通过一定的衡量标准, C α ( T ) = C ( T ) + α ∗ ∣ T l e a f ∣ C_{\alpha}(T) = C(T) + \alpha*|T_{leaf}| Cα(T)=C(T)+αTleaf,叶子节点越多,损失越大

后剪枝的工作流程,比如选择如下的节点,进行判断其不分裂行不行? 不分裂的条件就是分裂之后的结果比分裂之前效果还要差劲。

在这里插入图片描述

按照上面的计算公式,分裂之前 C α ( T ) = 0.444 ∗ 6 + 1 ∗ α C_{\alpha}(T) =0.444*6 + 1* \alpha Cα(T)=0.4446+1α分裂之后就是两个叶子节点之和相加 C α ( T ) = 3 ∗ 0 + 1 ∗ α + 0.444 ∗ 3 + 1 ∗ α = 0.444 ∗ 3 + 2 ∗ α C_{\alpha}(T) =3*0+1* \alpha + 0.444*3+1* \alpha=0.444*3 + 2* \alpha Cα(T)=30+1α+0.4443+1α=0.4443+2α最后就变成了比较这两次取得的数值,值越大代表着损失越大,也就越不好,取值的大小是取决于我们给定的 α \alpha α值, α \alpha α值给出的越大,模型越会控制过拟合,值较小的时候是我们希望模型取得较好的结果,过不过拟合看的不是很重要

六、分类、回归任务

树模型做分类任务,某个叶子节点中的类型是由什么所决定的?还使用最初的图示为例,树模型是属于有监督的算法,数据在输入之前就已经有标签的,比如下面“-”代表不玩游戏。“+”代表玩游戏,那么右侧红框的分类结果中有三个“-”的数据,得到数据的众数都是分类为“-”,所以之后如果有数据再分到此类别中,就都会被标记为“-”,故分类任务是有叶子节点中数据的众数决定的,少数服从多数,加入某叶子节点中有10个“-”,2个“+”,则认定该叶子节点分类为“-”

回归任务和分类任务的做法几乎是一样的,但是评估的方式是不一样的,回归是采用方差进行衡量的,比如将上面的五个人按照年龄进行判断,是否为老年人,根究A方案,显然其方差要原小于B方案中的方差,也就是人为A方案中划分方式更为合理。那么既然是回归问题也就避免不了数据取值,叶子节点中的数值计算的方式就为各个数据的平均数,那么接下来使用树模型进行预测的时候,如果数据落入该叶子节点中数值即为此前计算的平均值,假如是A方案中的右侧叶子节点,预测的结果数值就为(30+15+12) /3 =19
在这里插入图片描述

七、树模型的可视化展示

前期工作:下载可视乎的包:graphviz,并配置环境变量,检验前期工作完成的标识,在cmd命令行中输入dot -version可以正常弹出如下信息

在这里插入图片描述

进入代码书写,首先导入相关的库,书写环境为jupyter notebook

import warnings
warnings.filterwarnings('ignore')

import numpy as np
import os
%matplotlib inline
import matplotlib
import matplotlib.pyplot as plt
plt.rcParams['axes.labelsize'] = 14
plt.rcParams['xtick.labelsize'] = 12
plt.rcParams['ytick.labelsize'] = 12
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11

第二步就是进行树模型的创建

from sklearn.datasets import load_iris
from sklearn.tree import DecisionTreeClassifier

iris = load_iris() #加载鸢尾花数据集
X = iris.data[:,2:] #这里先选择两个特征petal length and width(花瓣的长度和宽度)
y = iris.target #设置标签

tree_clf = DecisionTreeClassifier(max_depth=2) #初始化树模型并设置最大的深度为2
tree_clf.fit(X,y) #训练树模型
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9

最后一步就是可视化展示

from sklearn.tree import export_graphviz #这个就是刚刚下载的软件

export_graphviz(
    tree_clf, #第一个就是刚刚训练好的树模型
    out_file="iris_tree.dot", #这里指定输出的文件路径和文件名称
    feature_names=iris.feature_names[2:], #画图中需要用到的特征名称,上面选择两个,这里也是跟着一样
    class_names=iris.target_names, #标签设置
    rounded=True, #最后两个默认即可
    filled=True
)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10

执行上面代码后会在指定的文件路径下生成相应的文件,然后,你可以使用graphviz包中的dot命令行工具将此.dot文件转换为各种格式,如PDF或PNG。下面这条命令行将.dot文件转换为.png图像文件:dot -Tpng iris_tree.dot -o iris_tree.png

执行之后,就会在同路径下生成指定格式的图片
在这里插入图片描述
图片打开后如下,至此树模型的可视化的展示也就完成了,其中白色分块中都有五行代码,第一行是指分类的条件,第二行是gini系数值,第三行是样本数据量,第四行是原本三种类别的数量,最后一行是此次认定的数据分类结果
在这里插入图片描述
如果还想要将生成的图片在jupyter notebook中进行内嵌输出,可以使用如下代码

from IPython.display import Image
Image(filename='iris_tree.png',width=400,height=400)
  • 1
  • 2

八、决策边界展示分析

概率计算:这里举个例子,加入输入的数据为花瓣长5厘米,宽1.5厘米的花。 相应的叶节点是深度为2的左节点,因此决策树应输出以下概率(对照着上面的可视化展示的树模型图)

Iris-Setosa 为 0%(0/54),
Iris-Versicolor 为 90.7%(49/54),
Iris-Virginica 为 9.3%(5/54)。

使用代码验证一下

tree_clf.predict_proba([[5,1.5]])
  • 1

输出的结果为:

array([[0.        , 0.90740741, 0.09259259]])
  • 1

如果是直接预测类别

tree_clf.predict([[5,1.5]])
  • 1

输出的结果为:

array([1]) #这里的1就代表是第二个类别,也就是Iris-Versicolor 
  • 1

将所有的数据结果进行可视化展现

from matplotlib.colors import ListedColormap

def plot_decision_boundary(clf, X, y, axes=[0, 7.5, 0, 3], iris=True, legend=False, plot_training=True):
    x1s = np.linspace(axes[0], axes[1], 100)
    x2s = np.linspace(axes[2], axes[3], 100)
    x1, x2 = np.meshgrid(x1s, x2s)
    X_new = np.c_[x1.ravel(), x2.ravel()]
    y_pred = clf.predict(X_new).reshape(x1.shape)
    custom_cmap = ListedColormap(['#fafab0','#9898ff','#a0faa0'])
    plt.contourf(x1, x2, y_pred, alpha=0.3, cmap=custom_cmap)
    if not iris:
        custom_cmap2 = ListedColormap(['#7d7d58','#4c4c7f','#507d50'])
        plt.contour(x1, x2, y_pred, cmap=custom_cmap2, alpha=0.8)
    if plot_training:
        plt.plot(X[:, 0][y==0], X[:, 1][y==0], "yo", label="Iris-Setosa")
        plt.plot(X[:, 0][y==1], X[:, 1][y==1], "bs", label="Iris-Versicolor")
        plt.plot(X[:, 0][y==2], X[:, 1][y==2], "g^", label="Iris-Virginica")
        plt.axis(axes)
    if iris:
        plt.xlabel("Petal length", fontsize=14)
        plt.ylabel("Petal width", fontsize=14)
    else:
        plt.xlabel(r"$x_1$", fontsize=18)
        plt.ylabel(r"$x_2$", fontsize=18, rotation=0)
    if legend:
        plt.legend(loc="lower right", fontsize=14)

plt.figure(figsize=(8, 4))
plot_decision_boundary(tree_clf, X, y)

#下面这一部分是根据实际的决策树的分类结果进行分类线的绘制
plt.plot([2.45, 2.45], [0, 3], "k-", linewidth=2)
plt.plot([2.45, 7.5], [1.75, 1.75], "k--", linewidth=2)
plt.plot([4.95, 4.95], [0, 1.75], "k:", linewidth=2)
plt.plot([4.85, 4.85], [1.75, 3], "k:", linewidth=2)
plt.text(1.40, 1.0, "Depth=0", fontsize=15)
plt.text(3.2, 1.80, "Depth=1", fontsize=13)
plt.text(4.05, 0.5, "(Depth=2)", fontsize=11)
plt.title('Decision Tree decision boundaries')

plt.show()
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41

输出的结果为:(关于网格和刻度方向的显示问题可以通过修改matplotlib库的配置文件修改)
在这里插入图片描述

九、决策树预剪枝常用参数

DecisionTreeClassifier类(sklearn的版本为0.23.1)还有一些其他参数类似地限制了决策树的形状:
min_samples_split(节点在分割之前必须具有的最小样本数,默认为2),
min_samples_leaf(叶子节点必须具有的最小样本数,默认为1),
max_leaf_nodes(叶子节点的最大数量),
max_features(在每个节点处评估用于拆分的最大特征数,一般不限制)。
max_depth(树最大的深度)

比如进行两个树模型不同参数的对比

from sklearn.datasets import make_moons
X,y = make_moons(n_samples=100,noise=0.25,random_state=53)

tree_clf1 = DecisionTreeClassifier(random_state=42)
tree_clf2 = DecisionTreeClassifier(min_samples_leaf=4,random_state=42) #就只有一个参数不同,进行对比
tree_clf1.fit(X,y)
tree_clf2.fit(X,y)

plt.figure(figsize=(12,4))
plt.subplot(121)
plot_decision_boundary(tree_clf1,X,y,axes=[-1.5,2.5,-1,1.5],iris=False)
plt.title('No restrictions')

plt.subplot(122)
plot_decision_boundary(tree_clf2,X,y,axes=[-1.5,2.5,-1,1.5],iris=False)
plt.title('min_samples_leaf=4')
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16

输出的结果为:(可以发现不做任何限制的时候,模型过拟合了,而添加叶子节点最小样本数参数后,模型变得更加可靠,也可以测试一下其他的参数)
在这里插入图片描述
树模型对数据的敏感,比如将数据的旋转45度,那么模型的决策边界是会发生变化的

np.random.seed(6)
Xs = np.random.rand(100, 2) - 0.5
ys = (Xs[:, 0] > 0).astype(np.float32) * 2

angle = np.pi / 4
rotation_matrix = np.array([[np.cos(angle), -np.sin(angle)], [np.sin(angle), np.cos(angle)]])
Xsr = Xs.dot(rotation_matrix)

tree_clf_s = DecisionTreeClassifier(random_state=42)
tree_clf_s.fit(Xs, ys)
tree_clf_sr = DecisionTreeClassifier(random_state=42)
tree_clf_sr.fit(Xsr, ys)

plt.figure(figsize=(11, 4))
#旋转之前
plt.subplot(121)
plot_decision_boundary(tree_clf_s, Xs, ys, axes=[-0.7, 0.7, -0.7, 0.7], iris=False)
plt.title('Sensitivity to training set rotation')  

#旋转之后
plt.subplot(122)
plot_decision_boundary(tree_clf_sr, Xsr, ys, axes=[-0.7, 0.7, -0.7, 0.7], iris=False)
plt.title('Sensitivity to training set rotation')

plt.show()
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25

输出的结果为:(切记树模型对数据是十分敏感的)
在这里插入图片描述

十、回归树模型

回归任务,就像前文中介绍的一样,回归和分类做法几乎是一样的,但是评估的方式是不一样的,分类通过可视化的树模型中可以看出是采用gini系数,而回归任务就是使用方差的平均数(mse:mean squared error)

#设置数据
np.random.seed(42)
m=200
X=np.random.rand(m,1)
y = 4*(X-0.5)**2
y = y + np.random.randn(m,1)/10

#训练模型
from sklearn.tree import DecisionTreeRegressor
tree_reg = DecisionTreeRegressor(max_depth=2)
tree_reg.fit(X,y)

#可视化模型
export_graphviz(
        tree_reg,
        out_file=("regression_tree.dot"),
        feature_names=["x1"],
        rounded=True,
        filled=True
    )
    
#notebook中显示
# 你的第二个决策树长这样
from IPython.display import Image
Image(filename="regression_tree.png",width=400,height=400,)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25

输出的结果为:

在这里插入图片描述
然后对比一下树的深度对模型的影响,通过可视化的图形展示

from sklearn.tree import DecisionTreeRegressor

tree_reg1 = DecisionTreeRegressor(random_state=42, max_depth=2)
tree_reg2 = DecisionTreeRegressor(random_state=42, max_depth=3)
tree_reg1.fit(X, y)
tree_reg2.fit(X, y)

def plot_regression_predictions(tree_reg, X, y, axes=[0, 1, -0.2, 1], ylabel="$y$"):
    x1 = np.linspace(axes[0], axes[1], 500).reshape(-1, 1)
    y_pred = tree_reg.predict(x1)
    plt.axis(axes)
    plt.xlabel("$x_1$", fontsize=18)
    if ylabel:
        plt.ylabel(ylabel, fontsize=18, rotation=0)
    plt.plot(X, y, "b.")
    plt.plot(x1, y_pred, "r.-", linewidth=2, label=r"$\hat{y}$")

plt.figure(figsize=(11, 4))
plt.subplot(121)

plot_regression_predictions(tree_reg1, X, y)
for split, style in ((0.1973, "k-"), (0.0917, "k--"), (0.7718, "k--")):
    plt.plot([split, split], [-0.2, 1], style, linewidth=2)
plt.text(0.21, 0.65, "Depth=0", fontsize=15)
plt.text(0.01, 0.2, "Depth=1", fontsize=13)
plt.text(0.65, 0.8, "Depth=1", fontsize=13)
plt.legend(loc="upper center", fontsize=18)
plt.title("max_depth=2", fontsize=14)

plt.subplot(122)

plot_regression_predictions(tree_reg2, X, y, ylabel=None)
for split, style in ((0.1973, "k-"), (0.0917, "k--"), (0.7718, "k--")):
    plt.plot([split, split], [-0.2, 1], style, linewidth=2)
for split in (0.0458, 0.1298, 0.2873, 0.9040):
    plt.plot([split, split], [-0.2, 1], "k:", linewidth=1)
plt.text(0.3, 0.5, "Depth=2", fontsize=13)
plt.title("max_depth=3", fontsize=14)

plt.show()
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40

输出的结果为:

在这里插入图片描述
最后对比一下最小的叶子节点的样本数量

tree_reg1 = DecisionTreeRegressor(random_state=42)
tree_reg2 = DecisionTreeRegressor(random_state=42, min_samples_leaf=10)
tree_reg1.fit(X, y)
tree_reg2.fit(X, y)

x1 = np.linspace(0, 1, 500).reshape(-1, 1)
y_pred1 = tree_reg1.predict(x1)
y_pred2 = tree_reg2.predict(x1)

plt.figure(figsize=(11, 4))

plt.subplot(121)
plt.plot(X, y, "b.")
plt.plot(x1, y_pred1, "r.-", linewidth=2, label=r"$\hat{y}$")
plt.axis([0, 1, -0.2, 1.1])
plt.xlabel("$x_1$", fontsize=18)
plt.ylabel("$y$", fontsize=18, rotation=0)
plt.legend(loc="upper center", fontsize=18)
plt.title("No restrictions", fontsize=14)

plt.subplot(122)
plt.plot(X, y, "b.")
plt.plot(x1, y_pred2, "r.-", linewidth=2, label=r"$\hat{y}$")
plt.axis([0, 1, -0.2, 1.1])
plt.xlabel("$x_1$", fontsize=18)
plt.title("min_samples_leaf={}".format(tree_reg2.min_samples_leaf), fontsize=14)

plt.show()
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28

输出的结果为:(可以看出不约束默认情况下,树模型会尽量拟合所有的点,但是改变参数数值后,模型变得横平竖直的,也可以设置其他的常用的参数玩一玩)
在这里插入图片描述
至此关于回归树模型的介绍就完毕了,撒花✿✿ヽ(°▽°)ノ✿

总结

文章中没有对决策树的由来和历史进行过多的讲解,直接有日常的小例子入手引入决策树模型,并层层递进式的讲解决策树中涉及的知识点,然后通过代码进行可视化的展示,并以图示的方式对比不同参数对于树模型的影响,最后要注意的是决策树不仅可以做分类也可以做回归任务

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/凡人多烦事01/article/detail/148197
推荐阅读
相关标签
  

闽ICP备14008679号