当前位置:   article > 正文

机器学习算法:基于企鹅数据集的决策树分类预测_决策树鸟类分类数据集

决策树鸟类分类数据集

1 逻决策树的介绍和应用

1.1 决策树的介绍

决策树是一种常见的分类模型,在金融风控、医疗辅助诊断等诸多行业具有较为广泛的应用。决策树的核心思想是基于树结构对数据进行划分,这种思想是人类处理问题时的本能方法。例如在婚恋市场中,女方通常会先询问男方是否有房产,如果有房产再了解是否有车产,如果有车产再看是否有稳定工作……最后得出是否要深入了解的判断。

主要应用:

由于决策树模型中自变量与因变量的非线性关系以及决策树简单的计算方法,使得它成为集成学习中最为广泛使用的基模型。梯度提升树(GBDT),XGBoost以及LightGBM等先进的集成模型都采用了决策树作为基模型,在广告计算、CTR预估、金融风控等领域大放异彩,成为当今与神经网络相提并论的复杂模型,更是数据挖掘比赛中的常客。在新的研究中,南京大学周志华教授提出一种多粒度级联森林模型,创造了一种全新的基于决策树的深度集成方法,为我们提供了决策树发展的另一种可能。

同时决策树在一些明确需要可解释性或者提取分类规则的场景中被广泛应用,而其他机器学习模型在这一点很难做到。例如在医疗辅助系统中,为了方便专业人员发现错误,常常将决策树算法用于辅助病症检测。例如在一个预测哮喘患者的模型中,医生发现测试的许多高级模型的效果非常差。在他们运行了一个决策树模型后发现,算法认为剧烈咳嗽的病人患哮喘的风险很小。但医生非常清楚剧烈咳嗽一般都会被立刻检查治疗,这意味着患有剧烈咳嗽的哮喘病人都会马上得到收治。用于建模的数据认为这类病人风险很小,是因为所有这类病人都得到了及时治疗,所以极少有人在此之后患病或死亡。

1.2 相关流程

  • 了解 决策树 的理论知识
  • 掌握 决策树 的 sklearn 函数调用并将其运用在企鹅数据集的预测中

Part1 Demo实践

  • Step1:库函数导入
  • Step2:模型训练
  • Step3:数据和模型可视化
  • Step4:模型预测

Part2 基于企鹅(penguins)数据集的决策树分类实践

  • Step1:库函数导入
  • Step2:数据读取/载入
  • Step3:数据信息简单查看
  • Step4:可视化描述
  • Step5:利用 决策树模型 在二分类上 进行训练和预测
  • Step6:利用 决策树模型 在三分类(多分类)上 进行训练和预测

3 算法实战

3.1Demo实践

Step1: 库函数导入

  1. ## 基础函数库
  2. import numpy as np
  3. ## 导入画图库
  4. import matplotlib.pyplot as plt
  5. import seaborn as sns
  6. ## 导入决策树模型函数
  7. from sklearn.tree import DecisionTreeClassifier
  8. from sklearn import tree

Step2: 训练模型

  1. ##Demo演示LogisticRegression分类
  2. ## 构造数据集
  3. x_fearures = np.array([[-1, -2], [-2, -1], [-3, -2], [1, 3], [2, 1], [3, 2]])
  4. y_label = np.array([0, 1, 0, 1, 0, 1])
  5. ## 调用决策树回归模型
  6. tree_clf = DecisionTreeClassifier()
  7. ## 调用决策树模型拟合构造的数据集
  8. tree_clf = tree_clf.fit(x_fearures, y_label)

Step3: 数据和模型可视化(需要用到graphviz可视化库)

  1. ## 可视化构造的数据样本点
  2. plt.figure()
  3. plt.scatter(x_fearures[:,0],x_fearures[:,1], c=y_label, s=50, cmap='viridis')
  4. plt.title('Dataset')
  5. plt.show()

  1. ## 可视化决策树
  2. import graphviz
  3. dot_data = tree.export_graphviz(tree_clf, out_file=None)
  4. graph = graphviz.Source(dot_data)
  5. graph.render("pengunis")
