当前位置:   article > 正文

python实现决策树_决策树的基尼系数 numpy python

决策树的基尼系数 numpy python
 决策树是一个预测模型;他代表的是对象属性与对象值之间的一种映射关系。树中每个节点表示某个对象,而每个分叉路径则代表某个可能的属性值,而每个叶节点则对应从根节点到该叶节点所经历的路径所表示的对象的值。
  • 1

详细关于决策树的讨论,请自行google。

一、找到最优分割位置

1、针对样本数据,需要在其不同的维度(d)上根据特定数据(v)进行分割

#X:样本数据
#y:样本属性
#d:维度
#v:分割标准
def cut(X , y , d , v):
    ind_left  = (X[:,d] <= v)
    ind_right = (X[:,d] > v)
    return (X[ind_left] , X[ind_right] , y[ind_left] , y[ind_right])
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8

2、将样本数据排序

sorted_index= np.argsort(X[:,d])
  • 1

3、找出中间点

v = (X[sorted_index[i] , d ] + X[sorted_index[i+1] , d ]) / 2
  • 1

4、按照中间点进行分割

X_left , X_right , y_left , y_right = cut(X , y , d , v)
  • 1

5、计算基尼系数

    gini_cur = gini(y_left , y_right )
  • 1

6、找到基尼系数最小的分割位置(维度,分割值)

if gini_cur < gini_best :
   best_g = gini_cur
   best_d = d
   best_v = v
  • 1
  • 2
  • 3
  • 4

二、创建决策树
1、找到原始数据的最优分割点(对于第一次,找的结果是根节点的分割情况)

d , v , g  = try_split(X , y)
  • 1

2、将找的结果保存在结点Node中

node = Node(d,v,g)
  • 1

3、根据最优点将数据分割

X_left , X_right , y_left , y_right = cut(X , y , d , v)
  • 1

4、递归查找下一个结点

node.child_left  = create_tree(X , y)
node.child_right = create_tree(X , y)
  • 1
  • 2

最后对上述过程汇总:
1、实现计算基尼系数
这里写图片描述

from collections import Counter

#y:样本数据的标签
def gini(y):
    counter = Counter(y)
    result = 0
    for v in counter.values():
        result += (v / len(y))**2
    return (1 - result )
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9

2、根据维度(d)和值(v)对数据进行分割

#X:样本数据
#y:样本数据的标签
#d:维度
#v:分割数据
def cut(X , y , d , v):
    ind_left  = (X[:,d] <= v)
    ind_right = (X[:,d] > v)
    return (X[ind_left] , X[ind_right] , y[ind_left] , y[ind_right])
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8

3、查找最优分割点

import numpy as np
#X:样本数据
#y:样本数据的标签
def try_split(X , y):
    best_g = 1
    best_d = -1
    best_v = -1
    for d in range(X.shape[1]):
        #将数据排序
        sorted_index = np.argsort(X[:,d])
        for i in range(len(X) - 1):
            if (X[sorted_index[i],d] == X[sorted_index[i + 1],d]):
                continue
            #计算两点之间的平均值
            v = (X[sorted_index[i],d] + X[sorted_index[i + 1],d]) / 2
            #根据d  v将X  y分割
            X_left , X_right , y_left , y_right = cut(X , y , d , v)
            #计算基尼系数
            gini_cur = gini(y_left) + gini(y_right)
            #计算最优分割点
            if gini_cur < best_g:
                best_g = gini_cur
                best_v = v
                best_d = d
      return (best_d,best_v,best_g)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25

4、结点,保存分割信息

class Node():
    def __init__(self,d=None,v=None,g=None,l=None):
        self.dim   = d
        self.value = v
        self.gini  = g
        self.label = l

        self.child_left = None
        self.child_rignt = None
    def __repr__(self):
        return 'Node(d={},v={},g={})'.format(self.dim,self.value,self.gini)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11

5、创建决策树

def create_tree(X , y):
    #查找最优分割点
    d,v,g = try_split(X , y)

    #不用再分
    if (d==-1) or (g==0):
        return None

    #实例化结点
    node = Node(d,v,g)

    #按照最优点把数据分割
    X_left , X_right , y_left , y_right = cut(X , y , d , v)

    #递归子结点(左)
    node.child_left = create_tree(X_left , y_left)
    #左边分割完了,保存label
    if node.child_left is None:
        #label
        label = Counter(y_left).most_common(1)[0][0]
        node.label = Node(l = label)

    #递归子结点(右)
    node.child_right = create_tree(X_left , y_left)
    #右边分割完了,保存label
    if node.child_right is None:
        #label
        label = Counter(y_right).most_common(1)[0][0]
        node.label = Node(l = label)

   return node
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31

6、绘制决策树
这里写图片描述

def show_tree(node):
    if node is None:
        return ''

    result += '{} [label="{}"]\n'.format(id(node),node)
    if node.child_left is not None:
        result += '{} [label="{}"]\n'.format(id(node.child_left),node.child_left)
        result += '{}->{}\n'.format(id(node),id(node.child_left))
        result += show_tree(node.child_left)

    if node.child_right is not None:
        result += '{} [label="{}"]\n'.format(id(node.child_right),node.child_right)
        result += '{}->{}\n'.format(id(node),id(node.child_right))
        result += show_tree(node.child_right)
    return result
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/AllinToyou/article/detail/549480
推荐阅读
相关标签
  

闽ICP备14008679号