当前位置:   article > 正文

实验——基于决策树算法完成鸢尾花卉品种预测任务_鸢尾花数据集决策树实验报告

鸢尾花数据集决策树实验报告


一 实验要求

本实验通过鸢尾花数据集iris.csv来实现对决策树进一步的了解。其中, Iris鸢尾花数据集是一个经典数据集,在统计学习和机器学习领域都经常被用作示例。数据集内包含3类共150条记录,每类各50个数据,每条记录都有4项特征:花萼长度、花萼宽度、花瓣长度、花瓣宽度,可以通过这4个特征预测鸢尾花卉属于iris-setosa, iris-versicolour, iris-virginica三个类别中的 哪一品种。Iris数据集样例如下图所示:

在这里插入图片描述

本实验将五分之四的数据集作为训练集对决策树模型进行训练;将剩余五 分之一的数据集作为测试集,采用训练好的决策树模型对其进行预测。训练集 与测试集的数据随机选取。本实验采用准确率(accuracy)作为模型的评估函数:预测结果正确的数量占样本总数,(TP+TN)/(TP+TN+FP+FN)。

【实验要求】

  1. 本实验要求输出测试集各样本的预测标签和真实标签,并计算模型准确率。(选做)另外,给出 3 个可视化预测结果。

  2. 决策树算法可以分别尝试 ID3,C4.5,cart树,并评判效果。

  3. (选做):对你的决策树模型进行预剪枝与后剪枝

  4. (选做):分别做 c4.5 和 cart 树的剪枝并比较不同。

二 实验思路

分析数据结构,因为没有每个样本独有的属性(例如学生ID),决定采用ID3决策树(ID3决策树的信息增益偏向于可能值较多的属性)。采用信息增益Information Gain确定划分的最优特征,对于保存树的结构方面,采用字典的形式保存,以下方字典形式为例:

{
    "PetalWidth": {
        "0": "Iris-setosa",
        "1": {
            "PetalLength": {
                "0": "Iris-setosa",
                "1": "Iris-versicolor",
                "2": {
                    "SepalLength": {
                        "1": "Iris-virginica",
                        "2": "Iris-versicolor"
                    }
                }
            }
        },
        "2": {
            "SepalLength": {
                "0": "Iris-virginica",
                "1": "Iris-virginica",
                "2": {
                    "SepalWidth": {
                        "0": "Iris-virginica",
                        "1": "Iris-virginica"
                    }
                }
            }
        }
    }
}
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29

由于数据是连续的,在训练决策树之前需要将其离散化,利用python中seaborn库观察数据分布:

import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sb

Data = pd.read_csv('iris.csv')
sb.pairplot(Data.dropna(), hue='Species')
plt.show()
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7

在这里插入图片描述

因此我们确定数据的离散边界:

离散特征\离散后的值012
SepalLength0-5.55.5-6.36.3-inf
SepalWidth0.3.23.2-inf
PetalLength0-22-4.94.9-inf
PetalWidth0-0.60.6-1.71.7-inf

为了保证结果的随机性,训练集与测试集的划分采用随机采样,每种花随机抽取10个共30个测试集样本,其余为训练集样本:

# 分割训练集与测试集
def split_train_test(data:pd.DataFrame):
    # 随机采样
    test_index = random.sample(range(50),10)
    test_index.extend(random.sample(range(50,100),10))
    test_index.extend(random.sample(range(100,150),10))
    print(test_index)
    testSet = data.iloc[test_index]
    train_index = list(range(150))
    for index in test_index:
        train_index.remove(index)
    # 划分训练集
    trainSet = data.iloc[train_index]
    return trainSet,testSet
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14

三 实验代码

  • split_train_test:分割训练集与测试集
  • ShannonEntropy:计算一个Dataframe关于forecast_label列的信息熵
  • InformationGain:计算一个Dataframelabel标签关于forecast_label的信息增益
  • createTree:递归生成决策树的过程
  • decision:对一个样本进行决策
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sb
import json
import random

