当前位置:   article > 正文

GBDT回归(决策树桩) python实现_sine.txt

sine.txt

GBDT回归(决策树桩) python实现

1.代码

import pandas as pd
import numpy as np

#读取数据
def load_data(filename):

    data = pd.read_table(filename, sep ='\t', header = None)
    X_data = data.loc[:][0]
    Y_data = data.loc[:][1]

    return X_data, Y_data


class weakLearner():

    def __init__(self, min_sample, min_err):

        self.min_sample = min_sample
        self.min_err = min_err

    # 计算均方误差
    def __mse(self, left_node):

        return np.sum((np.average(left_node) - np.array(left_node)) ** 2)

    # 特征值从小到大排序好,错位相加,生成候选划分点
    def Get_stump_list(self, X):
        tmp1 = list(X.copy())
        tmp2 = list(X.copy())
        tmp1.insert(0, 0)
        tmp2.append(0)
        stump_list = ((np.array(tmp1) + np.array(tmp2)) / float(2))[1:-1]

        return stump_list

    #根据x与候选划分点进行比较,二分数据集
    def __binSplitData(self, stump_list, X, Y):
        left_node = []
        left_node_x = []
        right_node = []
        right_node_y = []
        for j in range(np.shape(X)[0]):
            if X[j] < stump_list:
                left_node.append(Y[j])
                right_node
            else:
                right_node.append(Y[j])

        return left_node, right_node

    # 根据最小均方误差选择最佳划分点
    def __bestSplit(self, stump_list, X, Y):
        best_mse = np.inf
        for i in range(np.shape(stump_list)[0]):
            left_node, right_node = self.__binSplitData(stump_list[i], X, Y)
            left_mse = self.__mse(left_node)
            right_mse = self.__mse(right_node)

            if best_mse > (left_mse + right_mse):
                best_mse = left_mse + right_mse
                best_f_val = stump_list[i]
        return best_f_val

    # 建立CART树
    def __CART(self, X, Y):
        tree = dict()

        if len(X) <= self.min_sample:
            return np.mean(Y)

        stump_list = self.Get_stump_list(X) #生成候选分割节点
        best_f_val = self.__bestSplit(stump_list, X, Y) #获得最佳分割节点

        tree['cut_f_val'] = best_f_val  # 最佳划分点
        left_node, right_node = self.__binSplitData(best_f_val, X, Y)

        tree['left_tree'] = left_node  # 存储左分支的所有样本的标签值
        tree['right_tree'] = right_node
        left_mse = self.__mse(left_node)
        right_mse = self.__mse(right_node)
        tree['left_mse'] = left_mse
        tree['right_mse'] = right_mse


        now_mse = left_mse + right_mse
        if now_mse <= self.min_err:
            return np.mean(Y)

        return tree

    # 训练CART树
    def train(self, X, Y):

        self.tree = self.__CART(X, Y)

        return self.tree

    #CART预测
    def predict(self, X):
        return np.array([self.__predict_one(x, self.tree) for x in X])

    #预测一个样本
    def __predict_one(self, x, tree):
        cut_val = tree['cut_f_val']
        Y_left = np.average(tree['left_tree'])
        Y_right = np.average(tree['right_tree'])

        result = Y_left if x <= cut_val else Y_right

        return result



