当前位置:   article > 正文

决策树——DT分类wine数据_wine数据集决策树

wine数据集决策树
  1. #导入相应的包
  2. from sklearn import datasets #导入方法类,datasets是sklearn自带的数据集模块,里面有许多机器学习经典的数据集
  3. from sklearn.model_selection import train_test_split #导入将数据划分为训练数据和测试数据的模块
  4. from sklearn.preprocessing import LabelEncoder, OneHotEncoder #导入sklearn前处理模块,对数据进行打标签
  5. #from sklearn.externals.six import StringIO
  6. from six import StringIO #StringIO顾名思义就是在内存中读写str,为了后面用Graphviz进行可视化
  7. from sklearn import tree #导入决策树模块
  8. import pandas as pd
  9. import numpy as np
  10. import pydotplus ##pydot模块提供了一个完整的界面,用于在图表语言中的计算机处理和过程图表。pydotplus是旧pydot项目的一个改进版本,它为graphviz的点语言提供了一个python接口。
  11. from sklearn.metrics import accuracy_score
  12. # 获取所需数据集
  13. wine = datasets.load_wine() #此时的wine是一个dataframe形式
  14. print(wine)
  15. wine.keys() # 数据集关键字
  16. X = wine.data
  17. Y = wine.target
  18. print(X)
  19. print('==========')
  20. print(Y)
  21. #划分训练集和测试集,按照7:3的比例划分
  22. X_train, X_test, Y_train, Y_test = train_test_split(X, Y, test_size=0.1, random_state=2) #random_state= 随机数的种子,相同的种子,产生相同的随机数
  23. #而且random_state的数值会影响预测的结果,原因是不同种子,随机数不一样,导致测试集会发生变化。测试集变化,正确率有变化,说明了这个模型有提升的空间
  24. print(X_train.shape)
  25. print(X_test.shape)
  26. #构建决策树
  27. clf = tree.DecisionTreeClassifier(criterion="entropy")
  28. clf = clf.fit(X_train, Y_train)
  29. #决策树可视化
  30. dot_data = StringIO() #要把str写入StringIO,我们需要先创建一个StringIO
  31. tree.export_graphviz(clf, out_file = dot_data, #绘制决策树
  32. feature_names = wine.feature_names,
  33. class_names = wine.target_names,
  34. filled=True, rounded=True,
  35. special_characters=True)
  36. graph = pydotplus.graph_from_dot_data(dot_data.getvalue())
  37. graph.write_pdf("treewine.pdf") #将画的树保存为PDF
  38. predict_results = clf.predict(X_test) # 使用模型对测试集进行预测
  39. print(accuracy_score(predict_results, Y_test))
  40. print(predict_results)

这是预测的准确率和结果:

划分的树:

对了,决策树的可视化需要先安装一个库和一个软件(没有的话),详情看下面的链接。

http://t.csdn.cn/qjlkD

环境:python3+Jupyter notebook

声明:本文内容由网友自发贡献,转载请注明出处:【wpsshop博客】
推荐阅读
相关标签
  

闽ICP备14008679号