赞
踩
“独木难成林”
本文采用编译器:jupyter
决策树 (decision tree) 是一类常见的机器学习方法,顾名思义,决策树是基于树结构来进行决策的,这恰是人类在面临决 策问题时一种很自然的处理机制。
例如,我们要对“是否录用他作为机器学习算法工程师?”这样的问题进行决策时,通常会进行一系列的判断或“子决策”:我们先看“他是否发表过顶会论文?”如果是“没有”,则再看“是否是研究生?”如果是“是研究生”,再判断“他的项目是否和机器学习相关?”......最终我们得出决策。过程如图:
从上图可以看出,决策树也具有数据结构里 树 的概念的所有元素:边、根节点、叶子节点、深度等等。
上图所示的决策树深度为2,最多经过两次判断就可以走到叶子节点
很明显,决策树是一种非参数学习算法,天然的可以解决多分类问题(也可以解决回归问题),并且得到的结果具有非常好的可解释性。
- import numpy as np
- import matplotlib.pyplot as plt
-
- from sklearn import datasets
-
- iris = datasets.load_iris()
- # 由于要可视化展示,所以这里只取后两个特征
- X = iris.data[:,2:]
- y = iris.target
-
- # 第0类
- plt.scatter(X[y==0,0], X[y==0,1])
-
- # 第1类
- plt.scatter(X[y==1,0], X[y==1,1])
-
- # 第2类
- plt.scatter(X[y==2,0], X[y==2,1])
- plt.show()
- from sklearn.tree import DecisionTreeClassifier
-
- dt_clf = DecisionTreeClassifier(max_depth=2, criterion="entropy")
- dt_clf.fit(X, y)
- """
- Out[4]:
- DecisionTreeClassifier(class_weight=None, criterion='entropy', max_depth=2,
- max_features=None, max_leaf_nodes=None,
- min_impurity_decrease=0.0, min_impurity_split=None,
- min_samples_leaf=1, min_samples_split=2,
- min_weight_fraction_leaf=0.0, presort=False, random_state=None,
- splitter='best')
- """
-
- def plot_decision_boundary(model, axis):
-
- x0, x1 = np.meshgrid(
- np.linspace(axis[0], axis[1], int((axis[1]-axis[0])*100)).reshape(-1,1),
- np.linspace(axis[2], axis[3], int((axis[3]-axis[2])*100)).reshape(-1,1)
- )
- X_new = np.c_[x0.ravel(), x1.ravel()]
-
- y_predict = model.predict(X_new)
- zz = y_predict.reshape(x0.shape)
-
- from matplotlib.colors import ListedColormap
- custom_cmap = ListedColormap(['#EF9A9A', '#FFF59D','#90CAF9'])
-
- plt.contourf(x0, x1, zz, linewidth=5, cmap=custom_cmap)
-
- plot_decision_boundary(dt_clf, axis=[0.5, 7.5, 0, 3])
- plt.scatter(X[y==0,0], X[y==0,1])
- plt.scatter(X[y==1,0], X[y==1,1])
- plt.scatter(X[y==2,0], X[y==2,1])
- plt.show()
我们的到的决策树如下:
经过简单的实践时候我们应该好奇构造一颗决策树的方法,即每个节点在哪个维度上做划分?以及某个维度在哪个值上做划分?
决策树中最常用的标准之一是信息熵,信息熵表示的是随机变量不确定度的度量。熵越大,数据的不确定性越高;熵越小,数据的不确定性越低。表达式如下:
对于二分类任务:
所以对于上面提出的两个问题,我们的目的是使得划分后的信息熵降低
- import numpy as np
- import matplotlib.pyplot as plt
-
- def entropy(p):
- return -p * np.log(p) - (1-p) * np.log(1-p)
-
- x = np.linspace(0.01, 0.99, 200)
-
- plt.plot(x, entropy(x))
- plt.show()
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。