当前位置:   article > 正文

Educoder 头歌【人工智能之决策树算法】_决策树模型预测头歌

决策树模型预测头歌

第1关:决策树算法求解分类预测问题

任务描述

本关任务:学习决策树,并基于离散的输入值和输出值数据归纳实现样例的布尔分类。

现有一些是否决定在该饭店等待餐桌吃饭的数据(x,y),其中x是输入属性的值向量,y是单一布尔输出值,学员需要分析数据,构造一棵决策树,学习目标谓词 WillWait 的预测( Yes 或者 No ),每一条数据属性如下:

  • Alternate :附件是否有一个更合适的候选饭店(Yes 和 No);

  • Bar :饭店中是否有舒适的酒吧等待区(Yes 和 No);

  • Fri/Sat :当今天是星期五或星期六时,该属性为真 Yes ,否则为假 No;

  • Hungry :是否饿了(Yes 和 No);

  • Patrons :饭店中有多少客人,取值为 None 、 Some 和 Full;

  • Price :饭店价格区间;

  • Raining :是否下雨(Yes 和 No);

  • Reservation :是否预定(Yes 和 No);

  • Type :饭店类型(French 、 Italian 、 Thai 和 Burger);

  • WaitEstimate :对等待时间的估计(010 、 1030 、 30~60 和 >60 分钟)。

相关知识

为了完成本关任务,你需要掌握:1.决策树,2. ID3 算法,3.求解思路。

决策树

决策树表示一个函数,以属性值向量作为输入,返回一个“决策”,对于输入值是离散的和输出值是二值的情况,通常将这称之为布尔分类,其中样例输入被分类为正例(真)或反例(假),决策树在过程中则是通过一系列的计算测试达到决策的目的。

决策树学习的搜索策略是贪婪搜索策略,近似于极小化搜索树的深度,主要思想就是挑选分叉的属性,以便于尽可能对样例进行正确分类。一个完美属性可以将样例全部划分为正例集合和反例集合,这些集合对应决策树的叶子结点,哪些属性优先被选择就是决策树算法的核心,常见的选择算法有 ID3 算法和 C4.5 算法,本关卡重点介绍和学习 ID3 算法。

ID3算法

ID3(Iterative Dichotomiser 3 迭代二叉树三代) 算法是由 Ross Quinlan 发明的,它建立在奥卡姆剃刀理论的基础上,即越是小型的决策树越是优于大型的决策树,实际上也是一个启发式算法,是一种自顶向下增长树的贪婪算法,在每个结点选取能最好地分类样例的属性,重复这个过程,直到这棵树能完美分类训练样本或所有的属性都使用过了,其算法伪代码如下:
奥卡姆剃刀理论阐述了一个信息熵的概念,以此来选择最优分类属性。设随机变量 V V V具有值 v k v_k vk,各自的概率表示为 P ( v k ) P(v_k) P(vk),则 V V V的熵的定义为:在这里插入图片描述举个例子来理解信息熵的定义,随机抛掷一枚硬币,出现正面和反面的概率都为0.5,根据以上熵的定义,可以得出以下式子:
在这里插入图片描述借助这个例子,设布尔随机变量以 q q q的概率为真,则可定义该变量的熵为:
在这里插入图片描述
那么对于拥有多个属性的数据来说,ID3选择属性的方式则是计算该属性的信息增益(收益),信息增益最大的被优先选择作为决策树的分支属性。

带有 d d d个不同值的属性 A A A将训练集 E E E划分为 E 1 , . . . , E d E_1 ,...,E_d E1,...,Ed,每个子集 E k E_k Ek p k p_k pk个正例和 n k n_k nk个反例,对于属性 A A A的信息熵为 R e m a i n d e r ( A ) Remainder(A) Remainder(A)
在这里插入图片描述对属性A的测试获得的信息增益为 G a i n ( A ) Gain(A) Gain(A)
在这里插入图片描述

求解思路

样例学习的问题首先要做的就是好好的分析数据,详细了解代码 testDecisionTree.py 里的样例数据 data ,然后完成对数据的解析(计算出最优分类属性),并借助解析函数建立决策树,最终能完成对新输入案例的分类预测。

编程要求

