当前位置:   article > 正文

机器学习实战之树回归(CART)python实现(附python3代码)_cart算法python

cart算法python

树回归
CART(Classification And Regression Tree, 分类回归树)

完整代码github
环境 python3

决策树分类

决策树不断将数据切分成小数据集,直到所有的目标变量完全相同,或者数据不能再切分为止。
决策树是一种贪心算法,它要在给定的时间内做出最佳选择,但不关心是否达到全局最优。

ID3的做法是每次选取当前最佳的特征来分割数据,并按照该特征的所有取值进行切分,一旦按某种特征进行切分后,该特征在之后的算法执行过程中就不在起作用。这种切分方式过于迅速且不能处理连续型特征,只有事先将连续型特征转换为离散型特征,才能在ID3算法中使用,但是这种转换会破坏连续型变量的内在性质。
详细介绍可以参照决策树python实现(ID3 和 C4.5)

树回归

优点:可以对复杂和非线性的数据建模
缺点:结果不易理解
适用数据类型:数值型和标称型数据

CART使用二元切分法来处理连续型变量,对CART稍作修改就可以处理回归问题。
二元切分法:每次把数据集切分成两份,如果数据的某特征值大于切分所要求的值,那么这些数据进入树的左子树,反之则进入树的右子树。二元切分法节省了树的构建时间,这点意义不大,因为树的构建一般是离线完成。

树回归的一般方法
(1)收集数据:任意方法
(2)准备数据:需要数值型数据,标称型数据应该映射成二值型数据
(3)分析数据:绘出数据的二维可视化显示结果,以字典方式生成树
(4)训练算法:大部分时间都花在叶节点树模型的构建上
(5)测试算法:使用测试数据上的R^2值来分析模型的效果
(6)使用算法:使用训练出的树做预测,预测结果可以用来做许多事情

连续和离散型特征的树构建

用字典来存储树的数据结构,该字典包含以下4元素:

  • 待切分的特征
  • 待切分的特征值
  • 右子树,当不再需要切分的时候,也可以是单个值
  • 左子树,与右子树类似

ID3用一部字典来存储每个切分,该字典可以包含两个或两个以上的值。
CART算法只做二元切分,所以这里可以固定树的数据结构。树包含左键和右键,可以存储另一棵子树或者单个值;字典还包含特征和特征值这两个键,它们给出切分算法所有的特征和特征值。
本章构建两种树:回归树和模型树
回归树:其每个节点包含单个值
模型树:其每个节点包含一个线性方程

函数createTree()的伪代码

找到最佳切分特征:
    如果该节点不能再分,将该节点存为叶节点
    执行二元切分
    在右子树调用createTree()方法
    在左子树调用createTree()方法
  • 1
  • 2
  • 3
  • 4
  • 5

CART算法的实现代码

from numpy import *
import numpy as np

def loadDataSet(filename):
    # 加载数据集
    '''
    :param filename:
    :return: 数据+标签列表
    '''
    dataMat = []
    fr = open(filename)
    for line in fr.readlines():
        curLine = line.strip().split('\t')
        fltLine = list(map(float, curLine))
        dataMat.append(fltLine)
    return dataMat

def binSplitDataSet(dataSet, feature, value):
    '''

    :param dataSet:数据集合
    :param feature: 待切分的特征
    :param value: 该特征的某个值
    :return: 通过数组过滤的方式将上述数据集合切分得到两个子集并返回
    '''
    mat0 = dataSet[nonzero(dataSet[:, feature] > value)[0], :]
    mat1 = dataSet[nonzero(dataSet[:, feature] <= value)[0], :]
    return mat0, mat1


