当前位置:   article > 正文

决策树对鸢尾花数据集分类_构建一个决策树对鸢尾花数据集(iris)进行分类,描述主要过程并粘贴实现代码

构建一个决策树对鸢尾花数据集(iris)进行分类,描述主要过程并粘贴实现代码

一、决策树

GiNi系数和熵的纯度的评价标准,基尼指数是信息熵中﹣logP 在P = 1处一阶泰勒展开后的结果。所以两者都可以用来度量数据集的纯度,用于描述决策树节点的纯度
相关增益越大,分类越好,但对于多叉树,如果不限制分裂多少支,一次分裂就可以将信息熵降为0
因此需要平衡分裂情况与信息增益,信息增益率:信息增益 除以 类别 本身的熵作为惩罚措施

二、相关代码

import pandas as pd
import numpy as np
from sklearn.datasets import load_iris
from sklearn.tree import DecisionTreeClassifier
from sklearn.tree import export_graphviz
from sklearn.tree import DecisionTreeRegressor
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
import matplotlib.pyplot as plt
import matplotlib as mpl

iris = load_iris()
data = pd.DataFrame(iris.data)
data.columns = iris.feature_names
data['Species'] = load_iris().target
print(data)

x = data.iloc[:, 2:4]  # 花瓣长度和宽度
y = data.iloc[:, -1]

x_train, x_test, y_train, y_test = train_test_split(x, y, train_size=0.75, random_state=42)    # 分割训练集和测试集

tree_clf = DecisionTreeClassifier(max_depth=8, criterion='gini')      # 决策树的训练
tree_clf.fit(x_train, y_train)
y_test_hat = tree_clf.predict(x_test)
print("acc score:", accuracy_score(y_test, y_test_hat))

print(tree_clf.feature_importances_)

export_graphviz(
    tree_clf,
    out_file="./iris_tree.dot",
    feature_names=iris.feature_names[2:4],
    class_names=iris.target_names,
    rounded=True,
    filled=True
)
print(tree_clf.predict_proba([[5, 1.5]]))
print(tree_clf.predict([[5, 1.5]]))

depth = np.arange(1, 15)
err_list = []
for d in depth:      #  不同深度树的容错率
    print(d)
    clf = DecisionTreeClassifier(criterion='gini', max_depth=d)
    clf.fit(x_train, y_train)
    y_test_hat = clf.predict(x_test)
    result = (y_test_hat == y_test)
    if d == 1:
        print(result)
    err = 1 - np.mean(result)
    print(100 * err)
    err_list.append(err)
    print(d, ' 错误率:%.2f%%' % (100 * err))

mpl.rcParams['font.sans-serif'] = ['SimHei']
plt.figure(facecolor='w')
plt.plot(depth, err_list, 'ro-', lw=2)
plt.xlabel('决策树深度', fontsize=15)
plt.ylabel('错误率', fontsize=15)
plt.title('决策树深度和过拟合', fontsize=18)
plt.grid(True)
plt.show()
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63

为了得到最合适的深度,不同深度的错误率统计:
在这里插入图片描述

三、Graphviz生成决策树

调用命令,借助dot文件生成如下图决策树:
./dot -Tpng ~/PycharmProjects/mlstudy/bjsxt/iris_tree.dot -o ~/PycharmProjects/mlstudy/bjsxt/iris_tree.png

在这里插入图片描述

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/你好赵伟/article/detail/701837
推荐阅读
相关标签
  

闽ICP备14008679号