本关的编程任务是补全右侧代码片段 buildpredict_parse_data__calc_all_gain__calc_attr_gain__calc_bool_gain__get_targ__is_leaf_BeginEnd 中间的代码,具体要求如下:

  • 在build中,创建一棵决策树,输入参数为根结点;

  • 在predict中,根据归纳好的决策树预测输入样例x的谓词 WillWait 状态(Yes 或者 No);

  • 在_parse_data_中,解析输入矩阵数据(在 Python 里以二维列表数据存储),各参数详见代码中函数注解,然后返回信息增益最大的属性名称及其属性值列表;

  • _calc_all_gain_中,计算所有样本的信息熵并返回,各参数详见代码中函数注解;

  • _calc_attr_gain_中,计算某一特征属性的信息熵并返回,各参数详见代码中函数注解;

  • _calc_bool_gain_中,计算二值随机变量的信息熵并返回,各参数详见代码中函数注解;

  • _get_targ_中,计算叶子结点的决策分类标签并返回,各参数详见代码中函数注解;

  • _is_leaf_中,判断该结点是否为叶子结点,若是则返回 True,否则返回 False。

测试说明

平台将自动编译补全后的代码,并生成若干组测试数据,接着根据程序的输出判断程序是否正确。

以下是平台的测试样例:
测试输入:
[[example, Alt, Bar, Fri, Hun, Pat, Price, Rain, Res, Type, Est],[x1, Yes, No, No, Yes, Some, $$$, No, Yes, French, 0-10]]
预期输出:
Yes

开始你的任务吧,祝你成功!

# -*- coding: UTF-8 -*-

import math
import numpy as np


class TreeNode:
    '''决策树结点数据结构
    成员变量:
    row - int 列表数据的行数,初始13
    col - int 列表数据的列数,初始12
    data - list[[]] 二维列表数据,初始数据形式在testDecisionTree.py里
                    第0行:[第0列:example(样本名字) 中间各列(1-10):各个特征属性名称 第11列:WillW ait(目标分类) ]
                    第1-12行:[样本名字,具体属性值,分类目标]
        data = [
        ['example', 'Alt', 'Bar', 'Fri', 'Hun', 'Pat',  'Price', 'Rain', 'Res', 'Type',   'Est',   'WillW ait'],
        ['x1',      'Yes', 'No',  'No',  'Yes', 'Some', '$$$',   'No',   'Yes', 'French', '0-10',  'y1=Yes'   ],
        ['x2',      'Yes', 'No',  'No',  'Yes', 'Full', '$',     'No',   'No',  'Thai',   '30-60', 'y2=No'    ],
            ........            .....       .....       .........           ............
        ['x12',     'Yes', 'Yes', 'Yes', 'Yes', 'Full', '$',     'No',   'No',  'Burger', '30-60', 'y12=Yes'  ] ]
    targ - string 分类结果 Yes No
    name - string 结点名字:特征属性名称
    attr - list[string] 该特征属性下的各个属性值
    children - list[GameNode] 该特征属性下的各个决策树子结点,与 attr 一一对应
    '''

    def __init__(self, row, col, data):
        self.row = row
        self.col = col
        self.data = data
        self.targ = ''  # target result
        self.name = ''  # attribute name
        self.attr = []  # attribute value list
        self.child = []  # attribute - TreeNode List


