赞
踩
决策树是一种常见的分类模型,在金融风控、医疗辅助诊断等诸多行业具有较为广泛的应用。决策树的核心思想是基于树结构对数据进行划分,这种思想是人类处理问题时的本能方法。例如在婚恋市场中,女方通常会先询问男方是否有房产,如果有房产再了解是否有车产,如果有车产再看是否有稳定工作……最后得出是否要深入了解的判断。
决策树的主要优点:
决策树的主要缺点:
由于决策树模型中自变量与因变量的非线性关系以及决策树简单的计算方法,使得它成为集成学习中最为广泛使用的基模型。梯度提升树,XGBoost以及LightGBM等先进的集成模型都采用了决策树作为基模型,在广告计算、CTR预估、金融风控等领域大放异彩 ,同时决策树在一些明确需要可解释性或者提取分类规则的场景中被广泛应用,而其他机器学习模型在这一点很难做到。例如在医疗辅助系统中,为了方便专业人员发现错误,常常将决策树算法用于辅助病症检测。
通过sklearn实现决策树分类
- import numpy as np
- import matplotlib.pyplot as plt
-
- from sklearn import datasets
-
- iris = datasets.load_iris()
- X = iris.data[:,2:]
- y = iris.target
-
- plt.scatter(X[y==0,0],X[y==0,1])
- plt.scatter(X[y==1,0],X[y==1,1])
- plt.scatter(X[y==2,0],X[y==2,1])
-
- plt.show()
- from sklearn.tree import DecisionTreeClassifier
-
- tree = DecisionTreeClassifier(max_depth=2,criterion="entropy")
- tree.fit(X,y)
依据模型绘制决策树的决策边界
- def plot_decision_boundary(model,axis):
- x0,x1 = np.meshgrid(
- np.linspace(axis[0],axis[1],int((axis[1]-axis[0])*100)).reshape(-1,1),
- np.linspace(axis[2],axis[3],int((axis[3]-axis[2])*100)).reshape(-1,1)
- )
- X_new = np.c_[x0.ravel(),x1.ravel()]
- y_predict = model.predict(X_new)
- zz = y_predict.reshape(x0.shape)
-
- from matplotlib.colors import ListedColormap
- custom_map = ListedColormap(["#EF9A9A","#FFF59D","#90CAF9"])
-
- plt.contourf(x0,x1,zz,linewidth=5,cmap=custom_map)
-
- plot_decision_boundary(tree,axis=[0.5,7.5,0,3])
- plt.scatter(X[y==0,0],X[y==0,1])
- plt.scatter(X[y==1,0],X[y==1,1])
- plt.scatter(X[y==2,0],X[y==2,1])
- plt.show()
实战:
Step: 库函数导入
- import numpy as np
-
- ## 导入画图库
- import matplotlib.pyplot as plt
- import seaborn as sns
-
- ## 导入决策树模型函数
- from sklearn.tree import DecisionTreeClassifier
- from sklearn import tree
Step: 训练模型
- ## 构造数据集
- x_fearures = np.array([[-1, -2], [-2, -1], [-3, -2], [1, 3], [2, 1], [3, 2]])
- y_label = np.array([0, 1, 0, 1, 0, 1])
-
- ## 调用决策树回归模型
- tree_clf = DecisionTreeClassifier()
-
- ## 调用决策树模型拟合构造的数据集
- tree_clf = tree_clf.fit(x_fearures, y_label)
Step: 数据和模型可视化
- plt.figure()
- plt.scatter(x_fearures[:,0],x_fearures[:,1], c=y_label, s=50, cmap='viridis')
- plt.title('Dataset')
- plt.show()
-
- import graphviz
- dot_data = tree.export_graphviz(tree_clf, out_file=None)
- graph = graphviz.Source(dot_data)
- graph.render("pengunis")
Step:模型预测
- x_fearures_new1 = np.array([[0, -1]])
- x_fearures_new2 = np.array([[2, 1]])
-
- ## 在训练集和测试集上分布利用训练好的模型进行预测
- y_label_new1_predict = tree_clf.predict(x_fearures_new1)
- y_label_new2_predict = tree_clf.predict(x_fearures_new2)
-
- print('The New point 1 predict class:\n',y_label_new1_predict)
- print('The New point 2 predict class:\n',y_label_new2_predict)
ID3 树是基于信息增益构建的决策树
- import numpy as np
- import matplotlib.pyplot as plt
-
- def entropy(p):
- return -p*np.log(p)-(1-p)*np.log(1-p)
-
- x = np.linspace(0.01,0.99,200)
- plt.plot(x,entropy(x))
- plt.show()
信息增益
信息熵是一种衡量数据混乱程度的指标,信息熵越小,则数据的“纯度”越高
ID3算法步骤
信息增益率计算公式
如果某个特征的特征值种类较多,则其内在信息值就越大。特征值种类越多,除以的系数就越大。
如果某个特征的特征值种类较小,则其内在信息值就越小
C4.5算法优缺点
Cart模型是一种决策树模型,它即可以用于分类,也可以用于回归
(1)决策树生成:用训练数据生成决策树,生成树尽可能大
(2)决策树剪枝:基于损失函数最小化的剪枝,用验证数据对生成的数据进行剪枝。
分类和回归树模型采用不同的最优化策略。Cart回归树使用平方误差最小化策略,Cart分类生成树采用的基尼指数最小化策略。
Criterion这个参数正是用来决定模型特征选择的计算方法的。sklearn提供了两种选择:
输入”entropy“,使用信息熵(Entropy)
输入”gini“,使用基尼系数(Gini Impurity)
基尼指数:
信息增益(ID3)、信息增益率值越大(C4.5),则说明优先选择该特征。
基尼指数值越小(cart),则说明优先选择该特征。
剪枝是决策树学习算法对付过拟合的主要手段。
在决策树学习中,为了尽可能正确分类训练样本,结点划分过程将不断重复,有时会造成决策树分支过多,这时就可能因训练样本学得"太好"了,以致于把训练集自身的一些特点当作所有数据都具有的一般性质而导致过拟合
决策树的构建过程是一个递归的过层,所以必须确定停止条件,否则过程将不会停止,树会不停生长。
先剪枝和后剪枝
先剪枝就是提前结束决策树的增长。
后剪枝是在决策树生长完成之后再进行剪枝的过程。
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。