def createTree(dataSet, leafType=regLeaf, errType=regErr, ops=(1, 4)):
    '''
    递归函数:树构建
    :param dataSet:数据集 
    :param leafType:建立叶节点函数 
    :param errType: 误差计算函数
    :param ops: 包含树构建所需的其他函数的元组(tolS,tolN), tolS:容许的误差下降值,tolN:切分的最少样本数
    :return: 
    '''
    feat, val = chooseBestSplit(dataSet, leafType, errType, ops)
    # 满足停止条件时返回叶节点值
    if feat == None:
        return val
    retTree = {}
    retTree['spInd'] = feat
    retTree['spVal'] = val
    lSet, rSet = binSplitDataSet(dataSet, feat, val)
    retTree['left'] = createTree(lSet, leafType, errType, ops)
    retTree['right'] = createTree(rSet, leafType, errType, ops)
    return retTree
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50

将CART算法用于回归

模型树:把叶节点设定为分段线性函数,模型树的可解释性是其优于回归树的特点之一
误差计算:对于给定数据集,应先用线性模型来对它进行拟合,然后计算真实目标值和模型预测值之间的差值,最后将这些差值的平方求和就得到了所需的误差。
补充一些新的代码,使createTree()运行。
实现chooseBestSplit函数:给定某个误差计算方法,该函数会找到数据集上的最佳二元切分方式。该函数还要确定什么时候停止切分,一旦停止切分会生成一个叶节点。即用最佳方式切分数据集和生成相应的叶节点。
伪代码

对每个特征:
	对每个特征值:
		将数据切分成两份
		计算切分误差
		如果当前误差小于当前最小误差,则将当前切分设定为最佳切分并更新最小误差
返回最佳切分的特征和阈值
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6

python代码如下

def regLeaf(dataSet):
    '''
    生成叶节点
    :param dataSet: 数据集
    :return:
    '''
    return mean(dataSet[:, -1])


def regErr(dataSet):
    '''
    误差估计函数,在给定数据上计算目标变量的平方误差,总方差
    :param dataSet:
    :return:
    '''
    return var(dataSet[:, -1])*shape(dataSet)[0]

def chooseBestSplit(dataSet, leafType=regLeaf, errType=regErr, ops=(1,4)):
    '''
    找到数据的最佳二元切分方式
    :param dataSet: 数据集
    :param leafType:建立叶节点的函数
    :param errType: 误差计算函数
    :param ops: 包含树构建所需其他参数的元组
    :return: 特征编号和切分特征值
    '''
    tolS = ops[0]  # 容许的误差下降值
    tolN = ops[1]  # 切分的最少样本数
    if len(set(dataSet[:, -1].T.tolist()[0])) == 1:
        return None, leafType(dataSet)
    m, n = shape(dataSet)  # 当前数据集的大小
    S = errType(dataSet)   # 当前数据集的误差
    bestS = inf
    bestIndex = 0
    bestValue = 0
    for featIndex in range(n-1):
        for splitVal in set((dataSet[:, featIndex].T.A.tolist())[0]):
            mat0, mat1 = binSplitDataSet(dataSet, featIndex, splitVal)
            if (shape(mat0)[0])<tolN or shape(mat1)[0]<tolN:
                continue
            newS = errType(mat0) + errType(mat1)
            if newS < bestS:
                bestIndex = featIndex
                bestS = newS
                bestValue = splitVal
    # 切分后如果误差减少不大,则不应该进行切分操作,直接创建叶节点
    if S-bestS < tolS:
        return None, leafType(dataSet)
    mat0, mat1 = binSplitDataSet(dataSet, bestIndex, bestValue)
    # 切分后的两个子集大小是否小于用户自定义的大小,则不应切分
    if shape(mat0)[0] < tolN or shape(mat1)[0] < tolN:
        return None, leafType(dataSet)
    return bestIndex, bestValue
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53

运行测试代码:

 	 myDat = loadDataSet('ex00.txt')
    # # plot_db(myDat)
     myMat = mat(myDat)
     retTree = createTree(myMat)
     print('retTree  ', retTree)
  • 1
  • 2
  • 3
  • 4
  • 5

