赞
踩
import pandas as pd import numpy as np #读取数据 def load_data(filename): data = pd.read_table(filename, sep ='\t', header = None) X_data = data.loc[:][0] Y_data = data.loc[:][1] return X_data, Y_data class weakLearner(): def __init__(self, min_sample, min_err): self.min_sample = min_sample self.min_err = min_err # 计算均方误差 def __mse(self, left_node): return np.sum((np.average(left_node) - np.array(left_node)) ** 2) # 特征值从小到大排序好,错位相加,生成候选划分点 def Get_stump_list(self, X): tmp1 = list(X.copy()) tmp2 = list(X.copy()) tmp1.insert(0, 0) tmp2.append(0) stump_list = ((np.array(tmp1) + np.array(tmp2)) / float(2))[1:-1] return stump_list #根据x与候选划分点进行比较,二分数据集 def __binSplitData(self, stump_list, X, Y): left_node = [] left_node_x = [] right_node = [] right_node_y = [] for j in range(np.shape(X)[0]): if X[j] < stump_list: left_node.append(Y[j]) right_node else: right_node.append(Y[j]) return left_node, right_node # 根据最小均方误差选择最佳划分点 def __bestSplit(self, stump_list, X, Y): best_mse = np.inf for i in range(np.shape(stump_list)[0]): left_node, right_node = self.__binSplitData(stump_list[i], X, Y) left_mse = self.__mse(left_node) right_mse = self.__mse(right_node) if best_mse > (left_mse + right_mse): best_mse = left_mse + right_mse best_f_val = stump_list[i] return best_f_val # 建立CART树 def __CART(self, X, Y): tree = dict() if len(X) <= self.min_sample: return np.mean(Y) stump_list = self.Get_stump_list(X) #生成候选分割节点 best_f_val = self.__bestSplit(stump_list, X, Y) #获得最佳分割节点 tree['cut_f_val'] = best_f_val # 最佳划分点 left_node, right_node = self.__binSplitData(best_f_val, X, Y) tree['left_tree'] = left_node # 存储左分支的所有样本的标签值 tree['right_tree'] = right_node left_mse = self.__mse(left_node) right_mse = self.__mse(right_node) tree['left_mse'] = left_mse tree['right_mse'] = right_mse now_mse = left_mse + right_mse if now_mse <= self.min_err: return np.mean(Y) return tree # 训练CART树 def train(self, X, Y): self.tree = self.__CART(X, Y) return self.tree #CART预测 def predict(self, X): return np.array([self.__predict_one(x, self.tree) for x in X]) #预测一个样本 def __predict_one(self, x, tree): cut_val = tree['cut_f_val'] Y_left = np.average(tree['left_tree']) Y_right = np.average(tree['right_tree']) result = Y_left if x <= cut_val else Y_right return result class GBDT(): def __init__(self, n_estimators: int = 10, classifier = weakLearner): self.n_estimators = n_estimators self.weakLearner = classifier self.Trees = [] self.learn_rate = 1 self.init_value = None # 预测值的初始化为平均值 def get_init_value(self, Y): """ :param y: 样本标签列表 :return: average(float)"样本标签的平均值 """ average = sum(Y) / len(Y) return average #计算损失函数的负梯度来近似残差,本项目采用均方损失,其负梯度为 Y - y_hat def get_residuals(self, Y, y_hat): y_residuals = [] for i in range(len(Y)): y_residuals.append(Y[i] - y_hat[i]) return y_residuals def fit(self, X, Y, n_estimators, learn_rate, min_sample, min_err): self.n_estimators = n_estimators # 默认为 10 self.learn_rate = learn_rate X = np.array(X) # 把列表转化为数组 Y = np.array(Y) #初始化模型 self.init_value = self.get_init_value(Y) n = len(Y) # 样本个数 y_hat = [self.init_value] * n # 初始化残差 residual = self.get_residuals(Y, y_hat) # 初始化GBDT的树模型列表 self.Trees = [] #迭代训练GBDT,生成n_estimators个树模型 for num in range(self.n_estimators): # 每次新生成树后,还需要再次更新残差residual wl = self.weakLearner(min_sample, min_err) # 实例化弱学习器 Tree = wl.train(X, residual) # 用上一颗树的残差拟合下一个树 # 计算当前模型的残差 Y_left = np.average(Tree['left_tree']) # 计算左叶子节点 Y_right = np.average(Tree['right_tree']) # 计算左叶子节点 left_residual = self.get_residuals(np.array(Tree['left_tree']), [Y_left] * n) # 计算左边的残差 right_residual = self.get_residuals(np.array(Tree['right_tree']), [Y_right] * n) # 计算右边的残差 residual = np.append(left_residual, right_residual) # 合并残差 # 计算此时的预测值=原始预测值+(学习率*残差预测值) for i in range(n): y_hat[i] = y_hat[i] + self.learn_rate * residual[i] # y_hat = [y_hat[i] + self.learn_rate * residual[i] for i in n] # 更新残差 residual = self.get_residuals(Y, y_hat) self.Trees.append(wl) # 存储训练好的模型 return self.Trees #模型预测 def predict(self, X): M = self.n_estimators #弱学习器的个数 y_ = 0 #初始化预测值为0 for m in range(M): y_ += [self.init_value] * len(X) + self.learn_rate * self.Trees[m].predict(X) return y_ #计算模型预测误差 def error(self, Y, y_predict): error = np.square(Y - y_predict).sum() / len(Y) return error if __name__ == "__main__": print("-----------------------------------1.load data------------------------------------------") X, Y = load_data("./sine.txt") # 按照3:1划分数据集 X_train = X[0:150] Y_train = Y[0:150] X_test = X[150: 200] Y_test = Y[150: 200] print("-----------------------------------2. Parameters Setting--------------------------------") n_estimators = 4 # 基学习器个数 learn_rate = 0.5 # 学习率 min_sample = 30 # 最小样本数 min_err = 0.3 # 最小误差 print("----------------------------------3.build GBDT--------------------------------------------") Trees_reg = GBDT() # 实例化提升树 Trees = Trees_reg.fit(X_train, Y_train, n_estimators, learn_rate, min_sample, min_err) #拟合模型 print("----------------------------------4.Predict Result-----------------------------------------") y_predict = Trees_reg.predict(X_test) #预测 print("Y_test: ", np.mat(Y_test)) print("predict_results: ", y_predict) print("----------------------------------5.Predict error-------------------------------------------") error = Trees_reg.error(Y_test, y_predict) #计算损失 print("The error is ", error)
0.190350 0.878049 0.306657 -0.109413 0.017568 0.030917 0.122328 0.951109 0.076274 0.774632 0.614127 -0.250042 0.220722 0.807741 0.089430 0.840491 0.278817 0.342210 0.520287 -0.950301 0.726976 0.852224 0.180485 1.141859 0.801524 1.012061 0.474273 -1.311226 0.345116 -0.319911 0.981951 -0.374203 0.127349 1.039361 0.757120 1.040152 0.345419 -0.429760 0.314532 -0.075762 0.250828 0.657169 0.431255 -0.905443 0.386669 -0.508875 0.143794 0.844105 0.470839 -0.951757 0.093065 0.785034 0.205377 0.715400 0.083329 0.853025 0.243475 0.699252 0.062389 0.567589 0.764116 0.834931 0.018287 0.199875 0.973603 -0.359748 0.458826 -1.113178 0.511200 -1.082561 0.712587 0.615108 0.464745 -0.835752 0.984328 -0.332495 0.414291 -0.808822 0.799551 1.072052 0.499037 -0.924499 0.966757 -0.191643 0.756594 0.991844 0.444938 -0.969528 0.410167 -0.773426 0.532335 -0.631770 0.343909 -0.313313 0.854302 0.719307 0.846882 0.916509 0.740758 1.009525 0.150668 0.832433 0.177606 0.893017 0.445289 -0.898242 0.734653 0.787282 0.559488 -0.663482 0.232311 0.499122 0.934435 -0.121533 0.219089 0.823206 0.636525 0.053113 0.307605 0.027500 0.713198 0.693978 0.116343 1.242458 0.680737 0.368910 0.484730 -0.891940 0.929408 0.234913 0.008507 0.103505 0.872161 0.816191 0.755530 0.985723 0.620671 0.026417 0.472260 -0.967451 0.257488 0.630100 0.130654 1.025693 0.512333 -0.884296 0.747710 0.849468 0.669948 0.413745 0.644856 0.253455 0.894206 0.482933 0.820471 0.899981 0.790796 0.922645 0.010729 0.032106 0.846777 0.768675 0.349175 -0.322929 0.453662 -0.957712 0.624017 -0.169913 0.211074 0.869840 0.062555 0.607180 0.739709 0.859793 0.985896 -0.433632 0.782088 0.976380 0.642561 0.147023 0.779007 0.913765 0.185631 1.021408 0.525250 -0.706217 0.236802 0.564723 0.440958 -0.993781 0.397580 -0.708189 0.823146 0.860086 0.370173 -0.649231 0.791675 1.162927 0.456647 -0.956843 0.113350 0.850107 0.351074 -0.306095 0.182684 0.825728 0.914034 0.305636 0.751486 0.898875 0.216572 0.974637 0.013273 0.062439 0.469726 -1.226188 0.060676 0.599451 0.776310 0.902315 0.061648 0.464446 0.714077 0.947507 0.559264 -0.715111 0.121876 0.791703 0.330586 -0.165819 0.662909 0.379236 0.785142 0.967030 0.161352 0.979553 0.985215 -0.317699 0.457734 -0.890725 0.171574 0.963749 0.334277 -0.266228 0.501065 -0.910313 0.988736 -0.476222 0.659242 0.218365 0.359861 -0.338734 0.790434 0.843387 0.462458 -0.911647 0.823012 0.813427 0.594668 -0.603016 0.498207 -0.878847 0.574882 -0.419598 0.570048 -0.442087 0.331570 -0.347567 0.195407 0.822284 0.814327 0.974355 0.641925 0.073217 0.238778 0.657767 0.400138 -0.715598 0.670479 0.469662 0.069076 0.680958 0.294373 0.145767 0.025628 0.179822 0.697772 0.506253 0.729626 0.786519 0.293071 0.259997 0.531802 -1.095833 0.487338 -1.034481 0.215780 0.933506 0.625818 0.103845 0.179389 0.892237 0.192552 0.915516 0.671661 0.330361 0.952391 -0.060263 0.795133 0.945157 0.950494 -0.071855 0.194894 1.000860 0.351460 -0.227946 0.863456 0.648456 0.945221 -0.045667 0.779840 0.979954 0.996606 -0.450501 0.632184 -0.036506 0.790898 0.994890 0.022503 0.386394 0.318983 -0.152749 0.369633 -0.423960 0.157300 0.962858 0.153223 0.882873 0.360068 -0.653742 0.433917 -0.872498 0.133461 0.879002 0.757252 1.123667 0.309391 -0.102064 0.195586 0.925339 0.240259 0.689117 0.340591 -0.455040 0.243436 0.415760 0.612755 -0.180844 0.089407 0.723702 0.469695 -0.987859 0.943560 -0.097303 0.177241 0.918082 0.317756 -0.222902 0.515337 -0.733668 0.344773 -0.256893 0.537029 -0.797272 0.626878 0.048719 0.208940 0.836531 0.470697 -1.080283 0.054448 0.624676 0.109230 0.816921 0.158325 1.044485 0.976650 -0.309060 0.643441 0.267336 0.215841 1.018817 0.905337 0.409871 0.154354 0.920009 0.947922 -0.112378 0.201391 0.768894
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。