赞
踩
树回归
CART(Classification And Regression Tree, 分类回归树)
完整代码见github
环境 python3
决策树不断将数据切分成小数据集,直到所有的目标变量完全相同,或者数据不能再切分为止。
决策树是一种贪心算法,它要在给定的时间内做出最佳选择,但不关心是否达到全局最优。
ID3的做法是每次选取当前最佳的特征来分割数据,并按照该特征的所有取值进行切分,一旦按某种特征进行切分后,该特征在之后的算法执行过程中就不在起作用。这种切分方式过于迅速且不能处理连续型特征,只有事先将连续型特征转换为离散型特征,才能在ID3算法中使用,但是这种转换会破坏连续型变量的内在性质。
详细介绍可以参照决策树python实现(ID3 和 C4.5)
优点:可以对复杂和非线性的数据建模
缺点:结果不易理解
适用数据类型:数值型和标称型数据
CART使用二元切分法来处理连续型变量,对CART稍作修改就可以处理回归问题。
二元切分法:每次把数据集切分成两份,如果数据的某特征值大于切分所要求的值,那么这些数据进入树的左子树,反之则进入树的右子树。二元切分法节省了树的构建时间,这点意义不大,因为树的构建一般是离线完成。
树回归的一般方法:
(1)收集数据:任意方法
(2)准备数据:需要数值型数据,标称型数据应该映射成二值型数据
(3)分析数据:绘出数据的二维可视化显示结果,以字典方式生成树
(4)训练算法:大部分时间都花在叶节点树模型的构建上
(5)测试算法:使用测试数据上的R^2值来分析模型的效果
(6)使用算法:使用训练出的树做预测,预测结果可以用来做许多事情
用字典来存储树的数据结构,该字典包含以下4元素:
ID3用一部字典来存储每个切分,该字典可以包含两个或两个以上的值。
CART算法只做二元切分,所以这里可以固定树的数据结构。树包含左键和右键,可以存储另一棵子树或者单个值;字典还包含特征和特征值这两个键,它们给出切分算法所有的特征和特征值。
本章构建两种树:回归树和模型树
回归树:其每个节点包含单个值
模型树:其每个节点包含一个线性方程
函数createTree()的伪代码
找到最佳切分特征:
如果该节点不能再分,将该节点存为叶节点
执行二元切分
在右子树调用createTree()方法
在左子树调用createTree()方法
CART算法的实现代码:
from numpy import * import numpy as np def loadDataSet(filename): # 加载数据集 ''' :param filename: :return: 数据+标签列表 ''' dataMat = [] fr = open(filename) for line in fr.readlines(): curLine = line.strip().split('\t') fltLine = list(map(float, curLine)) dataMat.append(fltLine) return dataMat def binSplitDataSet(dataSet, feature, value): ''' :param dataSet:数据集合 :param feature: 待切分的特征 :param value: 该特征的某个值 :return: 通过数组过滤的方式将上述数据集合切分得到两个子集并返回 ''' mat0 = dataSet[nonzero(dataSet[:, feature] > value)[0], :] mat1 = dataSet[nonzero(dataSet[:, feature] <= value)[0], :] return mat0, mat1 def createTree(dataSet, leafType=regLeaf, errType=regErr, ops=(1, 4)): ''' 递归函数:树构建 :param dataSet:数据集 :param leafType:建立叶节点函数 :param errType: 误差计算函数 :param ops: 包含树构建所需的其他函数的元组(tolS,tolN), tolS:容许的误差下降值,tolN:切分的最少样本数 :return: ''' feat, val = chooseBestSplit(dataSet, leafType, errType, ops) # 满足停止条件时返回叶节点值 if feat == None: return val retTree = {} retTree['spInd'] = feat retTree['spVal'] = val lSet, rSet = binSplitDataSet(dataSet, feat, val) retTree['left'] = createTree(lSet, leafType, errType, ops) retTree['right'] = createTree(rSet, leafType, errType, ops) return retTree
模型树:把叶节点设定为分段线性函数,模型树的可解释性是其优于回归树的特点之一
误差计算:对于给定数据集,应先用线性模型来对它进行拟合,然后计算真实目标值和模型预测值之间的差值,最后将这些差值的平方求和就得到了所需的误差。
补充一些新的代码,使createTree()运行。
实现chooseBestSplit函数:给定某个误差计算方法,该函数会找到数据集上的最佳二元切分方式。该函数还要确定什么时候停止切分,一旦停止切分会生成一个叶节点。即用最佳方式切分数据集和生成相应的叶节点。
伪代码:
对每个特征:
对每个特征值:
将数据切分成两份
计算切分误差
如果当前误差小于当前最小误差,则将当前切分设定为最佳切分并更新最小误差
返回最佳切分的特征和阈值
python代码如下:
def regLeaf(dataSet): ''' 生成叶节点 :param dataSet: 数据集 :return: ''' return mean(dataSet[:, -1]) def regErr(dataSet): ''' 误差估计函数,在给定数据上计算目标变量的平方误差,总方差 :param dataSet: :return: ''' return var(dataSet[:, -1])*shape(dataSet)[0] def chooseBestSplit(dataSet, leafType=regLeaf, errType=regErr, ops=(1,4)): ''' 找到数据的最佳二元切分方式 :param dataSet: 数据集 :param leafType:建立叶节点的函数 :param errType: 误差计算函数 :param ops: 包含树构建所需其他参数的元组 :return: 特征编号和切分特征值 ''' tolS = ops[0] # 容许的误差下降值 tolN = ops[1] # 切分的最少样本数 if len(set(dataSet[:, -1].T.tolist()[0])) == 1: return None, leafType(dataSet) m, n = shape(dataSet) # 当前数据集的大小 S = errType(dataSet) # 当前数据集的误差 bestS = inf bestIndex = 0 bestValue = 0 for featIndex in range(n-1): for splitVal in set((dataSet[:, featIndex].T.A.tolist())[0]): mat0, mat1 = binSplitDataSet(dataSet, featIndex, splitVal) if (shape(mat0)[0])<tolN or shape(mat1)[0]<tolN: continue newS = errType(mat0) + errType(mat1) if newS < bestS: bestIndex = featIndex bestS = newS bestValue = splitVal # 切分后如果误差减少不大,则不应该进行切分操作,直接创建叶节点 if S-bestS < tolS: return None, leafType(dataSet) mat0, mat1 = binSplitDataSet(dataSet, bestIndex, bestValue) # 切分后的两个子集大小是否小于用户自定义的大小,则不应切分 if shape(mat0)[0] < tolN or shape(mat1)[0] < tolN: return None, leafType(dataSet) return bestIndex, bestValue
运行测试代码:
myDat = loadDataSet('ex00.txt')
# # plot_db(myDat)
myMat = mat(myDat)
retTree = createTree(myMat)
print('retTree ', retTree)
结果:
retTree {'right': -0.04465028571428572, 'left': 1.0180967672413792, 'spVal': 0.48813, 'spInd': 0}
可以看到,树中包含2个叶节点。
运行测试代码:
myDat1 = loadDataSet('ex0.txt')
plot_db(myDat1)
myMat1 = mat(myDat1)
retTree1 = createTree(myMat1)
print('retTree1 ', retTree1)
结果:
retTree1 {'spVal': 0.39435, 'left': {'spVal': 0.582002, 'left': {'spVal': 0.797583, 'left': 3.9871632, 'spInd': 1, 'right': 2.9836209534883724}, 'spInd': 1, 'right': 1.980035071428571}, 'spInd': 1, 'right': {'spVal': 0.197834, 'left': 1.0289583666666666, 'spInd': 1, 'right': -0.023838155555555553}}
可以看到,树中包含5个叶节点。
如果一棵树的节点过多,表明该模型可能发生了过拟合==。
之前的算法都是使用了测试集上某种交叉验证技术来发现过拟合。决策树也是如此。
通过降低决策树的复杂度来避免过拟合的过程称为剪枝。
在函数chooseBestSplit()中提前终止条件,实际上是所谓的预剪枝操作。
另一种形式的剪枝需要使用测试集和训练集,称作后剪枝。
树构建算法对输入的参数tolS和tolN非常敏感,如果选用其他值,构建的树效果不太好,例如,测试代码:
myDat = loadDataSet('ex00.txt')
# plot_db(myDat)
myMat = mat(myDat)
retTree = createTree(myMat)
print('retTree ', retTree)
retTree_1 = createTree(myMat, ops=(0, 1))
print('retTree_1 ', retTree_1)
结果为:
retTree {'left': 1.0180967672413792, 'spInd': 0, 'spVal': 0.48813, 'right': -0.04465028571428572}
retTree_1 {'left': {'left': {'left': {'left': {'left': {'left': {'left': {'left': {'left': {'left': {'left': {'left': {'left': {'left': {'left': {'left': 1.035533, 'spInd': 0, 'spVal': 0.993349, 'right': 1.077553}, 'spInd': 0, 'spVal': 0.989888, 'right': {'left': 0.744207, 'spInd': 0, 'spVal': 0.988852, 'right': 1.069062}}, 'spInd': 0, 'spVal': 0.985425, 'right': 1.227946}, 'spInd': 0, 'spVal': 0.976414, 'right': {'left': {'left': 0.862911, 'spInd': 0, 'spVal': 0.975022, 'right': 0.673579}, 'spInd': 0, 'spVal': 0.953112, 'right': {'left': {'left': 1.06469, 'spInd': 0, 'spVal': 0.951949, 'right': {'left': 0.945255, 'spInd': 0, 'spVal': 0.950153, 'right': 1.022906}}, 'spInd': 0, 'spVal': 0.948268, 'right': {'left': 0.631862, 'spInd': 0, 'spVal': 0.936783, 'right': {'left': {'left': 1.026258, 'spInd': 0, 'spVal': 0.930173, 'right': 1.035645}, 'spInd': 0, 'spVal': 0.928097, 'right': 0.883225}}}}}, 'spInd': 0, 'spVal': 0.919384, 'right': {'left': {'left': {'left': {'left': 1.029889, 'spInd': 0, 'spVal': 0.919074, 'right': 1.123413}, 'spInd': 0, 'spVal': 0.902532, 'right': {'left': 0.861601, 'spInd': 0, 'spVal': 0.901056, 'right': {'left': 1.0559, 'spInd': 0, 'spVal': 0.900272, 'right': 0.996871}}}, 'spInd': 0, 'spVal': 0.897094, 'right': {'left': 1.240209, 'spInd': 0, 'spVal': 0.89593, 'right': {'left': 1.077275, 'spInd': 0, 'spVal': 0.884512, 'right': 1.117833}}}, 'spInd': 0, 'spVal': 0.877241, 'right': {'left': {'left': {'left': 0.797005, 'spInd': 0, 'spVal': 0.869077, 'right': 1.114825}, 'spInd': 0, 'spVal': 0.860049, 'right': 0.71749}, 'spInd': 0, 'spVal': 0.848921, 'right': 1.170959}}}, 'spInd': 0, 'spVal': 0.846455, 'right': {'left': 0.72003, 'spInd': 0, 'spVal': 0.845815, 'right': 0.952617}}, 'spInd': 0, 'spVal': 0.837522, 'right': {'left': {'left': 1.229373, 'spInd': 0, 'spVal': 0.834078, 'right': {'left': 1.01058, 'spInd': 0, 'spVal': 0.824442, 'right': {'left': 1.082153, 'spInd': 0, 'spVal': 0.822443, 'right': 1.086648}}}, 'spInd': 0, 'spVal': 0.821648, 'right': {'left': 1.280895, 'spInd': 0, 'spVal': 0.820802, 'right': 1.325907}}}, 'spInd': 0, 'spVal': 0.819823, 'right': {'left': {'left': {'left': {'left': 0.835264, 'spInd': 0, 'spVal': 0.814825, 'right': 1.095206}, 'spInd': 0, 'spVal': 0.813719, 'right': {'left': 0.706601, 'spInd': 0, 'spVal': 0.804586, 'right': {'left': 0.924033, 'spInd': 0, 'spVal': 0.795072, 'right': 0.965721}}}, 'spInd': 0, 'spVal': 0.79024, 'right': {'left': 0.533214, 'spInd': 0, 'spVal': 0.789625, 'right': 0.552614}}, 'spInd': 0, 'spVal': 0.785541, 'right': {'left': {'left': {'left': {'left': {'left': 1.165296, 'spInd': 0, 'spVal': 0.782167, 'right': {'left': {'left': 0.886049, 'spInd': 0, 'spVal': 0.78193, 'right': 1.074488}, 'spInd': 0, 'spVal': 0.774301, 'right': 0.836763}}, 'spInd': 0, 'spVal': 0.773422, 'right': {'left': {'left': 1.125943, 'spInd': 0, 'spVal': 0.773168, 'right': 1.140917}, 'spInd': 0, 'spVal': 0.772083, 'right': 1.299018}}, 'spInd': 0, 'spVal': 0.768784, 'right': {'left': {'left': {'left': {'left': 0.899705, 'spInd': 0, 'spVal': 0.768596, 'right': 0.760219}, 'spInd': 0, 'spVal': 0.761474, 'right': 1.058262}, 'spInd': 0, 'spVal': 0.750918, 'right': {'left': 0.748104, 'spInd': 0, 'spVal': 0.750078, 'right': 0.906291}}, 'spInd': 0, 'spVal': 0.742527, 'right': {'left': {'left': 1.087056, 'spInd': 0, 'spVal': 0.737189, 'right': 1.200781}, 'spInd': 0, 'spVal': 0.729234, 'right': {'left': 0.931956, 'spInd': 0, 'spVal': 0.727098, 'right': {'left': 1.000567, 'spInd': 0, 'spVal': 0.726828, 'right': 1.017112}}}}}, 'spInd': 0, 'spVal': 0.72312, 'right': 1.307248}, 'spInd': 0, 'spVal': 0.712503, 'right': {'left': {'left': 0.93349, 'spInd': 0, 'spVal': 0.712386, 'right': 0.564858}, 'spInd': 0, 'spVal': 0.703755, 'right': {'left': {'left': {'left
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。