当前位置:   article > 正文

决策树_机器学习_cart算法是否购买计算机

cart算法是否购买计算机

前言:   

     决策树也是一种简单的分类模型,主要应用在反信用卡诈骗,骚扰邮件过滤。

这里主要结合:

    1: 乳腺癌预测例子

     2  破腹产预测例子

         https://archive.ics.uci.edu/ml/datasets/Caesarian+Section+Classification+Dataset

   3   学生购买电脑的例子讲述一下相关的原理

       生成的树如下:

     

 

   

   

 

数据集2,预测是否会购买电脑

   

    

 

目录:

  1.      决策树简介
  2.      决策树算法流程
  3.      决策树三种主要算法(ID3,C4.5, CART)
  4.      决策树剪枝
  5.      例子
  6.      问题

一  决策树简介

     垃圾邮件分类决策树模型:

    

    

   决策树主要由四部分组成:

      根节点(属性)

      子节点(属性)

       分枝(属性值)

      叶子节点(标签

二  决策树流程:

        2.1 CreateNode(创建节点)

              

  

     2.2  总体流程:

           

 

    2.3 熵

        反应离散随机事件出现的概率

       

       熵也是一个有约束的凸函数

 

     

     可以通过琴声不等式证明p 取1/n 时候值最大,且范围值为

     证明:

       由定理1:

       

       其中

       

        因为是上凸,所有

       

       

        

      

        可以看出下面三点

  1:     

   2:       时候为熵的极大值

   3:         分类种类越多,熵值越大。

 

  

      

         

   2.4 信息增益

         假设有K类,A代表具体属性,比如上面 年龄,D_i 是按照年龄中一个具体的划分,比如老年

取的数据集

       

          其中:

         

              

  2.5 信息增益比

          信息增益划分,会偏向选择较多特征的一列,为了纠正这一问题,可以选择信息增益比

           

              

 

    

三  决策树三种主要算法

     3.1 ID3

           输入:

                   训练数据集D , 和 特征集A,  阀值

           输出:

                   决策树

 

           步骤:

                   1: 若D中所有的实例都是同一类,则T为单节点树,作为叶节点返回

                   2: 若特征集为空集, 则T为单节点树,将D中实例数最多的类作为叶节点返回

                   3: 计算特征集A中信息增益,选择信息增益最大的特征Ag

                    3: 计算特征集A中信息增益比,选择信息增益比最大的特征Ag

                   4:   如果该信息增益小于阀值e,  则T为单节点树,返回

                    5   对A中每个属性a,将D 按照a 划分为若干个非空集合Di, 构建新的节点树,该节点和其子节点构成

                 树T,返回T     例如:{A2:{a21:value, a22:{...},...}

                     6: 对第i个子节点,以Di为训练集,以A-{Ag}为特征集,递归上面步骤1--5

                 

              

          

     3.2 C4.5

             

                     唯一的区别选择了信息增益比

 

     3.3 CART

            最小二乘回归树生成算法

            可以对连续数据处理

             输入:

                      训练集D

             输出:

                     回归树f(x),也是二分树

           

             1: 选择最优切分变量j 和 切分点s, 求解

                      

                      遍历遍历j,对固定的切分遍历j扫描切分点s, 选择使得上面达到最小的(j,s)

                    

 

              2: 对选定的对(j,s)划分区域 ,并决定对应的输出值

                

             3:  继续对两个子区域调用步骤1,2 直到满足停止条件

              4: 将输入空间划分为M个区域 生成决策树

                     

 

    停止条件:

           当前的每个叶节点方差和小于一定值,例如上面乳腺癌按照年龄分,只分一次

         

      

"""
根据最小标准差方差获得最优的切分点
原理上就是数据尽量是同类
Args
   j: 属性,某一列
   dataList: 数据集

retrun
    j, s 最优切分点,这里考虑的只是连续数据
"""
def  MiniTree(j, dataList):
    
    #print("j: ",j, "\n data",dataList)
    sList =[int(item[j]) for item in dataList]
    sList.sort()
    setS = set(sList)
    print("\n ******setS******** ",setS )
    miniVar = float("inf")   
    miniS = 0
    leftR  =  []
    rightR =  []
    for s in setS:
        
        leftR.clear()
        rightR.clear()
        for item in dataList:
            a = int(item[j])
            y = int(item[-1])
            if a <s:    
                leftR.append(y)
            else:
                rightR.append(y)
         
        var1 = 0
        var2 = 0
        C1 = None
        C2 = None
        if len(leftR)>0:
            var1 = np.var(leftR)*len(leftR)
            C1 = np.mean(leftR)
                
        if len(rightR)>0:
            var2 = np.var(rightR)*len(rightR)
            C2 = np.mean(rightR)
        var = var1+var2
        #print("s: ",s, "\t currQ ",var)
        if var <miniVar:       
                miniVar = var
                miniS = s
                
                print("\n  miniS: \t  ",s,"\t C1: ",C1, "\t C2: ",C2)

 

   
                

      基尼指数      : CART中用来选择最优特征,反应样本集合的不确定性

       

        

 

        如何证明Gini 极大值是

       证明:

                根据约束条件,加上一个仿射函数,配成

              

        对        求偏导数

               

                 

     重新带入上式

            

          对 求偏导数, 则

             

            

         则:

              为极值点

 

     CART生成算法:

       输入:

           训练数据集D,停止计算条件

      输出:

          CART决策树

 

        s1: 设结点的训练数据集为D,计算现有特征对该数据集的基尼系数,对每一个特征A,对每一个取值a

             根据样本点取值对A=a ,将数据集分割成D1,D2两部分

        s2:  在所有可能的特征A 以及它们的切分点a,选择基尼指数最小的特征A及对应的切分点a作为最优切分点

      生成两个子结点,将数据集划分到对应的两个子结点中去,

        s3: 对两个子结点,递归调用s1,s2,直到满足停止条件

         s4: 生成CART树

             

              

四  剪枝   

  

    主要防止过拟合,增加泛化能力

 

    分为预剪枝,和后剪枝

    常见的后剪枝算法为: REP, PEP, CCP等算法


  4.1  决策树剪枝:  

   输入:

          算法生成整个树T, 参数

    输出:

           修剪后的树 Tn

   

    4.1 计算每个结点的经验熵

    4.2  递归从树的叶结点回缩(结点)

    4.3  设一组叶结点回缩前和回缩后整树分别为,对应的损失函数分别为

        如果,则进行剪枝,父结点变成新的叶结点

     4.4  返回4.2 直到不能继续

 

    损失函数:T为叶节点个数

      

    经验熵:

       

      则上面也可以写为:

       

         

     算法里面:

        需要先找到叶节点,然后递归回父节点,遍历出同一个父节点下面的其它叶节点

      才能得到剪枝后T数目

 

    

          剪枝后

          


  4.2   CART 剪枝

    

          子树序列形成:

                 

             对于固定的 ,存在不同的剪枝方案,但存在一个最优的子树

          

          以 t为单结点的子树t损失函数:

        

         以 t 为根结点损失函数为:

          

          可以调整参数a,使得二者相等

          

 

                此刻:

           表示剪枝后,整体损失函数减少程度

 

              在剪枝得到的子树序列T0,T1.....Tn中交叉验证选取最优子树

 

              利用独立数据集,测试子树序列的基尼系数,或者平方误差,这里面的每一轮生成一个最优的Tree

下一轮在前一轮基础上,重新剪枝

 

         算法流程:


          输入:

                 CART算法生成的决策树T0

          输出:

                  最优决策树

 

          1 : 设k = 0, T= T0

          2:  

          3     自下而上对内部各个结点t计算以及

               

 

             这里Tt表示以t为根结点的子树,C(Tt)是对训练集的预测误差,|T_t|是Tt的叶结点个数

 

             4:

的内部结点的子树,进行剪枝,以多数表决法得到树

 

   5: 设k=k+1,,

 

   6: Tk 不是由根结点以及两个叶结点构成的树,否则Tk= Tn

   7: 采用交叉验证得到子树T0,T1,...Tn选择最优子树

 这里注意:

                        剪枝后,一定要保证各个叶结点的精度。

 

 

               

        代码实现:

     为了方便,生成树后先给不同叶结点打上唯一的Tag

    

          

  1. """
  2. 给叶结点打上唯一的标签
  3. """
  4. def AddTag(self, tree):
  5. keys =['left', 'right']
  6. for key in keys:
  7. subTree = tree[key]
  8. if self.IsTree(subTree):
  9. self.AddTag(subTree)
  10. else:
  11. Tag = "TAG_"+key
  12. if Tag not in tree:
  13. x = self.LeafTag.pop(0)
  14. tree[Tag] = x

        

    这里面给出一种单轮选择最优Gini系数的方案,不同轮即上面k=k+1,最好使用不同的数据集

   同时也要注意,剪枝后,要使用数据集对新生成的Tree验证

   

  1. """
  2. 剪枝
  3. Args
  4. dataList: 数据集
  5. tree: 树
  6. return
  7. None
  8. """
  9. def Prunch(self, dataList, dataLabel,tree):
  10. ct = 0.0
  11. c_Beaf = 0.0
  12. if not self.IsTree(tree) or len(dataLabel)<1:
  13. return
  14. BeafNum = self.GetLeafs(tree) ##叶结点数目|T|
  15. ct = self.GetGini(dataLabel)
  16. if BeafNum<2: ###本身已经是叶结点了
  17. return
  18. dict_beaf={}
  19. i = 0
  20. for data in dataList:
  21. perdict_Label, Tag_ID = self.Perdict(tree,data)
  22. if Tag_ID not in dict_beaf:
  23. dict_beaf[Tag_ID] =[]
  24. test_label = dataLabel[i]
  25. beaf_List = dict_beaf[Tag_ID]
  26. beaf_List.append(test_label)
  27. i = i+1
  28. #keyList = dict_beaf.keys()
  29. # print("\n ****keyList***\n\t ",keyList, "\t len: ",len(keyList))
  30. c_Beaf = self.GetBeafGini(dict_beaf)
  31. alpha = (ct-c_Beaf)/(BeafNum-1)
  32. name = tree['name']
  33. feat = self.train_feature.index(name)
  34. val = tree['feature']
  35. # print("\n ct: \t",ct, "\t c_Beaf: \t",c_Beaf, "\t alpha: ",alpha)
  36. print("\n name: ",name , "\t val : ",val)
  37. print("\n t为根结点Gini :\t %.3f"%ct, "\t t为树的Gini:\t %.3f"%c_Beaf, "\t alpha: %.3f"%alpha)
  38. L_Data, L_Label, R_Data, R_Label =self.SplitData(dataList, feat, val, dataLabel)
  39. LTree= tree['left']
  40. RTree =tree['right']
  41. self.Prunch(L_Data,L_Label ,LTree)
  42. self.Prunch(R_Data, R_Label,RTree )
  43. return None

 name:  worst perimeter           val :  106.2

 t为根结点Gini :     0.371   t为树的Gini:       0.089   alpha: 0.017

 name:  worst smoothness          val :  0.1777

 t为根结点Gini :     0.042   t为树的Gini:       0.032   alpha: 0.001

 name:  worst concave points      val :  0.1607

 t为根结点Gini :     0.022   t为树的Gini:       0.019   alpha: 0.000

 name:  mean area         val :  698.8

 t为根结点Gini :     0.022   t为树的Gini:       0.019   alpha: 0.000

 name:  perimeter error           val :  4.138

 t为根结点Gini :     0.022   t为树的Gini:       0.019   alpha: 0.000

 name:  worst texture     val :  30.25

 t为根结点Gini :     0.022   t为树的Gini:       0.019   alpha: 0.001

 name:  mean fractal dimension    val :  0.05628

 t为根结点Gini :     0.124   t为树的Gini:       0.114   alpha: 0.003

 name:  smoothness error          val :  0.007499

 t为根结点Gini :     0.077   t为树的Gini:       0.073   alpha: 0.002

 name:  mean radius       val :  12.77

 t为根结点Gini :     0.000   t为树的Gini:       0.000   alpha: 0.000

 name:  mean radius       val :  12.77

 t为根结点Gini :     0.000   t为树的Gini:       0.000   alpha: 0.000

 name:  mean texture      val :  15.56

 t为根结点Gini :     0.349   t为树的Gini:       0.222   alpha: 0.021

 name:  mean perimeter    val :  102.5

 t为根结点Gini :     0.346   t为树的Gini:       0.267   alpha: 0.079

 name:  worst smoothness          val :  0.1021

 t为根结点Gini :     0.262   t为树的Gini:       0.217   alpha: 0.011

 name:  mean radius       val :  18.08

 t为根结点Gini :     0.000   t为树的Gini:       0.000   alpha: 0.000

 name:  worst concavity           val :  0.1932

 t为根结点Gini :     0.265   t为树的Gini:       0.220   alpha: 0.022

 name:  mean radius       val :  16.02

 t为根结点Gini :     0.375   t为树的Gini:       0.333   alpha: 0.042

五  例子

     1: 乳腺癌例子

         

  1. # -*- coding: utf-8 -*-
  2. """
  3. Created on Mon Sep 9 16:49:01 2019
  4. @author: chengxf2
  5. """
  6. import numpy as np
  7. import matplotlib.pyplot as plt
  8. import sys,os
  9. from sklearn.datasets import load_breast_cancer
  10. class CART:
  11. def IsTree(self, obj):
  12. bTree =(type(obj).__name__=='dict')
  13. return bTree
  14. """
  15. 获得叶结点个数
  16. Args
  17. tree: 树
  18. return
  19. num
  20. """
  21. def GetLeafs(self, tree):
  22. numLeaf = 0
  23. keys =['left', 'right']
  24. for key in keys:
  25. subTree = tree[key]
  26. if self.IsTree(subTree):
  27. numLeaf += self.GetLeafs(subTree)
  28. else:
  29. numLeaf +=1
  30. return numLeaf
  31. """
  32. 获得树的深度
  33. Args
  34. tree: 树
  35. return
  36. depth
  37. """
  38. def GetDepth(self, tree):
  39. print("====depth##########")
  40. maxDepth = 0
  41. keys =['left', 'right']
  42. for key in keys:
  43. subTree = tree[key]
  44. if self.IsTree(subTree):
  45. depth =1+ self.GetDepth(subTree)
  46. else:
  47. depth = 1
  48. if depth >maxDepth:
  49. maxDepth = depth
  50. return maxDepth
  51. """
  52. 加载数据集
  53. Args:
  54. None
  55. return
  56. feature_names:
  57. ['mean radius' 'mean texture' 'mean perimeter' 'mean area'
  58. 'mean smoothness' 'mean compactness' 'mean concavity'
  59. 'mean concave points' 'mean symmetry' 'mean fractal dimension'
  60. 'radius error' 'texture error' 'perimeter error' 'area error'
  61. 'smoothness error' 'compactness error' 'concavity error'
  62. 'concave points error' 'symmetry error' 'fractal dimension error'
  63. 'worst radius' 'worst texture' 'worst perimeter' 'worst area'
  64. 'worst smoothness' 'worst compactness' 'worst concavity'
  65. 'worst concave points' 'worst symmetry' 'worst fractal dimension']
  66. target_Name:
  67. ['malignant' 恶性 'benign' 良性]
  68. ### 数据总长 569 ,300用来Train ,269 用来测试,剪枝
  69. """
  70. def LoadData(self):
  71. cancer = load_breast_cancer()
  72. data = cancer['data']
  73. target = cancer['target']
  74. target_Name = cancer['target_names']
  75. DESCR = cancer['DESCR']
  76. feature_names = cancer['feature_names']
  77. #print("\n data: \n ", data)
  78. #print("\n target: \n ", target)
  79. #print("\n target_Name: \n ", target_Name)
  80. #print("\n DESCR: \n ", DESCR)
  81. #print("\n feature_names: \n ", feature_names)
  82. self.trainData = data[0:300]
  83. self.train_target = target[0:300]
  84. self.target_name= target_Name
  85. self.train_feature = feature_names.tolist()
  86. self.testData = data[300:-1]
  87. self.test_target= target[300:-1]
  88. """
  89. 分割数据
  90. Args
  91. dataMat: 输入矩阵
  92. col: 列
  93. val: 值
  94. Returns
  95. L_Data: 小于该值的矩阵
  96. L_Labels: 左边标签
  97. R_Data: 右边矩阵
  98. R_Labels: 右边标签
  99. """
  100. def SplitData(self, dataList, col, val, Labels):
  101. L_Data =[]; L_Label=[]
  102. R_Data =[]; R_Label=[]
  103. m,n = np.shape(dataList)
  104. # print("\n m: ",m, "n: ",n)
  105. for i in range(m):
  106. cur = dataList[i][col]
  107. data = dataList[i]
  108. #print("\n data ",data)
  109. label = Labels[i]
  110. if cur<val:
  111. L_Data.append(data)
  112. L_Label.append(label)
  113. else:
  114. R_Data.append(data)
  115. R_Label.append(label)
  116. # print("\n =R_Mat= \n ",np.shape(R_Data))
  117. #print("======================\n")
  118. #print("\n L_Data: \n ",np.shape(L_Data))
  119. return L_Data, L_Label, R_Data, R_Label
  120. """
  121. 分割数据
  122. Args
  123. dataMat: 输入矩阵
  124. col: 列
  125. val: 值
  126. Returns
  127. L_Data: 小于该值的矩阵
  128. L_Labels: 左边标签
  129. R_Data: 右边矩阵
  130. R_Labels: 右边标签
  131. """
  132. def GetSubLabel(self, dataList, col, val, Labels):
  133. L_Label=[]
  134. R_Label=[]
  135. m,n = np.shape(dataList)
  136. for i in range(m):
  137. cur = dataList[i][col]
  138. label = Labels[i]
  139. if cur<val:
  140. L_Label.append(label)
  141. else:
  142. R_Label.append(label)
  143. return L_Label,R_Label
  144. def GetGini(self, Labels):
  145. m = len(Labels)
  146. dictItem ={}
  147. gini = 0.0
  148. if m <1:
  149. return None
  150. for label in Labels:
  151. if label not in dictItem:
  152. dictItem[label]=0
  153. dictItem[label]=dictItem[label]+1
  154. for key in dictItem.keys():
  155. prob = np.power(dictItem[key]/m,2)
  156. gini +=prob
  157. return 1-gini
  158. """
  159. 选择最优特征
  160. Args
  161. dataList: 数据集
  162. Labels: 标签集
  163. return
  164. L_Data
  165. R_Data
  166. L_Label
  167. R_Label
  168. """
  169. def ChooseBestFeatures(self, dataList, Labels):
  170. m,n = np.shape(dataList)
  171. miniGini = float("inf") ##Gini 系数选择最小的
  172. bestCol= 0 ##最佳特征
  173. bestFeature = 0 ##划分点
  174. for i in range(n):
  175. item = [data[i] for data in dataList]
  176. setFeature = set(item)
  177. for feature in setFeature:
  178. L_Label, R_Label= self.GetSubLabel(dataList, i, feature, Labels)
  179. m1 = len(L_Label)
  180. m2 = len(R_Label)
  181. if m1==0 or m2 ==0: continue
  182. gini = (m1*self.GetGini(L_Label)+m2*self.GetGini(R_Label))/m
  183. if gini<miniGini:
  184. miniGini = gini
  185. bestCol = i
  186. bestFeature = feature
  187. print("\n minGini ",miniGini, "bestCol ",self.train_feature[bestCol], "\t bestFeature ",bestFeature)
  188. return bestCol, bestFeature
  189. """
  190. 创建树
  191. Args
  192. dataList: 数据集
  193. Labels: 分类结果
  194. return
  195. Tree
  196. """
  197. def CreateTree(self, dataList, Labels):
  198. setLabel = set(Labels)
  199. m = len(dataList) ##样本个数
  200. if 1 == len(setLabel): ###只有一个分类
  201. return Labels[0]
  202. elif 1 ==m: ##只有一个样本
  203. return Labels[0]
  204. feat, val = self.ChooseBestFeatures(dataList, Labels)
  205. tree ={}
  206. tree['feature']=val
  207. tree['name'] = self.train_feature[feat]
  208. L_Data, L_Label, R_Data, R_Label =self.SplitData(dataList, feat, val, Labels)
  209. tree['left']=self.CreateTree(L_Data,L_Label )
  210. tree['right']=self.CreateTree(R_Data, R_Label )
  211. return tree
  212. """
  213. 分类
  214. Args
  215. myTree 决策树
  216. Labels 标签
  217. testVec 当前数据集
  218. return
  219. 测试出来的标签
  220. """
  221. def Perdict(self, myTree, data):
  222. keyList= ['left','right']
  223. name = myTree['name']
  224. feature = myTree['feature']
  225. col = self.train_feature.index(name)
  226. cur = data[col]
  227. if cur<feature:
  228. subTree = myTree['left']
  229. else:
  230. subTree = myTree['right']
  231. if self.IsTree(subTree):
  232. label = self.Perdict(subTree, data)
  233. else:
  234. label =subTree
  235. return label
  236. """
  237. 训练
  238. """
  239. def Train(self):
  240. tree = self.CreateTree(self.trainData, self.train_target)
  241. depth = self.GetDepth(tree)
  242. num = self.GetLeafs(tree)
  243. print("\n tree: \n ",tree)
  244. print("\n 叶结点数: \t",num)
  245. print("\n 树深度: \t ",depth )
  246. print("\n ***************树已生成*******************\n")
  247. num = len(self.testData)
  248. err = 0
  249. for i in range(num):
  250. data = self.testData[i]
  251. true_label = self.test_target[i]
  252. label= self.Perdict(tree, data)
  253. print("\n label ",label, "\t target ",true_label)
  254. if true_label != label:
  255. err=err+1
  256. print("\n 测试样本: \t", num, "测试错误率: \t ",err/num, "\t err: ",err)
  257. """
  258. 初始化
  259. Args
  260. None
  261. return
  262. None
  263. """
  264. def __init__(self):
  265. self.LoadData()
  266. self.Train()
  267. cart =CART()

    

         2: 剖腹产预测例子

       

  1. # -*- coding: utf-8 -*-
  2. """
  3. Created on Wed Aug 28 13:59:01 2019
  4. @author: chengxf2
  5. """
  6. """
  7. 决策树
  8. """
  9. import numpy as np
  10. from mpl_toolkits.mplot3d import Axes3D
  11. import matplotlib.pyplot as plt
  12. from matplotlib import cm
  13. import os,sys
  14. import operator
  15. import math
  16. from decimal import *
  17. import treePlot
  18. from imp import reload
  19. import copy
  20. import miniTree as mini
  21. #import csvFun
  22. """
  23. 数据集2: 根据学生情况,预测是否购买电脑
  24. self.dataList=[['teenager' ,'high', 'no' ,'same', 'no'],
  25. ['teenager', 'high', 'no', 'good', 'no'],
  26. ['middle_aged' ,'high', 'no', 'same', 'yes'],
  27. ['old_aged', 'middle', 'no' ,'same', 'yes'],
  28. ['old_aged', 'low', 'yes', 'same' ,'yes'],
  29. ['old_aged', 'low', 'yes', 'good', 'no'],
  30. ['middle_aged', 'low' ,'yes' ,'good', 'yes'],
  31. ['teenager' ,'middle' ,'no', 'same', 'no'],
  32. ['teenager', 'low' ,'yes' ,'same', 'yes'],
  33. ['old_aged' ,'middle', 'yes', 'same', 'yes'],
  34. ['teenager' ,'middle', 'yes', 'good', 'yes'],
  35. ['middle_aged' ,'middle', 'no', 'good', 'yes'],
  36. ['middle_aged', 'high', 'yes', 'same', 'yes'],
  37. ['old_aged', 'middle', 'no' ,'good' ,'no']]
  38. self.Labels=['age','input','student','level']
  39. RID age income student credit_rating class:buys_computer
  40. """
  41. #
  42. from csvFun import LoadFile
  43. class Shannon:
  44. '''
  45. 学生购买电脑例子
  46. Args
  47. None
  48. return
  49. data
  50. '''
  51. def LoadDataStudent(self):
  52. self.dataList=[['teenager' ,'high', 'no' ,'same', 'no'],
  53. ['teenager', 'high', 'no', 'good', 'no'],
  54. ['middle_aged' ,'high', 'no', 'same', 'yes'],
  55. ['old_aged', 'middle', 'no' ,'same', 'yes'],
  56. ['old_aged', 'low', 'yes', 'same' ,'yes'],
  57. ['old_aged', 'low', 'yes', 'good', 'no'],
  58. ['middle_aged', 'low' ,'yes' ,'good', 'yes'],
  59. ['teenager' ,'middle' ,'no', 'same', 'no'],
  60. ['teenager', 'low' ,'yes' ,'same', 'yes'],
  61. ['old_aged' ,'middle', 'yes', 'same', 'yes'],
  62. ['teenager' ,'middle', 'yes', 'good', 'yes'],
  63. ['middle_aged' ,'middle', 'no', 'good', 'yes'],
  64. ['middle_aged', 'high', 'yes', 'same', 'yes'],
  65. ['old_aged', 'middle', 'no' ,'good' ,'no']]
  66. self.dataLabel =['age','input','student','level','result']
  67. """
  68. 降维
  69. Args
  70. age:连续数据
  71. return
  72. 离散数据
  73. """
  74. def ReduceDimension(self, age):
  75. #print("age ",min(age), max(age))
  76. DiAge =[]
  77. for x in age:
  78. if x<32:
  79. DiAge.append("Y")
  80. else:
  81. DiAge.append("M")
  82. return DiAge
  83. """
  84. 转换为高斯分布
  85. Arg:
  86. age: list
  87. u: 平均值
  88. return
  89. 复合高斯分布的数据
  90. """
  91. def GetInfo(self, dataList):
  92. age0 = []
  93. age1 = []
  94. for item in dataList:
  95. print("-1 ",item)
  96. if item[-1]=='0':
  97. age0.append(int(item[0]))
  98. else:
  99. age1.append(int(item[0]))
  100. #求方差
  101. mean0 = np.mean(age0)
  102. var0 = np.var(age0)
  103. mean1 = np.mean(age1)
  104. var1 = np.var(age1)
  105. #求标准差
  106. std0 = np.std(age0,ddof=1)
  107. std1 = np.std(age1,ddof=1)
  108. #print("mean0 ", mean0, " var0: ", var0, "\t mean1 ",mean1, "\t var1: ",var1)
  109. #print("\n std0: ",std0, "\t std1: ",std1)
  110. """
  111. 如果属性集只有一个属性了,返回概率最高的那个
  112. Args
  113. 数据
  114. return
  115. 分类结果
  116. """
  117. def MajorCnt(self, classList):
  118. dictCount ={}
  119. for item in classList:
  120. # print("\n vote ",vote)
  121. vote = item[-1]
  122. if vote not in dictCount:
  123. dictCount[vote] = 0
  124. dictCount[vote]+=1
  125. # print("\n classCount ",dictCount)
  126. sortKey = sorted(dictCount.items(), key =lambda d:d[1], reverse= True)
  127. return sortKey[0][0],dictCount
  128. """
  129. 获得数据集路径,以及保存Tree路径
  130. arg:
  131. None
  132. return:
  133. File Path
  134. """
  135. def GetPath(self):
  136. ##
  137. fileName = "caesarian.csv"
  138. path = os.path.abspath(fileName)
  139. self.path = path
  140. self.treePath = os.path.abspath("tree")
  141. """
  142. 加载数据集
  143. Args:
  144. None
  145. return
  146. dataList
  147. """
  148. def LoadData(self):
  149. #labels = ['no surfacing','flippers'] ###相当于属性
  150. #Caesarian 剖腹产
  151. dataList = LoadFile(self.path)
  152. self.dataList = dataList[1:]
  153. print("\n ****************************\n")
  154. #mini.MiniTree(0, self.dataList)
  155. age = [int(item[0]) for item in self.dataList]
  156. age = self.ReduceDimension(age)
  157. for i in range(len(self.dataList)):
  158. self.dataList[i][0]=age[i]
  159. self.dataLabel = ['Age','Num','time','Blood','Heart','Result']
  160. """
  161. 计算单个熵
  162. Args
  163. prob
  164. return
  165. H: 熵
  166. """
  167. def GetEnt(self, prob):
  168. ent = 0.0
  169. if prob==0 or prob ==1:
  170. return 0
  171. else:
  172. ent = -prob*np.log2(prob)
  173. return ent
  174. """
  175. 获得对应的熵
  176. Args:
  177. dataList
  178. return
  179. ent:熵
  180. """
  181. def GetHD(self, dataList):
  182. num = len(dataList)
  183. # print("样本个数: ",num)
  184. labelDict ={}
  185. for i in range(num):
  186. label = dataList[i][-1]
  187. if label not in labelDict:
  188. labelDict[label]=0
  189. labelDict[label]= labelDict[label]+1
  190. ent = 0.0
  191. for key in labelDict:
  192. prob = labelDict[key]/num
  193. ent += self.GetEnt(prob)
  194. #print("\n ent:::: ",ent)
  195. return ent
  196. """
  197. 调试
  198. """
  199. def Debug(self):
  200. n = np.arange(2,30)
  201. shanList =[]
  202. for i in range(2,30):
  203. p = 1/i
  204. y = self.GetShan(p)
  205. shanList.append(y*i)
  206. plt.plot(n, shanList)
  207. plt.show()
  208. """
  209. 根据指定的属性,获取对应的属性
  210. Args
  211. col: 对应的一列
  212. attr: 属性
  213. return
  214. 匹配的数据集
  215. """
  216. def SpliteData(self,dataList, col, attr):
  217. subData = []
  218. for item in dataList:
  219. curAttr = item[col]
  220. #print("\n curAttr :",curAttr)
  221. data= []
  222. if curAttr == attr:
  223. data1 = item[:col]
  224. data2 = item[col+1:]
  225. data.extend(data1)
  226. data.extend(data2)
  227. #print("\n data1: ",data1,"\t data2: ",data2, "curAttr: ",curAttr)
  228. subData.append(data)
  229. n = len(subData)
  230. #print("n: ",n ,"\n subData ",subData)
  231. return subData
  232. """
  233. 使用信息增益
  234. args
  235. dataList
  236. return
  237. 对应列
  238. """
  239. def ChooseAttr(self, dataList):
  240. numAttr = len(dataList[0])-1 ##最后一个是属性
  241. num = float(len(dataList))
  242. baseHD = self.GetHD(dataList)
  243. bestGain = 0.0
  244. bestAttr = 0
  245. for col in range(numAttr):
  246. attrList = [data[col] for data in dataList]
  247. setAttr = set(attrList)
  248. HDa = 0.0
  249. HA =0.0
  250. for attr in setAttr:
  251. subData = self.SpliteData(dataList, col, attr)
  252. prob = len(subData)/num
  253. ent = self.GetHD(subData)
  254. HA += self.GetEnt(prob)
  255. HDa +=prob*ent
  256. gDA = (baseHD-HDa)/(HA+1.0) ##除以HA 就是隐形增益比
  257. #print("HA ",HA)
  258. #print("infoGain ",infoGain)
  259. if (gDA> bestGain):
  260. bestGain = gDA
  261. bestAttr = col
  262. #print("\n ******bestAttr*********", bestAttr)
  263. return bestAttr,bestGain
  264. """
  265. 创建树
  266. Args
  267. DataSet, labels
  268. return
  269. treeDict
  270. """
  271. def CreateTree(self, dataList, labels):
  272. kind = [item[-1] for item in dataList] ##最后一列Result
  273. if kind.count(kind[0])== len(kind):
  274. return kind[0]
  275. ###只剩下一个属性了
  276. if len(dataList[0]) ==1: ###只有一个属性了
  277. label,dictCount = self.MajorCnt(dataList)
  278. return label
  279. bestAttr,bestGain = self.ChooseAttr(dataList) ##最佳一列
  280. if bestGain<self.epsilon :
  281. label,dictCount = self.MajorCnt(dataList)
  282. if label =='Y':
  283. print("***********error**********",dataList, "\t dict ",dictCount)
  284. return label
  285. #print("\n ",bestGain)
  286. bestLabel = labels[bestAttr]
  287. del labels[bestAttr] ##删除某一列
  288. myTree ={bestLabel:{}}
  289. ###创建分枝
  290. branch = [item[bestAttr] for item in dataList]
  291. setBranch = set(branch)
  292. for key in setBranch:
  293. subLabel = labels[:]
  294. subData = self.SpliteData(dataList, bestAttr, key)
  295. subTree = self.CreateTree(subData, subLabel)
  296. myTree[bestLabel][key] = subTree
  297. return myTree
  298. """
  299. 训练
  300. Args
  301. None
  302. return
  303. """
  304. def Train(self):
  305. label = copy.deepcopy(self.dataLabel)
  306. tree = self.CreateTree(self.dataList, label)
  307. #print("\n Train: dataLabel: ",self.dataLabel)
  308. self.storeTree(tree)
  309. self.grabTree()
  310. #print("\n tree ",tree)
  311. """
  312. 获得叶节点数目
  313. Args:
  314. Tree
  315. return
  316. numLeaf
  317. """
  318. def GetNumLeaf(self, myTree):
  319. numLeaf = 0
  320. firstNode = list(myTree.keys())[0]
  321. #print("\n type::: ",type(firstNode))
  322. secondDict = myTree[firstNode]
  323. keys = list(secondDict.keys())
  324. #print("\n keys::: ",keys)
  325. for key in keys:
  326. if type(secondDict[key]).__name__=='dict':
  327. numLeaf += self.GetNumLeaf(secondDict[key])
  328. else:
  329. numLeaf+=1
  330. return numLeaf
  331. """
  332. 获取树的深度
  333. Args
  334. Tree
  335. return
  336. Depth
  337. """
  338. def GetTreeDepth(self, myTree):
  339. maxDepth = 0
  340. firstNode = list(myTree.keys())[0]
  341. secondDict = myTree[firstNode]
  342. keys = list(secondDict.keys())
  343. for key in keys:
  344. if type(secondDict[key]).__name__=='dict':
  345. depth =1+ self.GetTreeDepth(secondDict[key])
  346. else:
  347. depth = 1
  348. if depth>maxDepth:
  349. maxDepth = depth
  350. return maxDepth
  351. """
  352. 保存模型
  353. Args
  354. tree
  355. fileName
  356. return
  357. None
  358. """
  359. def storeTree(self, tree):
  360. import pickle
  361. fileName = self.treePath
  362. fw = open(fileName,'wb')
  363. pickle.dump(tree, fw)
  364. #print("\n storeTree label: ",self.dataLabel)
  365. fw.close()
  366. """
  367. 加载树
  368. Args
  369. fileNmae
  370. return
  371. tree
  372. """
  373. def grabTree(self):
  374. import pickle
  375. fileName = self.treePath
  376. fr = open(fileName,'rb')
  377. tree = pickle.load(fr)
  378. leaf = self.GetNumLeaf(tree)
  379. print("leaf ",leaf)
  380. #print("\n tree: ",tree)
  381. #print("\n **********************\n ",tree)
  382. reload(treePlot)
  383. treePlot.createPlot(tree)
  384. errorNum = 0
  385. num = len(self.dataList)
  386. print("\n grabTree label: ",self.dataLabel)
  387. for item in self.dataList:
  388. classifyLabel = self.classify(tree, self.dataLabel, item)
  389. if classifyLabel != item[-1]:
  390. errorNum= errorNum+1
  391. #print("\n classifyLabel:", classifyLabel, " real: ",item[-1])
  392. print("\n errorNUm ",errorNum, " total: ",num)
  393. return tree
  394. """
  395. 分类
  396. Args
  397. myTree 决策树
  398. Labels 标签
  399. testVec 当前数据集
  400. return
  401. 测试出来的标签
  402. """
  403. def classify(self, myTree, Labels, testVec):
  404. father = list(myTree.keys())[0]
  405. childDict = myTree[father]
  406. index = Labels.index(father)
  407. for key in list(childDict.keys()):
  408. if testVec[index]==key:
  409. if type(childDict[key]).__name__=='dict':
  410. testLabel = self.classify(childDict[key], Labels, testVec)
  411. else:
  412. testLabel = childDict[key]
  413. return testLabel
  414. def __init__(self):
  415. self.epsilon = 0.01
  416. self.m = 0
  417. self.n = 0
  418. self.fileName =""
  419. self.GetPath()
  420. self.LoadData()
  421. #self.LoadDataStudent()
  422. self.Train()
  423. shannon = Shannon()

 六   问题

   1     有1000笔贷款,其中部分是10000以上,有部分是100以下,决策树准确率只有80%多,

      如何设计决策树,保证发放贷款是盈利的。

 

   2   剪枝主要是为了解决过拟合,但是当部分结点本身就欠拟合,或者精度不高,还需要剪枝?

 

  3    生成结点的终止条件

   

参考文档:

   

      《机器学习实战》

        《机器学习与应用》

       《统计学习方法》

        https://blog.csdn.net/tkkzc3E6s4Ou4/article/details/83829616

         https://www.cnblogs.com/paisenpython/p/10371644.html
          https://blog.csdn.net/hot7732788/article/details/90070618

         https://blog.csdn.net/ccblogger/article/details/82656185

         https://www.cnblogs.com/beiyan/p/8321329.html

          https://wenku.baidu.com/view/671df33631126edb6f1a101c.html

           https://www.cnblogs.com/lpworkstudyspace1992/p/8030186.html

https://cuijiahua.com/blog/2017/12/ml_13_regtree_1.html

 

   

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/你好赵伟/article/detail/768583
推荐阅读
相关标签
  

闽ICP备14008679号