结果:

retTree   {'right': -0.04465028571428572, 'left': 1.0180967672413792, 'spVal': 0.48813, 'spInd': 0}
  • 1

可以看到,树中包含2个叶节点

运行测试代码:

	myDat1 = loadDataSet('ex0.txt')
    plot_db(myDat1)
    myMat1 = mat(myDat1)
    retTree1 = createTree(myMat1)
    print('retTree1  ', retTree1)
  • 1
  • 2
  • 3
  • 4
  • 5

结果:

retTree1   {'spVal': 0.39435, 'left': {'spVal': 0.582002, 'left': {'spVal': 0.797583, 'left': 3.9871632, 'spInd': 1, 'right': 2.9836209534883724}, 'spInd': 1, 'right': 1.980035071428571}, 'spInd': 1, 'right': {'spVal': 0.197834, 'left': 1.0289583666666666, 'spInd': 1, 'right': -0.023838155555555553}}
  • 1

可以看到,树中包含5个叶节点

树剪枝

如果一棵树的节点过多,表明该模型可能发生了过拟合==。
之前的算法都是使用了测试集上某种交叉验证技术来发现过拟合。决策树也是如此。
通过降低决策树的复杂度来避免过拟合的过程称为剪枝
在函数chooseBestSplit()中提前终止条件,实际上是所谓的预剪枝操作。
另一种形式的剪枝需要使用测试集和训练集,称作后剪枝

预剪枝

树构建算法对输入的参数tolS和tolN非常敏感,如果选用其他值,构建的树效果不太好,例如,测试代码:

	myDat = loadDataSet('ex00.txt')
    # plot_db(myDat)
    myMat = mat(myDat)
    retTree = createTree(myMat)
    print('retTree  ', retTree)
	retTree_1 = createTree(myMat, ops=(0, 1))
    print('retTree_1  ', retTree_1)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7

结果为:

