当前位置:   article > 正文

基于决策树构建鸢尾花数据的分类模型并绘制决策树模型_基于决策树对鸢尾花分类主要研究内容

基于决策树对鸢尾花分类主要研究内容

    决策树模型是一种很简单但是却很经典的机器学习模型,经历多次的改进和发展,现在已经有很多成熟的树模型,比如早期的ID3算法、现在的C45模型、CART树模型等等,决策树一个很大的优点就是可解释性比较强,当然这也是相对于其他模型来说的,决策树模型在训练完成后还可以通过绘制模型图片,详细了解在树中每一个分裂节点的位置是使用什么属性进行的,本文是硕士论文撰写期间一个简单的小实验,这里整理出来,留作学习记录,下面是具体的实现:

  1. #!usr/bin/env python
  2. #encoding:utf-8
  3. from __future__ import division
  4. '''
  5. __Author__:沂水寒城
  6. 功能:使用决策树模型来对鸢尾花数据进行分析预测
  7. 绘制DT模型
  8. '''
  9. import os
  10. import csv
  11. import csv
  12. from sklearn.tree import *
  13. from sklearn.model_selection import train_test_split
  14. import matplotlib as mpl
  15. mpl.use('Agg')
  16. import matplotlib.pyplot as plt
  17. import pydotplus
  18. from sklearn.externals.six import StringIO #生成StringIO对象
  19. import graphviz
  20. os.environ["PATH"]+=os.pathsep + 'D:/Program Files (x86)/Graphviz2.38/bin/'
  21. from sklearn.datasets import load_iris
  22. from sklearn import tree
  23. iris = load_iris()
  24. def read_data(test_data='fake_result/features_cal.csv',n=1,label=1):
  25. '''
  26. 加载数据的功能
  27. n:特征数据起始位
  28. label:是否是监督样本数据
  29. '''
  30. csv_reader=csv.reader(open(test_data))
  31. data_list=[]
  32. for one_line in csv_reader:
  33. data_list.append(one_line)
  34. x_list=[]
  35. y_list=[]
  36. label_dict={'setosa':0,'versicolor':1,'virginica':2}
  37. for one_line in data_list[1:]:
  38. if label==1:
  39. biaoqian=label_dict[one_line[-1]]
  40. #biaoqian=int(one_line[-1])
  41. y_list.append(int(biaoqian)) #标志位
  42. one_list=[float(o) for o in one_line[n:-1]]
  43. x_list.append(one_list)
  44. else:
  45. one_list=[float(o) for o in one_line[n:]]
  46. x_list.append(one_list)
  47. return x_list, y_list
  48. def split_data(data_list, y_list, ratio=0.30):
  49. '''
  50. 按照指定的比例,划分样本数据集
  51. ratio: 测试数据的比率
  52. '''
  53. X_train, X_test, y_train, y_test = train_test_split(data_list, y_list, test_size=ratio, random_state=50)
  54. print '--------------------------------split_data shape-----------------------------------'
  55. print len(X_train), len(y_train)
  56. print len(X_test), len(y_test)
  57. return X_train, X_test, y_train, y_test
  58. def DT_model(data='XD_new_encoding.csv',rationum=0.20):
  59. '''
  60. 使用决策树模型
  61. '''
  62. x_list,y_list=read_data(test_data=data,n=1,label=1)
  63. X_train,X_test,y_train,y_test=split_data(x_list, y_list, ratio=rationum)
  64. DT=DecisionTreeClassifier()
  65. DT.fit(X_train,y_train)
  66. y_predict=DT.predict(X_test)
  67. print 'DT model accuracy: ', DT.score(X_test,y_test)
  68. dot_data=StringIO()
  69. export_graphviz(DT,out_file=dot_data,class_names=iris.target_names,feature_names=iris.feature_names,filled=True,
  70. rounded=True,special_characters=True)
  71. graph=pydotplus.graph_from_dot_data(dot_data.getvalue())
  72. graph.write_png('iris_result.png')
  73. if __name__ == '__main__':
  74. DT_model(data='iris.csv',rationum=0.30)

输出结果为:

  1. --------------------------------split_data shape-----------------------------------
  2. 105 105
  3. 45 45
  4. DT model accuracy: 0.9555555555555556
  5. [Finished in 1.6s]

    其中,iris.csv是sklearn中的鸢尾花数据,具体的保存方法在我之前的博客中已经有了,感兴趣的话可以看一下

    DT模型如下:

    

     直观看起来还是挺漂亮的,仔细看的话足够清晰了,对于详细分析数据而言是很有帮助的。

    下面是决策树模型图构建过程中的原始数据

  1. digraph Tree {
  2. node [shape=box] ;
  3. 0 [label="X[3] <= 0.8\ngini = 0.666\nsamples = 105\nvalue = [36, 33, 36]"] ;
  4. 1 [label="gini = 0.0\nsamples = 36\nvalue = [36, 0, 0]"] ;
  5. 0 -> 1 [labeldistance=2.5, labelangle=45, headlabel="True"] ;
  6. 2 [label="X[3] <= 1.65\ngini = 0.499\nsamples = 69\nvalue = [0, 33, 36]"] ;
  7. 0 -> 2 [labeldistance=2.5, labelangle=-45, headlabel="False"] ;
  8. 3 [label="X[2] <= 5.0\ngini = 0.157\nsamples = 35\nvalue = [0, 32, 3]"] ;
  9. 2 -> 3 ;
  10. 4 [label="gini = 0.0\nsamples = 31\nvalue = [0, 31, 0]"] ;
  11. 3 -> 4 ;
  12. 5 [label="X[0] <= 6.05\ngini = 0.375\nsamples = 4\nvalue = [0, 1, 3]"] ;
  13. 3 -> 5 ;
  14. 6 [label="gini = 0.0\nsamples = 1\nvalue = [0, 1, 0]"] ;
  15. 5 -> 6 ;
  16. 7 [label="gini = 0.0\nsamples = 3\nvalue = [0, 0, 3]"] ;
  17. 5 -> 7 ;
  18. 8 [label="X[2] <= 4.85\ngini = 0.057\nsamples = 34\nvalue = [0, 1, 33]"] ;
  19. 2 -> 8 ;
  20. 9 [label="X[1] <= 3.1\ngini = 0.375\nsamples = 4\nvalue = [0, 1, 3]"] ;
  21. 8 -> 9 ;
  22. 10 [label="gini = 0.0\nsamples = 3\nvalue = [0, 0, 3]"] ;
  23. 9 -> 10 ;
  24. 11 [label="gini = 0.0\nsamples = 1\nvalue = [0, 1, 0]"] ;
  25. 9 -> 11 ;
  26. 12 [label="gini = 0.0\nsamples = 30\nvalue = [0, 0, 30]"] ;
  27. 8 -> 12 ;
  28. }

    如果需要pdf版本的模型图也可以,下面是生成的PDF数据(因无法上传文件,这里添加的图片的后缀名,使用时直接删除图片后缀即可)

    

    

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/从前慢现在也慢/article/detail/701807
推荐阅读
相关标签
  

闽ICP备14008679号