def split_train_test(data:pd.DataFrame):
    """
    分割训练集与测试集
    :param data: 总的数据集
    :return: 返回训练集与测试集
    """
    # 随机采样
    test_index = random.sample(range(50),10)
    test_index.extend(random.sample(range(50,100),10))
    test_index.extend(random.sample(range(100,150),10))
    print(test_index)
    testSet = data.iloc[test_index]
    train_index = list(range(150))
    for index in test_index:
        train_index.remove(index)
    # 划分训练集
    trainSet = data.iloc[train_index]
    return trainSet,testSet


def ShannonEntropy(data:pd.DataFrame,forecast_label:str)->float:
    """
    计算一个数据集关于某个标签(这里是预测标签)的信息熵
    :param data: 数据集
    :param forecast_label: 预测标签
    :return: 返回这个数据集的信息熵
    """
    total = data.shape[0]
    kinds = data[forecast_label].value_counts()
    Entropy = 0.0
    # 对于每种预测标签 计算pk*log(pk)
    for i in range(kinds.shape[0]):
        # 计算每种预测标签的比例
        prior_probability = kinds[i]/total
        # 计算信息熵
        Entropy += (prior_probability * np.log2(prior_probability))
    return -Entropy


def InformationGain(data:pd.DataFrame,label:str,forecast_label:str)->float:
    """
    计算label标签关于forecast_label的信息增益
    :param data: 数据集
    :param label: 计算标签
    :param forecast_label: 预测标签
    :return: 信息增益
    """
    # 计算总的信息熵Entropy(S)
    total_entropy = ShannonEntropy(data,forecast_label)
    # 初始化信息增益
    gain = total_entropy
    # 按照计算标签分组
    sub_frame = data[[label,'Species']]
    group = sub_frame.groupby(label)
    # 计算信息增益
    for key, df in group:
        gain -= (df.shape[0]/data.shape[0]) * ShannonEntropy(df,'Species')
    return gain


def createTree(data:pd.DataFrame)->dict:
    """
    递归的创建决策树
    :param data: 训练集数据
    :return: 返回一个字典表示决策树
    """
    # 该分支下的实例只有一种分类
    if len(data['Species'].value_counts()) == 1:
        return data['Species'].iloc[0]

    # 初始化最优信息增益与最优属性
    bestGain = 0
    bestFeature = -1

    # 对于每种属性计算信息增益,选出信息增益最大的一列
    for column in data:
        if column != 'Species':
            gain = InformationGain(data, column, 'Species')
            if bestGain < gain:
                bestGain = gain
                bestFeature = column

    # 数据集中所有数据都相同,但种类不同,返回最多数量的种类
    if bestFeature == -1:
        valueCount = data['Species'].value_counts()
        return valueCount.index[0]


    # 初始化一个字典
    myTree = {bestFeature: {}}
    # 统计出最佳属性的所有可能取值
    valueList = set(data[bestFeature])
    for value in valueList:
        # 递归的构造子树
        myTree[bestFeature][value] = createTree(data[data[bestFeature] == value])

    return myTree


def decision(tree:dict,testVector:pd.Series):
    """
    预测一个测试集样本的类别
    :param tree: 生成的决策树
    :param testVector: 测试数据向量
    :return: 返回预测标签
    """
    # 初始化预测标签
    forecastLabel = 0
    # 获取当前决策树第一个节点属性
    firstFeature = next(iter(tree))
    # 获取子树
    childTree = tree[firstFeature]
    # 对子树中不同的可能值检测是否相等
    for key in childTree.keys():
        # 满足条件深入到下一层
        if testVector[firstFeature] == key:
            # 下一层是分支节点
            if type(childTree[key]) == dict :
                forecastLabel = decision(childTree[key],testVector)
            # 下一层是叶节点
            else:
                forecastLabel = childTree[key]
    return forecastLabel