'pengunis.pdf'

Step4:模型预测

  1. ## 创建新样本
  2. x_fearures_new1 = np.array([[0, -1]])
  3. x_fearures_new2 = np.array([[2, 1]])
  4. ## 在训练集和测试集上分布利用训练好的模型进行预测
  5. y_label_new1_predict = tree_clf.predict(x_fearures_new1)
  6. y_label_new2_predict = tree_clf.predict(x_fearures_new2)
  7. print('The New point 1 predict class:\n',y_label_new1_predict)
  8. print('The New point 2 predict class:\n',y_label_new2_predict)
  1. The New point 1 predict class:
  2. [1]
  3. The New point 2 predict class:
  4. [0]

3.2 基于penguins_raw数据集的决策树实战

在实践的最开始,我们首先需要导入一些基础的函数库包括:numpy (Python进行科学计算的基础软件包),pandas(pandas是一种快速,强大,灵活且易于使用的开源数据分析和处理工具),matplotlib和seaborn绘图。

  1. #下载需要用到的数据集
  2. !wget https://tianchi-media.oss-cn-beijing.aliyuncs.com/DSW/6tree/penguins_raw.csv
  1. --2023-03-22 16:21:32-- https://tianchi-media.oss-cn-beijing.aliyuncs.com/DSW/6tree/penguins_raw.csv
  2. 正在解析主机 tianchi-media.oss-cn-beijing.aliyuncs.com (tianchi-media.oss-cn-beijing.aliyuncs.com)... 49.7.22.39
  3. 正在连接 tianchi-media.oss-cn-beijing.aliyuncs.com (tianchi-media.oss-cn-beijing.aliyuncs.com)|49.7.22.39|:443... 已连接。
  4. 已发出 HTTP 请求,正在等待回应... 200 OK
  5. 长度: 53098 (52K) [text/csv]
  6. 正在保存至: “penguins_raw.csv”
  7. penguins_raw.csv 100%[===================>] 51.85K --.-KB/s in 0.04s
  8. 2023-03-22 16:21:33 (1.23 MB/s) - 已保存 “penguins_raw.csv” [53098/53098])

Step1:函数库导入

  1. ## 基础函数库
  2. import numpy as np
  3. import pandas as pd
  4. ## 绘图函数库
  5. import matplotlib.pyplot as plt
  6. import seaborn as sns

本次我们选择企鹅数据(palmerpenguins)进行方法的尝试训练,该数据集一共包含8个变量,其中7个特征变量,1个目标分类变量。共有150个样本,目标变量为 企鹅的类别 其都属于企鹅类的三个亚属,分别是(Adélie, Chinstrap and Gentoo)。包含的三种种企鹅的七个特征,分别是所在岛屿,嘴巴长度,嘴巴深度,脚蹼长度,身体体积,性别以及年龄。

变量描述
speciesa factor denoting penguin species
islanda factor denoting island in Palmer Archipelago, Antarctica
bill_length_mma number denoting bill length
bill_depth_mma number denoting bill depth
flipper_length_mman integer denoting flipper length
body_mass_gan integer denoting body mass
sexa factor denoting penguin sex
yearan integer denoting the study year

Step2:数据读取/载入

  1. ## 我们利用Pandas自带的read_csv函数读取并转化为DataFrame格式
  2. data = pd.read_csv('./penguins_raw.csv')
  1. ## 为了方便我们仅选取四个简单的特征,有兴趣的同学可以研究下其他特征的含义以及使用方法
  2. data = data[['Species','Culmen Length (mm)','Culmen Depth (mm)',
  3. 'Flipper Length (mm)','Body Mass (g)']]