retTree   {'left': 1.0180967672413792, 'spInd': 0, 'spVal': 0.48813, 'right': -0.04465028571428572}
retTree_1   {'left': {'left': {'left': {'left': {'left': {'left': {'left': {'left': {'left': {'left': {'left': {'left': {'left': {'left': {'left': {'left': 1.035533, 'spInd': 0, 'spVal': 0.993349, 'right': 1.077553}, 'spInd': 0, 'spVal': 0.989888, 'right': {'left': 0.744207, 'spInd': 0, 'spVal': 0.988852, 'right': 1.069062}}, 'spInd': 0, 'spVal': 0.985425, 'right': 1.227946}, 'spInd': 0, 'spVal': 0.976414, 'right': {'left': {'left': 0.862911, 'spInd': 0, 'spVal': 0.975022, 'right': 0.673579}, 'spInd': 0, 'spVal': 0.953112, 'right': {'left': {'left': 1.06469, 'spInd': 0, 'spVal': 0.951949, 'right': {'left': 0.945255, 'spInd': 0, 'spVal': 0.950153, 'right': 1.022906}}, 'spInd': 0, 'spVal': 0.948268, 'right': {'left': 0.631862, 'spInd': 0, 'spVal': 0.936783, 'right': {'left': {'left': 1.026258, 'spInd': 0, 'spVal': 0.930173, 'right': 1.035645}, 'spInd': 0, 'spVal': 0.928097, 'right': 0.883225}}}}}, 'spInd': 0, 'spVal': 0.919384, 'right': {'left': {'left': {'left': {'left': 1.029889, 'spInd': 0, 'spVal': 0.919074, 'right': 1.123413}, 'spInd': 0, 'spVal': 0.902532, 'right': {'left': 0.861601, 'spInd': 0, 'spVal': 0.901056, 'right': {'left': 1.0559, 'spInd': 0, 'spVal': 0.900272, 'right': 0.996871}}}, 'spInd': 0, 'spVal': 0.897094, 'right': {'left': 1.240209, 'spInd': 0, 'spVal': 0.89593, 'right': {'left': 1.077275, 'spInd': 0, 'spVal': 0.884512, 'right': 1.117833}}}, 'spInd': 0, 'spVal': 0.877241, 'right': {'left': {'left': {'left': 0.797005, 'spInd': 0, 'spVal': 0.869077, 'right': 1.114825}, 'spInd': 0, 'spVal': 0.860049, 'right': 0.71749}, 'spInd': 0, 'spVal': 0.848921, 'right': 1.170959}}}, 'spInd': 0, 'spVal': 0.846455, 'right': {'left': 0.72003, 'spInd': 0, 'spVal': 0.845815, 'right': 0.952617}}, 'spInd': 0, 'spVal': 0.837522, 'right': {'left': {'left': 1.229373, 'spInd': 0, 'spVal': 0.834078, 'right': {'left': 1.01058, 'spInd': 0, 'spVal': 0.824442, 'right': {'left': 1.082153, 'spInd': 0, 'spVal': 0.822443, 'right': 1.086648}}}, 'spInd': 0, 'spVal': 0.821648, 'right': {'left': 1.280895, 'spInd': 0, 'spVal': 0.820802, 'right': 1.325907}}}, 'spInd': 0, 'spVal': 0.819823, 'right': {'left': {'left': {'left': {'left': 0.835264, 'spInd': 0, 'spVal': 0.814825, 'right': 1.095206}, 'spInd': 0, 'spVal': 0.813719, 'right': {'left': 0.706601, 'spInd': 0, 'spVal': 0.804586, 'right': {'left': 0.924033, 'spInd': 0, 'spVal': 0.795072, 'right': 0.965721}}}, 'spInd': 0, 'spVal': 0.79024, 'right': {'left': 0.533214, 'spInd': 0, 'spVal': 0.789625, 'right': 0.552614}}, 'spInd': 0, 'spVal': 0.785541, 'right': {'left': {'left': {'left': {'left': {'left': 1.165296, 'spInd': 0, 'spVal': 0.782167, 'right': {'left': {'left': 0.886049, 'spInd': 0, 'spVal': 0.78193, 'right': 1.074488}, 'spInd': 0, 'spVal': 0.774301, 'right': 0.836763}}, 'spInd': 0, 'spVal': 0.773422, 'right': {'left': {'left': 1.125943, 'spInd': 0, 'spVal': 0.773168, 'right': 1.140917}, 'spInd': 0, 'spVal': 0.772083, 'right': 1.299018}}, 'spInd': 0, 'spVal': 0.768784, 'right': {'left': {'left': {'left': {'left': 0.899705, 'spInd': 0, 'spVal': 0.768596, 'right': 0.760219}, 'spInd': 0, 'spVal': 0.761474, 'right': 1.058262}, 'spInd': 0, 'spVal': 0.750918, 'right': {'left': 0.748104, 'spInd': 0, 'spVal': 0.750078, 'right': 0.906291}}, 'spInd': 0, 'spVal': 0.742527, 'right': {'left': {'left': 1.087056, 'spInd': 0, 'spVal': 0.737189, 'right': 1.200781}, 'spInd': 0, 'spVal': 0.729234, 'right': {'left': 0.931956, 'spInd': 0, 'spVal': 0.727098, 'right': {'left': 1.000567, 'spInd': 0, 'spVal': 0.726828, 'right': 1.017112}}}}}, 'spInd': 0, 'spVal': 0.72312, 'right': 1.307248}, 'spInd': 0, 'spVal': 0.712503, 'right': {'left': {'left': 0.93349, 'spInd': 0, 'spVal': 0.712386, 'right': 0.564858}, 'spInd': 0, 'spVal': 0.703755, 'right': {'left': {'left': {'left
  • 1
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/weixin_40725706/article/detail/726864
推荐阅读
相关标签
  

闽ICP备14008679号