class DecisionTree:
    '''决策树
    成员变量:
    root - TreeNode 博弈树根结点
    成员函数:
    buildTree - 创建决策树
    predict - 预测样本分类标签
    _parse_data_ - 解析数据中最大信息增益的特性属性
    _calc_all_gain_ - 计算整个样本的信息熵
    _calc_attr_gain_ - 计算某一特征属性的信息熵
    _calc_bool_gain_ - 通用计算函数:计算二值随机变量的信息熵
    _get_targ_ - 获取叶子结点的决策分类标签
    _is_leaf_ - 判断该结点是否为叶子结点
    '''

    def __init__(self, row, col, data):
        self.root = TreeNode(row, col, data)

    def build(self, root):
        '''递归法创建博弈树
        参数:
        root - TreeNode 初始为决策树根结点
        '''
        # 请在这里补充代码,完成本关任务
        # ********** Begin **********#
        name, label = self._parse_data_(root.row, root.col, root.data)
        root.name = name
        root.attr = label
        for value in root.attr:
            subData = [root.data[0]] + [x for x in root.data if x[root.data[0].index(root.name)] == value]
            if len(set([x[-1].split('=')[-1] for x in subData[1:]])) != 1:
                dtree = DecisionTree(len(subData), root.col, subData)
                dtree.build(dtree.root)
                # dNode = TreeNode(len(subData), root.col, subData)
                # self.build()
            else:
                dtree = DecisionTree(len(subData), root.col, subData)
                dtree.root.targ = dtree.root.data[1][-1].split('=')[-1]
            root.child.append(dtree)
        # ********** End **********#

    def predict(self, root, x):
        '''分类预测
        参数:
        root - TreeNode 决策树根结点
        x - [[]] 测试数据,形如:
           [ ['example', 'Alt', 'Bar', 'Fri', 'Hun', 'Pat', 'Price', 'Rain', 'Res', 'Type',  'Est'],
             ['x1',      'Yes', 'No',  'No',  'Yes', 'Some', '$$$',  'No',   'Yes', 'French','0-10'] ]
        返回值:
        clf - string 分类标签 Yes No
        '''
        # 请在这里补充代码,完成本关任务
        # ********** Begin **********#
        #self.printtree(root)
        now = root
        while not self._is_leaf_(now):
            i = x[0].index(now.name)
            i = now.attr.index(x[1][i])
            now = now.child[i].root
        return self._get_targ_(now)
        # ********** End **********#

    def _parse_data_(self, row, col, data):
        '''解析数据:计算数据中最大信息增益的特性属性
        参数:
        row - int 列表数据的行数
        col - int 列表数据的列数
        data - list[[]] 二维列表数据,形如:
                第0行:[第0列:example(样本名字) 中间各列(1-10):各个特征属性名称 第11列:WillW ait(目标分类) ]
                第1-12行:[样本名字,具体属性值,分类目标]
        data = [
        ['example', 'Alt', 'Bar', 'Fri', 'Hun', 'Pat',  'Price', 'Rain', 'Res', 'Type',   'Est',   'WillW ait'],
        ['x1',      'Yes', 'No',  'No',  'Yes', 'Some', '$$$',   'No',   'Yes', 'French', '0-10',  'y1=Yes'   ],
        ['x2',      'Yes', 'No',  'No',  'Yes', 'Full', '$',     'No',   'No',  'Thai',   '30-60', 'y2=No'    ],
            ........            .....       .....       .........           ............
        ['x12',     'Yes', 'Yes', 'Yes', 'Yes', 'Full', '$',     'No',   'No',  'Burger', '30-60', 'y12=Yes'  ] ]
        返回值:
        clf - string, list[] 信息增益最大的属性名称 及其 属性值列表
        '''
        # 请在这里补充代码,完成本关任务
        # ********** Begin **********#
        numFeature = col - 1
        d = [x[-1] for x in data[1:]]
        baseEntropy = self._calc_all_gain_(1, d)
        bestInforGain = -1
        bestFeature = -1

        for i in range(1, numFeature):
            d = [[x[i], x[-1]] for x in data[1:]]
            inforGain = baseEntropy - self._calc_attr_gain_(row - 1, d)
            if inforGain > bestInforGain:
                bestInforGain = inforGain
                bestFeature = i

        return data[0][bestFeature], list(set([x[bestFeature] for x in data[1:]]))
        # ********** End **********#

    def _calc_all_gain_(self, row, data):
        '''计算整个样本的信息熵
        参数:
        row - int 列表数据的行数
        data - list[] 一维列表数据,形如:[分类目标]
                data = ['y1=Yes', 'y2=No', ........, 'y12=Yes']
        返回值:
        clf - float 信息熵
        '''
        # 请在这里补充代码,完成本关任务
        # ********** Begin **********#
        numEntries = len(data)
        labelCounts = {}
        for Feature in data:
            feature = Feature.split('=')[-1]
            if feature not in labelCounts.keys():
                labelCounts[feature] = 0;
            labelCounts[feature] += 1
        shannonEnt = 0.0
        if len(labelCounts.keys()) == 2:
            shannonEnt = self._calc_bool_gain_(float(list(labelCounts.values())[0]) / numEntries)
        return shannonEnt
        # ********** End **********#

    def _calc_attr_gain_(self, row, data):
        '''计算某一特征属性的信息熵
        参数:
        row - int 列表数据的行数
        data - list[[]] 二维列表数据(2列),形如:[[某一属性值,分类目标]]
                  [ ['0-10',  'y1=Yes'   ],
                    ['30-60', 'y2=No'    ],
                      ........
                    ['30-60', 'y12=Yes'  ] ]
        返回值:
        clf - float 信息熵
        '''
        # 请在这里补充代码,完成本关任务
        # ********** Begin **********#
        featList = [x[0] for x in data]
        uniqueVals = set(featList)
        newEntropy = 0.0
        for value in uniqueVals:
            subDataSet = [x[1] for x in data if x[0] == value]
            prob = len(subDataSet) / float(len(data))
            newEntropy += prob * self._calc_all_gain_(1, subDataSet)
        return newEntropy
        # ********** End **********#

    def _calc_bool_gain_(self, p):
        '''通用计算函数:计算二值随机变量的信息熵
        参数:
        p - float 二值随机变量的概率 在[0, 1]之间
        返回值:
        clf - float 信息熵
        '''
        # 请在这里补充代码,完成本关任务
        # ********** Begin **********#
        return -(p * math.log(p, 2) + (1 - p) * math.log(1 - p, 2))
        # ********** End **********#

    def _get_targ_(self, node):
        '''计算叶子结点的决策分类标签
        参数:
        node - TreeNode 决策树结点
        返回值:
        clf - string 分类标签 Yes No
        '''
        # 请在这里补充代码,完成本关任务
        # ********** Begin **********#
        return node.targ
        # ********** End **********#

    def _is_leaf_(self, node):
        '''判断该结点是否为叶子结点
        参数:
        node - TreeNode 决策树结点
        返回值:
        clf - bool 叶子结点True 非叶子结点False
        '''
        # 请在这里补充代码,完成本关任务
        # ********** Begin **********#
        if len(node.child) == 0:
            clf = True
        else:
            clf = False
        return clf
        # ********** End **********#

    def printtree(self, root):
        print(root.name)
        for i in range(len(root.child)):
            if self._is_leaf_(root.child[i].root):
                print(root.name, '-- ', root.attr[i], ' -->', root.child[i].root.targ)
            else:
                print(root.name, '-- ', root.attr[i], ' -->', root.child[i].root.name)
        for i in range(len(root.child)):
            self.printtree(root.child[i].root)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107
  • 108
  • 109
  • 110
  • 111
  • 112
  • 113
  • 114
  • 115
  • 116
  • 117
  • 118
  • 119
  • 120
  • 121
  • 122
  • 123
  • 124
  • 125
  • 126
  • 127
  • 128
  • 129
  • 130
  • 131
  • 132
  • 133
  • 134
  • 135
  • 136
  • 137
  • 138
  • 139
  • 140
  • 141
  • 142
  • 143
  • 144
  • 145
  • 146
  • 147
  • 148
  • 149
  • 150
  • 151
  • 152
  • 153
  • 154
  • 155
  • 156
  • 157
  • 158
  • 159
  • 160
  • 161
  • 162
  • 163
  • 164
  • 165
  • 166
  • 167
  • 168
  • 169
  • 170
  • 171
  • 172
  • 173
  • 174
  • 175
  • 176
  • 177
  • 178
  • 179
  • 180
  • 181
  • 182
  • 183
  • 184
  • 185
  • 186
  • 187
  • 188
  • 189
  • 190
  • 191
  • 192
  • 193
  • 194
  • 195
  • 196
  • 197
  • 198
  • 199
  • 200
  • 201
  • 202
  • 203
  • 204
  • 205
  • 206
  • 207
  • 208
  • 209
  • 210
  • 211
  • 212
  • 213
  • 214
  • 215
  • 216
  • 217
  • 218
  • 219
  • 220
  • 221
  • 222
  • 223
  • 224
  • 225
  • 226
  • 227
  • 228
  • 229
  • 230

总结

在这里插入图片描述
总是因为一些小问题出错,浪费了太多时间,虽然也有网上编译器的问题,但是还是我的粗心,没有一个合适的思路。
博客也有很长时间没更新了,特地发上来警醒自己,等寒假或者什么时候有空的时候,会统一归档一下。
不过我的报告不都是抄的吗(笑)

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/Monodyee/article/detail/643532
推荐阅读
相关标签
  

闽ICP备14008679号