Step3:数据信息简单查看

  1. ## 利用.info()查看数据的整体信息
  2. data.info()
  1. <class 'pandas.core.frame.DataFrame'>
  2. RangeIndex: 344 entries, 0 to 343
  3. Data columns (total 5 columns):
  4. # Column Non-Null Count Dtype
  5. --- ------ -------------- -----
  6. 0 Species 344 non-null object
  7. 1 Culmen Length (mm) 342 non-null float64
  8. 2 Culmen Depth (mm) 342 non-null float64
  9. 3 Flipper Length (mm) 342 non-null float64
  10. 4 Body Mass (g) 342 non-null float64
  11. dtypes: float64(4), object(1)
  12. memory usage: 13.6+ KB
  1. ## 进行简单的数据查看,我们可以利用 .head() 头部.tail()尾部
  2. data.head()
SpeciesCulmen Length (mm)Culmen Depth (mm)Flipper Length (mm)Body Mass (g)
0Adelie Penguin (Pygoscelis adeliae)39.118.7181.03750.0
1Adelie Penguin (Pygoscelis adeliae)39.517.4186.03800.0
2Adelie Penguin (Pygoscelis adeliae)40.318.0195.03250.0
3Adelie Penguin (Pygoscelis adeliae)NaNNaNNaNNaN
4Adelie Penguin (Pygoscelis adeliae)36.719.3193.03450.0

这里我们发现数据集中存在NaN,一般的我们认为NaN在数据集中代表了缺失值,可能是数据采集或处理时产生的一种错误。这里我们采用-1将缺失值进行填补,还有其他例如“中位数填补、平均数填补”的缺失值处理方法有兴趣的同学也可以尝试。

data = data.fillna(-1)
data.tail()
SpeciesCulmen Length (mm)Culmen Depth (mm)Flipper Length (mm)Body Mass (g)
339Chinstrap penguin (Pygoscelis antarctica)55.819.8207.04000.0
340Chinstrap penguin (Pygoscelis antarctica)43.518.1202.03400.0
341Chinstrap penguin (Pygoscelis antarctica)49.618.2193.03775.0
342Chinstrap penguin (Pygoscelis antarctica)50.819.0210.04100.0
343Chinstrap penguin (Pygoscelis antarctica)50.218.7198.03775.0
  1. ## 其对应的类别标签为'Adelie Penguin', 'Gentoo penguin', 'Chinstrap penguin'三种不同企鹅的类别。
  2. data['Species'].unique()
  1. array(['Adelie Penguin (Pygoscelis adeliae)',
  2. 'Gentoo penguin (Pygoscelis papua)',
  3. 'Chinstrap penguin (Pygoscelis antarctica)'], dtype=object)
  1. ## 利用value_counts函数查看每个类别数量
  2. pd.Series(data['Species']).value_counts()
  1. Adelie Penguin (Pygoscelis adeliae) 152
  2. Gentoo penguin (Pygoscelis papua) 124
  3. Chinstrap penguin (Pygoscelis antarctica) 68
  4. Name: Species, dtype: int64
  1. ## 对于特征进行一些统计描述
  2. data.describe()
Culmen Length (mm)Culmen Depth (mm)Flipper Length (mm)Body Mass (g)
count344.000000344.000000344.000000344.000000
mean43.66075617.045640199.7412794177.319767
std6.4289572.40561420.806759861.263227
min-1.000000-1.000000-1.000000-1.000000
25%39.20000015.500000190.0000003550.000000
50%44.25000017.300000197.0000004025.000000
75%48.50000018.700000213.0000004750.000000
max59.60000021.500000231.0000006300.000000

Step4:可视化描述

  1. ## 特征与标签组合的散点可视化
  2. sns.pairplot(data=data, diag_kind='hist', hue= 'Species')
  3. plt.show()

从上图可以发现,在2D情况下不同的特征组合对于不同类别的企鹅的散点分布,以及大概的区分能力。Culmen Lenth与其他特征的组合散点的重合较少,所以对于数据集的划分能力最好。