if __name__ == '__main__' :
    Data = pd.read_csv('iris.csv')

    # 画出统计分布图,统计每种类别的特征
    # sb.pairplot(Data.dropna(), hue='Species')
    # plt.show()

    # 数据离散化处理
    Data['SepalLength'] = np.digitize(Data['SepalLength'],bins=[5.5,6.3])
    Data['SepalWidth'] = np.digitize(Data['SepalWidth'],bins=[3.2])
    Data['PetalLength'] = np.digitize(Data['PetalLength'],bins=[2,4.9])
    Data['PetalWidth'] = np.digitize(Data['PetalWidth'],bins=[0.6,1.7])
    # 数据离散化的字典
    discrete_dict = {
        'SepalLength' : {'0-5.5':0,'5.5-6.3':1,'6.3-inf':2},
        'SepalWidth' : {'0.3.2':0,'3.2-inf':1},
        'PetalLength' : {'0-2':0,'2-4.9':1,'4.9-inf':2},
        'PetalWidth' : {'0-0.6':0,'0.6-1.7':1,'1.7-inf':2}
    }
    print('数据离散化字典:')
    print(json.dumps(discrete_dict, indent=4))
    # 分出训练集与测试集
    train_set,test_set = split_train_test(Data)
    # 训练出决策树
    tree = createTree(train_set)
    print("决策树字典表示:")
    print(json.dumps(tree, indent=4, sort_keys=True))
    # 初始化统计参数
    T = 0
    N = 0
    for i in range(test_set.shape[0]):
        print('================================================')
        vector = test_set.iloc[i, :-1]
        sl = vector['SepalLength']
        sw = vector['SepalWidth']
        pl = vector['PetalLength']
        pw = vector['PetalWidth']
        trueLabel = test_set.iloc[i]['Species']
        print(f'离散化后的测试数据:SepalLength={sl},SepalWidth={sw},PetalLength={pl},PetalWidth={pw},真实标签={trueLabel}')
        forecastLabel= decision(tree,vector)
        if forecastLabel == trueLabel:
            T+=1
            print(f'预测为{forecastLabel},预测正确')
        else:
            N+=1
            print(f'预测为{forecastLabel},预测错误')

    print('-------------------------------------------------------------')
    print(f'决策树预测准确率为:'+str(T/(T+N)))
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107
  • 108
  • 109
  • 110
  • 111
  • 112
  • 113
  • 114
  • 115
  • 116
  • 117
  • 118
  • 119
  • 120
  • 121
  • 122
  • 123
  • 124
  • 125
  • 126
  • 127
  • 128
  • 129
  • 130
  • 131
  • 132
  • 133
  • 134
  • 135
  • 136
  • 137
  • 138
  • 139
  • 140
  • 141
  • 142
  • 143
  • 144
  • 145
  • 146
  • 147
  • 148
  • 149
  • 150
  • 151
  • 152
  • 153
  • 154
  • 155
  • 156
  • 157
  • 158
  • 159
  • 160
  • 161
  • 162
  • 163
  • 164
  • 165
  • 166
  • 167
  • 168
  • 169
  • 170
  • 171
  • 172
  • 173
  • 174
  • 175
  • 176
  • 177
  • 178
  • 179
  • 180
  • 181

四 实验结果与分析

运行程序后的控制台输出:

