当前位置:   article > 正文

Python实现决策树(系列文章6)-- 名义型变量属性值分割(修正)_循环拆分名义变量

循环拆分名义变量

1 问题

决策树对于变量的竞争/选择分为两种:

  1. 变量间的比较选择(二分)
  2. 变量内的比较选择(二分)

首先,决策树选定一个变量,进行变量内的比较选择。
变量本身可能是N(Nominal, 名义型),O(Ordinal,有序离散型)以及C(Continuous, 连续型)。
对于C,可以进行百分位的离散化,之后处理的方式类似于O,自小向大进行滑动,寻找使得gini最小的值即可(由于变量收到「序」的约束)。因此每个变量的搜索过程不超过100次。
对于N, 名义型的变量值不存在「序」,因此最初的版本使用组合方式,穷举变量值可能的二分方式。具体的方法是用C(n,1) , C(n,2)… 一直搜索到C(n, n//2) 。因此不可避免出现了组合爆炸的问题。
例如,对与一个有15种值的变量二分:

from scipy.special import comb, perm 
In [57]: for i in range(7): 
    ...:     print('Combinations 15  %s' %(i+1)) 
    ...:     print(comb(15, i+1)) 
    ...:                                                                        
Combinations 15  1
15.0
Combinations 15  2
105.0
Combinations 15  3
455.0
Combinations 15  4
1365.0
Combinations 15  5
3003.0
Combinations 15  6
5005.0
Combinations 15  7
6435.0
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19

成千上万的组合不仅需要更多的迭代,也直接撑爆了内存。

2 修改思路

既然一次性穷举所有组合的方式不可取,因此考虑一种渐进搜索的方式。
首先通过只选择一个值的方式对整个变量进行二分(A or not A),通过遍历找到最好的单值二分点。
然后再添加一个变量,进行遍历和搜索。如果gini的减少(或者说改进)低于一定的速率,那么终止循环。
对于有15种值的N型变量,最多进行120次循环就可以完成计算。而且通常来说,由于gini改进速率的限制,很可能50次以内的循环就完成了。
结合上篇文章矩阵并行的方法,只要数次的矩阵计算就可以完成。

3 程序对比

  • 3.1 原来的程序

整体上来说,首先判断变量的类型,如果是N型,使用nbcut函数完成组合上的穷举。
然后使用穷举生成的n的序列(res_dict.data_list)进行循环选优。找到最优的组合后,将选中和未选中的变量值组合分别输出(condition_left、condition_right)。

    # 函数13:
    # 对于分类树,寻找最小gini
    # x是某个变量, y是目标列
    @staticmethod
    def find_min_gini(x=None, y=None, varname=None, vartype=None, start=1, pstart=0.1, pend=0.9):
        # 先获得可能的划分
        assert vartype in [
            'C', 'O', 'N'], 'Only Accept Vartype C(continuous), O(Oridinal), N(Nominal)'
        # 》》》 Note Here
        if vartype == 'N':
            res_dict = iTree.nbcut(data=x)
        elif vartype == 'O':
            res_dict = iTree.obcut(data=x, start=start)
        else:
            res_dict = iTree.cbcut(data=x, pstart=pstart, pend=pend)
        # 在循环上几种方式是一致的
        tem_gini_list = []
        for i in range(len(res_dict['data_list'])):
            tem_gini = iTree.get_gini(res_dict['data_list'][i], y)
            tem_gini_list.append(tem_gini)
        # index + min的方法得到最小值的位置
        min_gini = min(tem_gini_list)
        mpos = tem_gini_list.index(min_gini)
        # 》》》 Note Here
        if vartype == 'N':
            # 左边的部分(in),仅仅是组合,筛选的时候还要套一次字典
            condition_left = res_dict['comb_list'][mpos]
            condition_right = res_dict['not_comb_list'][mpos]
        else:
            # 找到了某个值,将数据分为两份 ( < q 和 >= q)
            # 最终返回了最小的基尼指数,以及对应的划分条件
            condition_left = '<' + str(res_dict['qtiles'][mpos])
            condition_right = '>=' + str(res_dict['qtiles'][mpos])
        new_res_dict = {}
        new_res_dict[varname] = {}
        new_res_dict[varname]['gini'] = min_gini
        new_res_dict[varname]['condition_left'] = condition_left
        new_res_dict[varname]['condition_right'] = condition_right
        return new_res_dict
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 3.2 补充函数
# 补充函数1
def cal_min_gini_mat(mat=None, y=0):
    # 左子树
    totals = len(y)
    # left
    left_total = (1 - mat).sum(axis=0)

    left_weight = left_total / totals
    left_p1 = (1-mat)@y / left_total
    left_p0 = 1 - left_p1

    l_gini = left_weight * (1 - left_p1 ** 2 - left_p0 ** 2)

    # right
    right_total = mat.sum(axis=0)

    right_weight = right_total / totals
    right_p1 = mat @y / right_total
    right_p0 = 1 - right_p1

    r_gini = right_weight * (1 - right_p1 ** 2 - right_p0 ** 2)

    _gini = l_gini + r_gini
    return min(_gini), _gini.argmin()
# 补充函数2
def Ntype_search(x=x,y=y):

    # ------ 1 初始化:第一次计算
    var_set = set(x)
    comb_list = list(combinations(var_set, 1))  # 取出一个变量

    comb_sel_list = []
    selectd_dict = {}
    for clist in comb_list:
        tem_selected_dict = copy.deepcopy(selectd_dict)
        tem_selected_dict.update(list_key_dict(data=clist))
        comb_sel_list.append(tem_selected_dict)

    mat_list = []
    for comb_dict in comb_sel_list:
        tem_list = x.map(comb_dict).fillna(0)
        mat_list.append(tem_list.values)

    # 行=记录数,列= 变量可能的取值
    mat = np.array(mat_list).T

    min_gini, min_pos = cal_min_gini_mat(mat=mat, y=y)
    selectd_dict.update(comb_sel_list[min_pos])

    # ------ 2 迭代计算
    cnt = 0
    iter_limit = 1000
    improve_ratio_thresh = 0.05
    while cnt <= iter_limit:
        selected_set = set(selectd_dict.keys())
        var_set = var_set - selected_set
        comb_list = list(combinations(var_set, 1))  # 取出一个变量
        for clist in comb_list:
            tem_selected_dict = copy.deepcopy(selectd_dict)
            tem_selected_dict.update(list_key_dict(data=clist))
            comb_sel_list.append(tem_selected_dict)
        mat_list = []
        for comb_dict in comb_sel_list:
            tem_list = x.map(comb_dict).fillna(0)
            mat_list.append(tem_list.values)
        mat = np.array(mat_list).T
        min_gini1, min_pos1 = cal_min_gini_mat(mat=mat, y=y)
        improve_ratio = (min_gini - min_gini1)/min_gini
        if improve_ratio < improve_ratio_thresh:
            print('Not enough Improvement')
            break
        min_gini = min_gini1
        selectd_dict.update(comb_sel_list[min_pos1])
    condition_left = list(selectd_dict.keys())
    condition_right = list(set(x) - set(condition_left))
    return min_gini, condition_left, condition_right
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 3.3 修正后的函数
# 函数13:
    # 对于分类树,寻找最小gini
    # x是某个变量, y是目标列
    @staticmethod
    def find_min_gini(x=None, y=None, varname=None, vartype=None, start=1, pstart=0.1, pend=0.9):
        # 先获得可能的划分
        assert vartype in ['C', 'O', 'N'], 'Only Accept Vartype C(continuous), O(Oridinal), N(Nominal)'
        # debug:orange -> 对于N进行穷举计算很容易导致组合爆炸(要求组合数不超过10)-> 动态规划
        if vartype == 'N':
            # res_dict = iTree.nbcut(data=x)
            res_dict = None # 不使用nbcut
        elif vartype == 'O':
            res_dict = iTree.obcut(data=x, start=start)
        else:
            res_dict = iTree.cbcut(data=x, pstart=pstart, pend=pend)

        # debug:orange
        if vartype != 'N':
            tem_gini_list = []
            for i in range(len(res_dict['data_list'])):
                tem_gini = iTree.get_gini(res_dict['data_list'][i], y)
                tem_gini_list.append(tem_gini)
            # index + min的方法得到最小值的位置
            min_gini = min(tem_gini_list)
            mpos = tem_gini_list.index(min_gini)
        else:
            pass
        if vartype == 'N':
            # 左边的部分(in),仅仅是组合,筛选的时候还要套一次字典
            min_gini,condition_left,condition_right = iTree.Ntype_search(x=x,y=y)
        else:
            # 找到了某个值,将数据分为两份 ( < q 和 >= q)
            # 最终返回了最小的基尼指数,以及对应的划分条件
            condition_left = '<' + str(res_dict['qtiles'][mpos])
            condition_right = '>=' + str(res_dict['qtiles'][mpos])
        new_res_dict = {}
        new_res_dict[varname] = {}
        new_res_dict[varname]['gini'] = min_gini
        new_res_dict[varname]['condition_left'] = condition_left
        new_res_dict[varname]['condition_right'] = condition_right
        return new_res_dict

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42

附:测试数据
https://download.csdn.net/download/yukai08008/12365486

后续

过了一周看看,发现矩阵计算还是稍微有点抽象的,自己也有点看不懂了。下面做些修改和补充:

# *********************** 数据 ****************************
x = copy.deepcopy(df1['CNT_FAM_MEMBERS'])
y = copy.deepcopy(df1[target_varname])
In [3]: x.shape                                                                 
Out[3]: (99020,)

In [4]: x.value_counts()                                                        
Out[4]: 
2.0     50965
1.0     21916
3.0     16900
4.0      7948
5.0      1125
6.0       121
7.0        26
8.0        10
12.0        2
10.0        2
15.0        1
16.0        1
9.0         1
14.0        1
20.0        1
Name: CNT_FAM_MEMBERS, dtype: int64

# *********************** 新的函数 ****************************
# 矩阵的行代表了组合(例如15个组合),列代表了记录(99020)
# 行 = 一个组合选择

def cal_min_gini_mat1(mat= None, y = None):
    # 总记录数
    totals = len(y)
    # 计算左子树
    left_total = (1-mat).sum(axis=1) # 横向相加,每行的和代表一个左子树的总数
    left_weight = left_total/totals # 每行左子树对应的权重
    left_p1 = np.dot(1-mat, y) / left_total # 每行左子树的p1
    left_p0 = 1 - left_p1 # 每行左子树的p0
    l_gini = left_weight*(1-left_p1**2- left_p0**2) # 每行左子树的gini
    # 同理计算右子树
    right_total = mat.sum(axis=1)
    right_weight = right_total / totals
    right_p1 = np.dot(mat, y) / right_total
    right_p0 = 1- right_p1
    r_gini = right_weight*(1-right_p1 **2 - right_p0 **2)
    _gini = l_gini + r_gini
    return min(_gini), _gini.argmin()

# 先看结果
In [53]: cal_min_gini_mat1(mat, y)                                              
Out[53]: (0.1506995921613467, 1)

# *********************** 函数过程拆解 ****************************

# >>> 1 变量的属性集合
In [8]: var_set = set(x) 

In [9]: len(var_set)                                                            
Out[9]: 15

# >>> 2 取出一个属性值,共有15种组合
In [10]: comb_list = list(combinations(var_set, 1))                             

In [11]: comb_list                                                              
Out[11]: 
[(1.0,),
 (2.0,),
 (3.0,),
 (4.0,),
 (5.0,),
 (6.0,),
 (7.0,),
 (8.0,),
 (9.0,),
 (10.0,),
 (12.0,),
 (14.0,),
 (15.0,),
 (16.0,),
 (20.0,)]
# >>> 3 制造一个选择字典
In [14]:     comb_sel_list = []  # 选择组合用
             selectd_dict = {}   # 后续使用
    ...:     for clist in comb_list: 
    ...:         tem_selected_dict = copy.deepcopy(selectd_dict) 
    ...:         tem_selected_dict.update(list_key_dict(data=clist)) 
    ...:         comb_sel_list.append(tem_selected_dict) 
In [81]: comb_sel_list                                                          
Out[81]: 
[{1.0: 1},
 {2.0: 1},
 {3.0: 1},
 {4.0: 1},
 {5.0: 1},
 {6.0: 1},
 {7.0: 1},
 {8.0: 1},
 {9.0: 1},
 {10.0: 1},
 {12.0: 1},
 {14.0: 1},
 {15.0: 1},
 {16.0: 1},
 {20.0: 1}]
# >>> 4 使用选择字典构造矩阵(选中或未选中)
In [21]:     mat_list = [] 
    ...:     for comb_dict in comb_sel_list: 
    ...:         tem_list = x.map(comb_dict).fillna(0) 
    ...:         mat_list.append(tem_list.values) 
# 因为有15种组合,所以列表元素有15个
In [82]: len(mat_list)                                                          
Out[82]: 15
# 将列表转为矩阵,准备计算
In [23]: mat = np.array(mat_list)                                               
In [24]: mat.shape                                                              
Out[24]: (15, 99020)
# Note:  (15,99020) dot (99020,) -> (15,)  这也就是我们通过矩阵一次计算15种组合的地方

# >>> 5 函数计算gini
In [83]: # 总记录数 
    ...:     totals = len(y)                                                    

In [84]: totals                                                                 
Out[84]: 99020


# 计算左子树 
In [85]: 
    ...:     left_total = (1-mat).sum(axis=1) # 横向相加,每行的和代表一个左子树
    ...: 的总数                                                                 

In [86]: left_total                                                             
Out[86]: 
array([77104., 48055., 82120., 91072., 97895., 98899., 98994., 99010.,
       99019., 99018., 99018., 99019., 99019., 99019., 99019.])

In [87]: left_total.shape                                                       
Out[87]: (15,)

In [88]: left_weight = left_total/totals # 每行左子树对应的权重                 

In [89]: left_weight                                                            
Out[89]: 
array([0.77867098, 0.485306  , 0.82932741, 0.91973339, 0.98863866,
       0.99877802, 0.99973743, 0.99989901, 0.9999899 , 0.9999798 ,
       0.9999798 , 0.9999899 , 0.9999899 , 0.9999899 , 0.9999899 ])

In [90]: left_weight.shape                                                      
Out[90]: (15,)

In [91]: left_p1 = np.dot(1-mat, y) / left_total # 每行左子树的p1               

In [92]:                                                                        

In [92]: left_p1                                                                
Out[92]: 
array([0.08125389, 0.08781604, 0.08047979, 0.08166066, 0.08193473,
       0.08205341, 0.08211609, 0.08209272, 0.08212565, 0.08212648,
       0.08212648, 0.08212565, 0.08212565, 0.08212565, 0.08212565])

In [93]: left_p1.shape                                                          
Out[93]: (15,)

left_p0 = 1 - left_p1 # 每行左子树的p0
In [94]: l_gini = left_weight*(1-left_p1**2- left_p0**2) # 每行左子树的gini     

In [95]:                                                                        

In [95]: l_gini                                                                 
Out[95]: 
array([0.10981055, 0.0679009 , 0.11705811, 0.12964266, 0.139311  ,
       0.14072031, 0.1408452 , 0.1408718 , 0.1408792 , 0.14087764,
       0.14087764, 0.1408792 , 0.1408792 , 0.1408792 , 0.1408792 ])

In [96]: l_gini.shape                                                           
Out[96]: (15,)
# 以上计算了15中组合的左子树的gini, 同理可计算右子树,不再赘复

# 最后加总左右子树,然后取最小值和最小值位置输出
_gini = l_gini + r_gini
return min(_gini), _gini.argmin()

# 也就是上面的结果
In [53]: cal_min_gini_mat1(mat, y)                                              
Out[53]: (0.1506995921613467, 1)

# >>> 6 这个结果意味着什么呢?
In [54]: comb_list                                                              
Out[54]: 
[(1.0,),
 (2.0,),
 (3.0,),
 (4.0,),
 (5.0,),
 (6.0,),
 (7.0,),
 (8.0,),
 (9.0,),
 (10.0,),
 (12.0,),
 (14.0,),
 (15.0,),
 (16.0,),
 (20.0,)]

# 意味着最佳的分割是第二个组合,这个组合的gini是0.1506995921613467
In [55]: comb_list[1]                                                           
Out[55]: (2.0,)

# >>> 以下从这个结论进行验算
In [56]: tem_df = pd.DataFrame()                                                

In [57]: tem_df['x'] = x                                                        

In [58]: tem_df['y'] = y   

In [60]: left_tree = tem_df[tem_df['x'] ==2]    

In [62]: right_tree = tem_df[tem_df['x'] !=2]  

In [63]: left_tree.shape                                                        
Out[63]: (50965, 2)

In [64]: right_tree.shape                                                       
Out[64]: (48055, 2)

In [65]: left_tree['y'].value_counts()                                          
Out[65]: 
0    47053
1     3912
Name: y, dtype: int64

In [66]: left_weight = 50965 / 99020                                            

In [67]: left_p1 = 3912/(3912+47053)  

In [68]: left_p0 = 1- left_p1                                                   

In [69]: left_p1                                                                
Out[69]: 0.07675855979593839

In [70]: left_p0                                                                
Out[70]: 0.9232414402040616

In [71]: (1 - left_p1**2 - left_p0**2) * left_weight                            
Out[71]: 0.07294931355439885

In [72]: right_weight = 48055/ 99020                                            

In [73]: right_tree['y'].value_counts()                                         
Out[73]: 
0    43835
1     4220
Name: y, dtype: int64

In [74]: right_p1 = 4220 / (4220+43835)                                         

In [75]: right_p0 = 1-right_p1                                                  

In [76]: right_p1                                                               
Out[76]: 0.08781604411611695

In [77]: right_p0                                                               
Out[77]: 0.912183955883883

In [78]: (1- right_p1**2 - right_p0**2)* right_weight                           
Out[78]: 0.07775027860694787

# >>> 验算完毕 
In [79]: 0.07775027860694787+0.07294931355439885                                
Out[79]: 0.1506995921613467





  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107
  • 108
  • 109
  • 110
  • 111
  • 112
  • 113
  • 114
  • 115
  • 116
  • 117
  • 118
  • 119
  • 120
  • 121
  • 122
  • 123
  • 124
  • 125
  • 126
  • 127
  • 128
  • 129
  • 130
  • 131
  • 132
  • 133
  • 134
  • 135
  • 136
  • 137
  • 138
  • 139
  • 140
  • 141
  • 142
  • 143
  • 144
  • 145
  • 146
  • 147
  • 148
  • 149
  • 150
  • 151
  • 152
  • 153
  • 154
  • 155
  • 156
  • 157
  • 158
  • 159
  • 160
  • 161
  • 162
  • 163
  • 164
  • 165
  • 166
  • 167
  • 168
  • 169
  • 170
  • 171
  • 172
  • 173
  • 174
  • 175
  • 176
  • 177
  • 178
  • 179
  • 180
  • 181
  • 182
  • 183
  • 184
  • 185
  • 186
  • 187
  • 188
  • 189
  • 190
  • 191
  • 192
  • 193
  • 194
  • 195
  • 196
  • 197
  • 198
  • 199
  • 200
  • 201
  • 202
  • 203
  • 204
  • 205
  • 206
  • 207
  • 208
  • 209
  • 210
  • 211
  • 212
  • 213
  • 214
  • 215
  • 216
  • 217
  • 218
  • 219
  • 220
  • 221
  • 222
  • 223
  • 224
  • 225
  • 226
  • 227
  • 228
  • 229
  • 230
  • 231
  • 232
  • 233
  • 234
  • 235
  • 236
  • 237
  • 238
  • 239
  • 240
  • 241
  • 242
  • 243
  • 244
  • 245
  • 246
  • 247
  • 248
  • 249
  • 250
  • 251
  • 252
  • 253
  • 254
  • 255
  • 256
  • 257
  • 258
  • 259
  • 260
  • 261
  • 262
  • 263
  • 264
  • 265
  • 266
  • 267
  • 268
  • 269
  • 270
  • 271
  • 272
  • 273
  • 274
  • 275
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/从前慢现在也慢/article/detail/591225
推荐阅读
相关标签
  

闽ICP备14008679号