我们发现

  1. '''为了方便我们将标签转化为数字
  2. 'Adelie Penguin (Pygoscelis adeliae)' ------0
  3. 'Gentoo penguin (Pygoscelis papua)' ------1
  4. 'Chinstrap penguin (Pygoscelis antarctica) ------2 '''
  5. def trans(x):
  6. if x == data['Species'].unique()[0]:
  7. return 0
  8. if x == data['Species'].unique()[1]:
  9. return 1
  10. if x == data['Species'].unique()[2]:
  11. return 2
  12. data['Species'] = data['Species'].apply(trans)
  1. for col in data.columns:
  2. if col != 'Species':
  3. sns.boxplot(x='Species', y=col, saturation=0.5, palette='pastel', data=data)
  4. plt.title(col)
  5. plt.show()



 

利用箱型图我们也可以得到不同类别在不同特征上的分布差异情况。

  1. # 选取其前三个特征绘制三维散点图
  2. from mpl_toolkits.mplot3d import Axes3D
  3. fig = plt.figure(figsize=(10,8))
  4. ax = fig.add_subplot(111, projection='3d')
  5. data_class0 = data[data['Species']==0].values
  6. data_class1 = data[data['Species']==1].values
  7. data_class2 = data[data['Species']==2].values
  8. # 'setosa'(0), 'versicolor'(1), 'virginica'(2)
  9. ax.scatter(data_class0[:,0], data_class0[:,1], data_class0[:,2],label=data['Species'].unique()[0])
  10. ax.scatter(data_class1[:,0], data_class1[:,1], data_class1[:,2],label=data['Species'].unique()[1])
  11. ax.scatter(data_class2[:,0], data_class2[:,1], data_class2[:,2],label=data['Species'].unique()[2])
  12. plt.legend()
  13. plt.show()

Step5:利用 决策树模型 在二分类上 进行训练和预测

  1. ## 为了正确评估模型性能,将数据划分为训练集和测试集,并在训练集上训练模型,在测试集上验证模型性能。
  2. from sklearn.model_selection import train_test_split
  3. ## 选择其类别为01的样本 (不包括类别为2的样本)
  4. data_target_part = data[data['Species'].isin([0,1])][['Species']]
  5. data_features_part = data[data['Species'].isin([0,1])][['Culmen Length (mm)','Culmen Depth (mm)',
  6. 'Flipper Length (mm)','Body Mass (g)']]
  7. ## 测试集大小为20%, 80%/20%分
  8. x_train, x_test, y_train, y_test = train_test_split(data_features_part, data_target_part, test_size = 0.2, random_state = 2020)
  1. ## 从sklearn中导入决策树模型
  2. from sklearn.tree import DecisionTreeClassifier
  3. from sklearn import tree
  4. ## 定义 决策树模型
  5. clf = DecisionTreeClassifier(criterion='entropy')
  6. # 在训练集上训练决策树模型
  7. clf.fit(x_train, y_train)
  1. DecisionTreeClassifier(ccp_alpha=0.0, class_weight=None, criterion='entropy',
  2. max_depth=None, max_features=None, max_leaf_nodes=None,
  3. min_impurity_decrease=0.0, min_impurity_split=None,
  4. min_samples_leaf=1, min_samples_split=2,
  5. min_weight_fraction_leaf=0.0, presort='deprecated',
  6. random_state=None, splitter='best')
  1. ## 可视化
  2. import graphviz
  3. dot_data = tree.export_graphviz(clf, out_file=None)
  4. graph = graphviz.Source(dot_data)
  5. graph.render("penguins")