class GBDT():
    def __init__(self, n_estimators: int = 10, classifier = weakLearner):
        self.n_estimators = n_estimators
        self.weakLearner = classifier
        self.Trees = []
        self.learn_rate = 1
        self.init_value = None



    # 预测值的初始化为平均值
    def get_init_value(self, Y):
        """
        :param y: 样本标签列表
        :return: average(float)"样本标签的平均值
        """
        average = sum(Y) / len(Y)
        return average

    #计算损失函数的负梯度来近似残差,本项目采用均方损失,其负梯度为 Y - y_hat
    def get_residuals(self, Y, y_hat):

        y_residuals = []

        for i in range(len(Y)):
            y_residuals.append(Y[i] - y_hat[i])

        return y_residuals


    def fit(self, X, Y, n_estimators, learn_rate, min_sample, min_err):
        self.n_estimators = n_estimators  # 默认为 10
        self.learn_rate = learn_rate

        X = np.array(X)  # 把列表转化为数组
        Y = np.array(Y)

        #初始化模型
        self.init_value = self.get_init_value(Y)
        n = len(Y)  # 样本个数
        y_hat = [self.init_value] * n

        # 初始化残差
        residual = self.get_residuals(Y, y_hat)

        # 初始化GBDT的树模型列表
        self.Trees = []


        #迭代训练GBDT,生成n_estimators个树模型
        for num in range(self.n_estimators):
            # 每次新生成树后,还需要再次更新残差residual
            wl = self.weakLearner(min_sample, min_err)  # 实例化弱学习器
            Tree = wl.train(X, residual)  # 用上一颗树的残差拟合下一个树

            # 计算当前模型的残差
            Y_left = np.average(Tree['left_tree'])  # 计算左叶子节点
            Y_right = np.average(Tree['right_tree'])  # 计算左叶子节点
            left_residual = self.get_residuals(np.array(Tree['left_tree']), [Y_left] * n)  # 计算左边的残差
            right_residual = self.get_residuals(np.array(Tree['right_tree']), [Y_right] * n)  # 计算右边的残差
            residual = np.append(left_residual, right_residual)  # 合并残差
            # 计算此时的预测值=原始预测值+(学习率*残差预测值)
            for i in range(n):
                y_hat[i] = y_hat[i] + self.learn_rate * residual[i]
            # y_hat = [y_hat[i] + self.learn_rate * residual[i] for i in n]

            # 更新残差
            residual = self.get_residuals(Y, y_hat)
            self.Trees.append(wl)  # 存储训练好的模型

        return self.Trees

    #模型预测
    def predict(self, X):

        M = self.n_estimators #弱学习器的个数
        y_ = 0 #初始化预测值为0
        for m in range(M):
            y_ += [self.init_value] * len(X) + self.learn_rate * self.Trees[m].predict(X)

        return y_

    #计算模型预测误差
    def error(self, Y, y_predict):

        error = np.square(Y - y_predict).sum() / len(Y)

        return error






if __name__ == "__main__":
    print("-----------------------------------1.load data------------------------------------------")
    X, Y = load_data("./sine.txt")

    # 按照3:1划分数据集
    X_train = X[0:150]
    Y_train = Y[0:150]
    X_test = X[150: 200]
    Y_test = Y[150: 200]
    print("-----------------------------------2. Parameters Setting--------------------------------")
    n_estimators = 4 # 基学习器个数
    learn_rate = 0.5  # 学习率
    min_sample = 30  # 最小样本数
    min_err = 0.3  # 最小误差
    print("----------------------------------3.build GBDT--------------------------------------------")
    Trees_reg = GBDT()  # 实例化提升树
    Trees = Trees_reg.fit(X_train, Y_train, n_estimators, learn_rate, min_sample, min_err) #拟合模型
    print("----------------------------------4.Predict Result-----------------------------------------")
    y_predict = Trees_reg.predict(X_test) #预测
    print("Y_test: ", np.mat(Y_test))
    print("predict_results: ", y_predict)
    print("----------------------------------5.Predict error-------------------------------------------")
    error = Trees_reg.error(Y_test, y_predict) #计算损失
    print("The error is ", error)


  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107
  • 108
  • 109
  • 110
  • 111
  • 112
  • 113
  • 114
  • 115
  • 116
  • 117
  • 118
  • 119
  • 120
  • 121
  • 122
  • 123
  • 124
  • 125
  • 126
  • 127
  • 128
  • 129
  • 130
  • 131
  • 132
  • 133
  • 134
  • 135
  • 136
  • 137
  • 138
  • 139
  • 140
  • 141
  • 142
  • 143
  • 144
  • 145
  • 146
  • 147
  • 148
  • 149
  • 150
  • 151
  • 152
  • 153
  • 154
  • 155
  • 156
  • 157
  • 158
  • 159
  • 160
  • 161
  • 162
  • 163
  • 164
  • 165
  • 166
  • 167
  • 168
  • 169
  • 170
  • 171
  • 172
  • 173
  • 174
  • 175
  • 176
  • 177
  • 178
  • 179
  • 180
  • 181
  • 182
  • 183
  • 184
  • 185
  • 186
  • 187
  • 188
  • 189
  • 190
  • 191
  • 192
  • 193
  • 194
  • 195
  • 196
  • 197
  • 198
  • 199
  • 200
  • 201
  • 202
  • 203
  • 204
  • 205
  • 206
  • 207
  • 208
  • 209
  • 210
  • 211
  • 212
  • 213
  • 214
  • 215
  • 216
  • 217
  • 218
  • 219
  • 220
  • 221
  • 222
  • 223
  • 224
  • 225
  • 226
  • 227
  • 228
  • 229
  • 230
  • 231
  • 232
  • 233

