当前位置:   article > 正文

决策树分类wine数据集(数据科学导论)_wine.data数据集

wine.data数据集

本次课程作业使用决策树模型处理wine数据集(分类任务)

知识点

数据预处理

这里主要是使用了归一化。当训练数据中不同属性的数值数据大小不一致(有的都是个位数,有的几百几千)时,我们需要将数据进行归一化。

这里使用了最简单的最大——最小值归一化。

  1. # X 数据的归一化
  2. for i in range(13):
  3. data_X.iloc[:,i] = (data_X.iloc[:,i]-data_X.iloc[:,i].min())/ (data_X.iloc[:,i].max()-data_X.iloc[:,i].min())

k折交叉验证

  当进行机器学习任务时,我们通常将数据集分为训练集和测试集,以评估模型的性能。然而,单次划分可能会导致模型在特定数据集上表现良好,但在其他数据集上表现较差。为了更准确地评估模型的性能,并确保其对不同数据集的泛化能力,可以使用k折交叉验证。

   k折交叉验证将数据集分成k个子集,称为折(fold)。然后,我们迭代k次,每次使用其中一个子集作为测试集,而其他k-1个子集作为训练集。这样可以更全面地评估模型的性能,因为我们对整个数据集的不同部分都进行了训练和测试。

Some metrics in ML classification

  本次实验是一个有标签的分类任务(非回归、聚类),所使用的指标有accuracy(准确率)、recall(召回率)、f1_score(综合前两者)

   Accuracy(准确率):表示分类器正确预测的样本数占总样本数的比例。计算公式为:准确率 = (TP+TN)/(TP+TN+FP+FN),其中TP表示真正类的样本数,TN表示真负类的样本数,FP表示将负类预测为正类的样本数,FN表示将正类预测为负类的样本数。准确率越高,分类器性能越好。

  Recall(召回率):又称为灵敏度或真正类率,表示所有真正类中被分类器正确预测出来的样本数占真正类样本数的比例。计算公式为:召回率 = TP/(TP+FN)。召回率越高,分类器对正类的识别能力越好。

   F1-score(F1值):综合考虑了准确率和召回率,是一个综合评价指标。计算公式为:F1 = 2 * (precision * recall) / (precision + recall),其中precision表示精确率,计算公式为:precision = TP/(TP+FP)。F1-score的取值范围是[0,1],当F1-score接近1时,分类器的性能较好。

recall_score()中的average参数:

需要注意的是,在多分类任务中,我们需要指定recall_score()中的average参数:

`average` 参数有几个可选值:

- `None`:即不进行任何平均化,会返回每个类别的召回率分数。这在多分类任务中,特别是类别不平衡的情况下,可以提供对每个类别的详细召回率信息。

- `micro`:计算总体的召回率。将每个类别的真阳性、假阴性、真阴性和假阳性的数量加总,然后计算总体的召回率。适用于类别不平衡的情况,尤其是当所有类别的重要性相同时。

- `macro`:对每个类别分别计算召回率,然后取平均值。不考虑类别的样本数或重要性,每个类别的召回率对最后的平均值都有同等权重。适用于各个类别的重要性相当,并且您希望每个类别都有相同影响力的情况。

- `weighted`:对每个类别分别计算召回率,然后加权平均。将每个类别的召回率乘以该类别的样本数占总体样本数的比例,然后将它们加总。适用于类别不平衡的情况,更关注样本较多的类别。

- `samples`:对每个样本的实际类别和预测类别分别计算召回率,并取平均值。适用于多标签分类任务,每个样本可以属于多个类别。

f1_score由acc与recall共同计算,故也有average参数

注:下文代码选用micro方式计算f1score

代码

  1. # 包引入
  2. import os
  3. import math
  4. import time
  5. import random
  6. import datetime
  7. import numpy as np
  8. import pandas as pd
  9. import matplotlib.pyplot as plt
  10. # ML库
  11. from sklearn.preprocessing import OneHotEncoder
  12. from sklearn.datasets import load_wine
  13. from sklearn.ensemble import GradientBoostingClassifier
  14. from sklearn.tree import DecisionTreeClassifier
  15. from sklearn import tree
  16. from sklearn.preprocessing import StandardScaler
  17. from sklearn.model_selection import KFold
  18. from sklearn.metrics import mean_squared_error
  19. from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
  1. # 载入wine数据集
  2. wine = load_wine()
  3. data_X = pd.DataFrame(wine.data,columns=wine.feature_names)
  4. data_Y = pd.DataFrame(wine.target)
  5. # 数据初步探索
  6. data_X.info()
  7. data_Y.info()
  8. print(data_X.head())
  9. print(data_Y.head())
  10. print(data_Y.iloc[:,:].value_counts())
  11. # X 数据的归一化
  12. for i in range(13):
  13. data_X.iloc[:,i] = (data_X.iloc[:,i]-data_X.iloc[:,i].min())/ (data_X.iloc[:,i].max()-data_X.iloc[:,i].min())
  14. print(data_X.head())
  15. # k折交叉验证
  16. X = data_X.values
  17. Y = data_Y.values
  18. k = 5
  19. kf = KFold(n_splits = k)
  20. # 找最优参数
  21. for hp_para in range(1,14):
  22. f1 = 0
  23. acc = 0
  24. for train_index, test_index in kf.split(X):
  25. # 在k折中取出训练集和验证集
  26. X_train, X_test = X[train_index], X[test_index]
  27. Y_train,Y_test = Y[train_index], Y[test_index]
  28. # 决策树模型
  29. model = DecisionTreeClassifier(max_depth = hp_para)
  30. # 模型训练
  31. model.fit(X_train, Y_train)
  32. # 预测数据
  33. Y_pred = pd.DataFrame(model.predict(X_test))
  34. Y_true = Y_test
  35. # 计算正确率
  36. f1 += f1_score(Y_true, Y_pred, average='macro')
  37. acc += accuracy_score(Y_true, Y_pred)
  38. f1/=5
  39. acc/=5
  40. # print loss
  41. print(hp_para, ' ', f1, " ", acc)
  42. # 训练与评测
  43. X = data_X.values
  44. Y = data_Y.values
  45. hp_para = 5
  46. model = DecisionTreeClassifier(max_depth = hp_para)
  47. model.fit(X,Y)
  48. Y_pred = model.predict(X)
  49. acc = accuracy_score(Y,Y_pred)
  50. rec = recall_score(Y,Y_pred,average='micro')
  51. f1 = f1_score(Y, Y_pred, average='micro')
  52. print(acc)
  53. print(rec)
  54. print(f1)

 

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/IT小白/article/detail/483900
推荐阅读
相关标签
  

闽ICP备14008679号