'penguins.pdf'
  1. ## 在训练集和测试集上分布利用训练好的模型进行预测
  2. train_predict = clf.predict(x_train)
  3. test_predict = clf.predict(x_test)
  4. from sklearn import metrics
  5. ## 利用accuracy(准确度)【预测正确的样本数目占总预测样本数目的比例】评估模型效果
  6. print('The accuracy of the Logistic Regression is:',metrics.accuracy_score(y_train,train_predict))
  7. print('The accuracy of the Logistic Regression is:',metrics.accuracy_score(y_test,test_predict))
  8. ## 查看混淆矩阵 (预测值和真实值的各类情况统计矩阵)
  9. confusion_matrix_result = metrics.confusion_matrix(test_predict,y_test)
  10. print('The confusion matrix result:\n',confusion_matrix_result)
  11. # 利用热力图对于结果进行可视化
  12. plt.figure(figsize=(8, 6))
  13. sns.heatmap(confusion_matrix_result, annot=True, cmap='Blues')
  14. plt.xlabel('Predicted labels')
  15. plt.ylabel('True labels')
  16. plt.show()
  1. The accuracy of the Logistic Regression is: 0.9954545454545455
  2. The accuracy of the Logistic Regression is: 1.0
  3. The confusion matrix result:
  4. [[31 0]
  5. [ 0 25]]

我们可以发现其准确度为1,代表所有的样本都预测正确了。

Step6:利用 决策树模型 在三分类(多分类)上 进行训练和预测

  1. ## 测试集大小为20%, 80%/20%分
  2. x_train, x_test, y_train, y_test = train_test_split(data[['Culmen Length (mm)','Culmen Depth (mm)',
  3. 'Flipper Length (mm)','Body Mass (g)']], data[['Species']], test_size = 0.2, random_state = 2020)
  4. ## 定义 决策树模型
  5. clf = DecisionTreeClassifier()
  6. # 在训练集上训练决策树模型
  7. clf.fit(x_train, y_train)
  1. DecisionTreeClassifier(ccp_alpha=0.0, class_weight=None, criterion='gini',
  2. max_depth=None, max_features=None, max_leaf_nodes=None,
  3. min_impurity_decrease=0.0, min_impurity_split=None,
  4. min_samples_leaf=1, min_samples_split=2,
  5. min_weight_fraction_leaf=0.0, presort='deprecated',
  6. random_state=None, splitter='best')
  1. ## 在训练集和测试集上分布利用训练好的模型进行预测
  2. train_predict = clf.predict(x_train)
  3. test_predict = clf.predict(x_test)
  4. ## 由于决策树模型是概率预测模型(前文介绍的 p = p(y=1|x,\theta)),所有我们可以利用 predict_proba 函数预测其概率
  5. train_predict_proba = clf.predict_proba(x_train)
  6. test_predict_proba = clf.predict_proba(x_test)
  7. print('The test predict Probability of each class:\n',test_predict_proba)
  8. ## 其中第一列代表预测为0类的概率,第二列代表预测为1类的概率,第三列代表预测为2类的概率。
  9. ## 利用accuracy(准确度)【预测正确的样本数目占总预测样本数目的比例】评估模型效果
  10. print('The accuracy of the Logistic Regression is:',metrics.accuracy_score(y_train,train_predict))
  11. print('The accuracy of the Logistic Regression is:',metrics.accuracy_score(y_test,test_predict))
  1. The test predict Probability of each class:
  2. [[0. 0. 1.]
  3. [0. 1. 0.]
  4. [0. 1. 0.]
  5. [1. 0. 0.]
  6. [1. 0. 0.]
  7. [0. 0. 1.]
  8. [0. 0. 1.]
  9. [1. 0. 0.]
  10. [0. 1. 0.]
  11. [1. 0. 0.]
  12. [0. 1. 0.]
  13. [0. 1. 0.]
  14. [1. 0. 0.]
  15. [0. 1. 0.]
  16. [0. 1. 0.]
  17. [0. 1. 0.]
  18. [1. 0. 0.]
  19. [0. 1. 0.]
  20. [1. 0. 0.]
  21. [1. 0. 0.]
  22. [0. 0. 1.]
  23. [1. 0. 0.]
  24. [0. 0. 1.]
  25. [1. 0. 0.]
  26. [1. 0. 0.]
  27. [1. 0. 0.]
  28. [0. 1. 0.]
  29. [1. 0. 0.]
  30. [0. 1. 0.]
  31. [1. 0. 0.]
  32. [1. 0. 0.]
  33. [0. 0. 1.]
  34. [0. 0. 1.]
  35. [0. 1. 0.]
  36. [1. 0. 0.]
  37. [0. 1. 0.]
  38. [0. 1. 0.]
  39. [1. 0. 0.]
  40. [1. 0. 0.]
  41. [0. 1. 0.]
  42. [0. 0. 1.]
  43. [1. 0. 0.]
  44. [0. 1. 0.]
  45. [1. 0. 0.]
  46. [1. 0. 0.]
  47. [0. 0. 1.]
  48. [0. 0. 1.]
  49. [1. 0. 0.]
  50. [1. 0. 0.]
  51. [0. 1. 0.]
  52. [1. 0. 0.]
  53. [1. 0. 0.]
  54. [0. 1. 0.]
  55. [0. 1. 0.]
  56. [0. 0. 1.]
  57. [0. 0. 1.]
  58. [0. 1. 0.]
  59. [1. 0. 0.]
  60. [1. 0. 0.]
  61. [1. 0. 0.]
  62. [0. 1. 0.]
  63. [0. 1. 0.]
  64. [0. 0. 1.]
  65. [0. 0. 1.]
  66. [1. 0. 0.]
  67. [0. 1. 0.]
  68. [0. 0. 1.]
  69. [1. 0. 0.]
  70. [1. 0. 0.]]
  71. The accuracy of the Logistic Regression is: 0.9963636363636363
  72. The accuracy of the Logistic Regression is: 0.9565217391304348
  1. ## 查看混淆矩阵
  2. confusion_matrix_result = metrics.confusion_matrix(test_predict,y_test)
  3. print('The confusion matrix result:\n',confusion_matrix_result)
  4. # 利用热力图对于结果进行可视化
  5. plt.figure(figsize=(8, 6))
  6. sns.heatmap(confusion_matrix_result, annot=True, cmap='Blues')
  7. plt.xlabel('Predicted labels')
  8. plt.ylabel('True labels')
  9. plt.show()
  1. The confusion matrix result:
  2. [[30 1 0]
  3. [ 0 23 0]
  4. [ 2 0 13]]