2.数据集 sine.txt

0.190350	0.878049
0.306657	-0.109413
0.017568	0.030917
0.122328	0.951109
0.076274	0.774632
0.614127	-0.250042
0.220722	0.807741
0.089430	0.840491
0.278817	0.342210
0.520287	-0.950301
0.726976	0.852224
0.180485	1.141859
0.801524	1.012061
0.474273	-1.311226
0.345116	-0.319911
0.981951	-0.374203
0.127349	1.039361
0.757120	1.040152
0.345419	-0.429760
0.314532	-0.075762
0.250828	0.657169
0.431255	-0.905443
0.386669	-0.508875
0.143794	0.844105
0.470839	-0.951757
0.093065	0.785034
0.205377	0.715400
0.083329	0.853025
0.243475	0.699252
0.062389	0.567589
0.764116	0.834931
0.018287	0.199875
0.973603	-0.359748
0.458826	-1.113178
0.511200	-1.082561
0.712587	0.615108
0.464745	-0.835752
0.984328	-0.332495
0.414291	-0.808822
0.799551	1.072052
0.499037	-0.924499
0.966757	-0.191643
0.756594	0.991844
0.444938	-0.969528
0.410167	-0.773426
0.532335	-0.631770
0.343909	-0.313313
0.854302	0.719307
0.846882	0.916509
0.740758	1.009525
0.150668	0.832433
0.177606	0.893017
0.445289	-0.898242
0.734653	0.787282
0.559488	-0.663482
0.232311	0.499122
0.934435	-0.121533
0.219089	0.823206
0.636525	0.053113
0.307605	0.027500
0.713198	0.693978
0.116343	1.242458
0.680737	0.368910
0.484730	-0.891940
0.929408	0.234913
0.008507	0.103505
0.872161	0.816191
0.755530	0.985723
0.620671	0.026417
0.472260	-0.967451
0.257488	0.630100
0.130654	1.025693
0.512333	-0.884296
0.747710	0.849468
0.669948	0.413745
0.644856	0.253455
0.894206	0.482933
0.820471	0.899981
0.790796	0.922645
0.010729	0.032106
0.846777	0.768675
0.349175	-0.322929
0.453662	-0.957712
0.624017	-0.169913
0.211074	0.869840
0.062555	0.607180
0.739709	0.859793
0.985896	-0.433632
0.782088	0.976380
0.642561	0.147023
0.779007	0.913765
0.185631	1.021408
0.525250	-0.706217
0.236802	0.564723
0.440958	-0.993781
0.397580	-0.708189
0.823146	0.860086
0.370173	-0.649231
0.791675	1.162927
0.456647	-0.956843
0.113350	0.850107
0.351074	-0.306095
0.182684	0.825728
0.914034	0.305636
0.751486	0.898875
0.216572	0.974637
0.013273	0.062439
0.469726	-1.226188
0.060676	0.599451
0.776310	0.902315
0.061648	0.464446
0.714077	0.947507
0.559264	-0.715111
0.121876	0.791703
0.330586	-0.165819
0.662909	0.379236
0.785142	0.967030
0.161352	0.979553
0.985215	-0.317699
0.457734	-0.890725
0.171574	0.963749
0.334277	-0.266228
0.501065	-0.910313
0.988736	-0.476222
0.659242	0.218365
0.359861	-0.338734
0.790434	0.843387
0.462458	-0.911647
0.823012	0.813427
0.594668	-0.603016
0.498207	-0.878847
0.574882	-0.419598
0.570048	-0.442087
0.331570	-0.347567
0.195407	0.822284
0.814327	0.974355
0.641925	0.073217
0.238778	0.657767
0.400138	-0.715598
0.670479	0.469662
0.069076	0.680958
0.294373	0.145767
0.025628	0.179822
0.697772	0.506253
0.729626	0.786519
0.293071	0.259997
0.531802	-1.095833
0.487338	-1.034481
0.215780	0.933506
0.625818	0.103845
0.179389	0.892237
0.192552	0.915516
0.671661	0.330361
0.952391	-0.060263
0.795133	0.945157
0.950494	-0.071855
0.194894	1.000860
0.351460	-0.227946
0.863456	0.648456
0.945221	-0.045667
0.779840	0.979954
0.996606	-0.450501
0.632184	-0.036506
0.790898	0.994890
0.022503	0.386394
0.318983	-0.152749
0.369633	-0.423960
0.157300	0.962858
0.153223	0.882873
0.360068	-0.653742
0.433917	-0.872498
0.133461	0.879002
0.757252	1.123667
0.309391	-0.102064
0.195586	0.925339
0.240259	0.689117
0.340591	-0.455040
0.243436	0.415760
0.612755	-0.180844
0.089407	0.723702
0.469695	-0.987859
0.943560	-0.097303
0.177241	0.918082
0.317756	-0.222902
0.515337	-0.733668
0.344773	-0.256893
0.537029	-0.797272
0.626878	0.048719
0.208940	0.836531
0.470697	-1.080283
0.054448	0.624676
0.109230	0.816921
0.158325	1.044485
0.976650	-0.309060
0.643441	0.267336
0.215841	1.018817
0.905337	0.409871
0.154354	0.920009
0.947922	-0.112378
0.201391	0.768894

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107
  • 108
  • 109
  • 110
  • 111
  • 112
  • 113
  • 114
  • 115
  • 116
  • 117
  • 118
  • 119
  • 120
  • 121
  • 122
  • 123
  • 124
  • 125
  • 126
  • 127
  • 128
  • 129
  • 130
  • 131
  • 132
  • 133
  • 134
  • 135
  • 136
  • 137
  • 138
  • 139
  • 140
  • 141
  • 142
  • 143
  • 144
  • 145
  • 146
  • 147
  • 148
  • 149
  • 150
  • 151
  • 152
  • 153
  • 154
  • 155
  • 156
  • 157
  • 158
  • 159
  • 160
  • 161
  • 162
  • 163
  • 164
  • 165
  • 166
  • 167
  • 168
  • 169
  • 170
  • 171
  • 172
  • 173
  • 174
  • 175
  • 176
  • 177
  • 178
  • 179
  • 180
  • 181
  • 182
  • 183
  • 184
  • 185
  • 186
  • 187
  • 188
  • 189
  • 190
  • 191
  • 192
  • 193
  • 194
  • 195
  • 196
  • 197
  • 198
  • 199
  • 200
  • 201

3.实验结果(精确度不高,下次考虑更深的树)

在这里插入图片描述

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/小小林熬夜学编程/article/detail/532977
推荐阅读
相关标签
  

闽ICP备14008679号