当前位置:   article > 正文

基于sklearn的决策树模型——以iris数据集为例_利用数据集iris,从中随机抽取50个样品作为testdata,剩下的作为traindata,进行判

利用数据集iris,从中随机抽取50个样品作为testdata,剩下的作为traindata,进行判别

Iris数据集是常用的分类实验数据集,由Fisher, 1936收集整理。Iris也称鸢尾花卉数据集,是一类多重变量分析的数据集。数据集包含150个数据样本,分为3类,每类50个数据,每个数据包含4个属性。可通过花萼长度,花萼宽度,花瓣长度,花瓣宽度4个属性预测鸢尾花卉属于(Setosa,Versicolour,Virginica)三个种类中的哪一类。

iris以鸢尾花的特征作为数据来源,常用在分类操作中。该数据集由3种不同类型的鸢尾花的各50个样本数据构成。其中的一个种类与另外两个种类是线性可分离的,后两个种类是非线性可分离的。

该数据集包含了4个属性:

& Sepal.Length(花萼长度),单位是cm;

& Sepal.Width(花萼宽度),单位是cm;

& Petal.Length(花瓣长度),单位是cm;

& Petal.Width(花瓣宽度),单位是cm;

种类:Iris Setosa(山鸢尾)、Iris Versicolour(杂色鸢尾),以及Iris Virginica(维吉尼亚鸢尾)。

 

  1. # -*- coding:utf-8 -*-
  2. # 导入决策树模型
  3. from sklearn import tree
  4. # 导入数据集
  5. from sklearn.datasets import load_iris
  6. # 导入数据集拆分模块
  7. from sklearn.model_selection import train_test_split
  8. # 导入准确率计算模块
  9. from sklearn.metrics import accuracy_score
  10. # 加载数据集
  11. iris = load_iris()
  12. # 查看数据集描述
  13. print(iris.DESCR)
  14. # 分割数据集
  15. x_train, x_test, y_train, y_test = train_test_split(iris.data, iris.target, test_size=0.2)
  16. # 训练模型
  17. clf = tree.DecisionTreeClassifier(criterion='entropy')
  18. clf.fit(x_train, y_train)
  19. # 预测
  20. y_pred = clf.predict(x_test)
  21. # 计算准确率
  22. score = accuracy_score(y_test, y_pred)
  23. print("准确率:", score)

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/小小林熬夜学编程/article/detail/701824
推荐阅读
相关标签
  

闽ICP备14008679号