3.3 重要知识点

3.3.1 决策树构建的伪代码

输入: 训练集D={($x_1$,$y_1$),($x_2$,$y_2$),....,($x_m$,$y_m$)};
特征集A={$a_1$,$a_2$,....,$a_d$}

输出: 以node为根节点的一颗决策树

过程:函数TreeGenerate($D$,$A$)

  1. 生成节点node
  2. $if$ $D$中样本全书属于同一类别$C$ $then$:
  3. ----将node标记为$C$类叶节点;$return$
  4. $if$ $A$ = 空集 OR D中样本在$A$上的取值相同 $then$:
  5. ----将node标记为叶节点,其类别标记为$D$中样本数最多的类;$return$
  6. 从 $A$ 中选择最优划分属性 $a_*$;
  7. $for$ $a_$ 的每一个值 $a_^v$ $do$:
  8. ----为node生成一个分支,令$D_v$表示$D$中在$a_$上取值为$a_^v$的样本子集;
  9. ----$if$ $D_v$ 为空 $then$:
  10. --------将分支节点标记为叶节点,其类别标记为$D$中样本最多的类;$then$
  11. ----$else$:
  12. --------以 TreeGenerate($D_v$,$A${$a_*$})为分支节点

决策树的构建过程是一个递归过程。函数存在三种返回状态:(1)当前节点包含的样本全部属于同一类别,无需继续划分;(2)当前属性集为空或者所有样本在某个属性上的取值相同,无法继续划分;(3)当前节点包含的样本集合为空,无法划分。

3.3.2 划分选择

从上述伪代码中我们发现,决策树的关键在于line6.从$A$中选择最优划分属性$

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/人工智能uu/article/detail/845148
推荐阅读
相关标签