数据离散化字典:
{
    "SepalLength": {
        "0-5.5": 0,
        "5.5-6.3": 1,
        "6.3-inf": 2
    },
    "SepalWidth": {
        "0.3.2": 0,
        "3.2-inf": 1
    },
    "PetalLength": {
        "0-2": 0,
        "2-4.9": 1,
        "4.9-inf": 2
    },
    "PetalWidth": {
        "0-0.6": 0,
        "0.6-1.7": 1,
        "1.7-inf": 2
    }
}
随机训练集index: 
[38, 28, 26, 44, 40, 19, 2, 18, 7, 46, 68, 79, 50, 97, 65, 88, 69, 81, 92, 95, 126, 132, 137, 131, 110, 124, 133, 116, 125, 143]
决策树字典表示:
{
    "PetalLength": {
        "0": "Iris-setosa",
        "1": {
            "PetalWidth": {
                "1": "Iris-versicolor",
                "2": {
                    "SepalWidth": {
                        "0": "Iris-virginica",
                        "1": "Iris-versicolor"
                    }
                }
            }
        },
        "2": {
            "PetalWidth": {
                "1": {
                    "SepalLength": {
                        "1": "Iris-virginica",
                        "2": "Iris-versicolor"
                    }
                },
                "2": {
                    "SepalLength": {
                        "1": "Iris-virginica",
                        "2": {
                            "SepalWidth": {
                                "0": "Iris-virginica",
                                "1": "Iris-virginica"
                            }
                        }
                    }
                }
            }
        }
    }
}
================================================
离散化后的测试数据:SepalLength=0,SepalWidth=0,PetalLength=0,PetalWidth=0,真实标签=Iris-setosa
预测为Iris-setosa,预测正确
================================================
离散化后的测试数据:SepalLength=0,SepalWidth=1,PetalLength=0,PetalWidth=0,真实标签=Iris-setosa
预测为Iris-setosa,预测正确
================================================
离散化后的测试数据:SepalLength=0,SepalWidth=1,PetalLength=0,PetalWidth=0,真实标签=Iris-setosa
预测为Iris-setosa,预测正确
================================================
离散化后的测试数据:SepalLength=0,SepalWidth=1,PetalLength=0,PetalWidth=0,真实标签=Iris-setosa
预测为Iris-setosa,预测正确
================================================
离散化后的测试数据:SepalLength=0,SepalWidth=1,PetalLength=0,PetalWidth=0,真实标签=Iris-setosa
预测为Iris-setosa,预测正确
================================================
离散化后的测试数据:SepalLength=0,SepalWidth=1,PetalLength=0,PetalWidth=0,真实标签=Iris-setosa
预测为Iris-setosa,预测正确
================================================
离散化后的测试数据:SepalLength=0,SepalWidth=1,PetalLength=0,PetalWidth=0,真实标签=Iris-setosa
预测为Iris-setosa,预测正确
================================================
离散化后的测试数据:SepalLength=1,SepalWidth=1,PetalLength=0,PetalWidth=0,真实标签=Iris-setosa
预测为Iris-setosa,预测正确
================================================
离散化后的测试数据:SepalLength=0,SepalWidth=1,PetalLength=0,PetalWidth=0,真实标签=Iris-setosa
预测为Iris-setosa,预测正确
================================================
离散化后的测试数据:SepalLength=0,SepalWidth=1,PetalLength=0,PetalWidth=0,真实标签=Iris-setosa
预测为Iris-setosa,预测正确
================================================
离散化后的测试数据:SepalLength=1,SepalWidth=0,PetalLength=1,PetalWidth=1,真实标签=Iris-versicolor
预测为Iris-versicolor,预测正确
================================================
离散化后的测试数据:SepalLength=1,SepalWidth=0,PetalLength=1,PetalWidth=1,真实标签=Iris-versicolor
预测为Iris-versicolor,预测正确
================================================
离散化后的测试数据:SepalLength=2,SepalWidth=1,PetalLength=1,PetalWidth=1,真实标签=Iris-versicolor
预测为Iris-versicolor,预测正确
================================================
离散化后的测试数据:SepalLength=1,SepalWidth=0,PetalLength=1,PetalWidth=1,真实标签=Iris-versicolor
预测为Iris-versicolor,预测正确
================================================
离散化后的测试数据:SepalLength=2,SepalWidth=0,PetalLength=1,PetalWidth=1,真实标签=Iris-versicolor
预测为Iris-versicolor,预测正确
================================================
离散化后的测试数据:SepalLength=1,SepalWidth=0,PetalLength=1,PetalWidth=1,真实标签=Iris-versicolor
预测为Iris-versicolor,预测正确
================================================
离散化后的测试数据:SepalLength=1,SepalWidth=0,PetalLength=1,PetalWidth=1,真实标签=Iris-versicolor
预测为Iris-versicolor,预测正确
================================================
离散化后的测试数据:SepalLength=1,SepalWidth=0,PetalLength=1,PetalWidth=1,真实标签=Iris-versicolor
预测为Iris-versicolor,预测正确
================================================
离散化后的测试数据:SepalLength=1,SepalWidth=0,PetalLength=1,PetalWidth=1,真实标签=Iris-versicolor
预测为Iris-versicolor,预测正确
================================================
离散化后的测试数据:SepalLength=1,SepalWidth=0,PetalLength=1,PetalWidth=1,真实标签=Iris-versicolor
预测为Iris-versicolor,预测正确
================================================
离散化后的测试数据:SepalLength=1,SepalWidth=0,PetalLength=1,PetalWidth=2,真实标签=Iris-virginica
预测为Iris-virginica,预测正确
================================================
离散化后的测试数据:SepalLength=2,SepalWidth=0,PetalLength=2,PetalWidth=2,真实标签=Iris-virginica
预测为Iris-virginica,预测正确
================================================
离散化后的测试数据:SepalLength=2,SepalWidth=0,PetalLength=2,PetalWidth=2,真实标签=Iris-virginica
预测为Iris-virginica,预测正确
================================================
离散化后的测试数据:SepalLength=2,SepalWidth=1,PetalLength=2,PetalWidth=2,真实标签=Iris-virginica
预测为Iris-virginica,预测正确
================================================
离散化后的测试数据:SepalLength=2,SepalWidth=1,PetalLength=2,PetalWidth=2,真实标签=Iris-virginica
预测为Iris-virginica,预测正确
================================================
离散化后的测试数据:SepalLength=2,SepalWidth=1,PetalLength=2,PetalWidth=2,真实标签=Iris-virginica
预测为Iris-virginica,预测正确
================================================
离散化后的测试数据:SepalLength=2,SepalWidth=0,PetalLength=2,PetalWidth=1,真实标签=Iris-virginica
预测为Iris-versicolor,预测错误
================================================
离散化后的测试数据:SepalLength=2,SepalWidth=0,PetalLength=2,PetalWidth=2,真实标签=Iris-virginica
预测为Iris-virginica,预测正确
================================================
离散化后的测试数据:SepalLength=2,SepalWidth=1,PetalLength=2,PetalWidth=2,真实标签=Iris-virginica
预测为Iris-virginica,预测正确
================================================
离散化后的测试数据:SepalLength=2,SepalWidth=1,PetalLength=2,PetalWidth=2,真实标签=Iris-virginica
预测为Iris-virginica,预测正确
-------------------------------------------------------------
决策树预测准确率为:0.9666666666666667
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107
  • 108
  • 109
  • 110
  • 111
  • 112
  • 113
  • 114
  • 115
  • 116
  • 117
  • 118
  • 119
  • 120
  • 121
  • 122
  • 123
  • 124
  • 125
  • 126
  • 127
  • 128
  • 129
  • 130
  • 131
  • 132
  • 133
  • 134
  • 135
  • 136
  • 137
  • 138
  • 139
  • 140
  • 141
  • 142
  • 143
  • 144
  • 145
  • 146
  • 147
  • 148
  • 149
  • 150
  • 151
  • 152
  • 153
  • 154

观察决策树的结构,SepalWidth=0SepalWidth=1时的决策结果相同,可以通过剪枝的操作减去多余的子树。同时对数据的分布进行分析,SepalWidthSepalLength的数据重合部分大,对这两个数据进行预剪枝的效果可能更好。

参考

【1】机器学习实战

【2】 机器学习项目实战–基于鸢尾花数据集(python代码,多种算法对比:决策树、SVM、k近邻)_西南交大-Liu_z的博客-CSDN博客

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/AllinToyou/article/detail/604237
推荐阅读
相关标签
  

闽ICP备14008679号