赞
踩
决策树的生成是一个递归过程。在决策树基本算法中,有三种情形会导致递归返回:
当前结点包含的样本全属于同一类别,无需划分
当前属性集为空,或是所有样本在所有属性上取值相同,无法划分
(3)当前结点包含的样本集合为空,不能划分。
在第(2)中情况下,我们把当前结点标记为叶结点,并将其类别设定为该结点所含样本最多的类别,即在利用当前结点的后验分布;
在第(3)种情况下,同样把当前结点标记为叶结点,但将其类别设定为其父结点所含样本最多的类别,即把父结点的样本分布作为当前结点的先验分布。
ID3(Iterative Dichotomiser 3)是一种经典的决策树学习算法,由 Ross Quinlan 在 1986 年提出。ID3 算法主要用于解决分类问题,它通过对数据集进行递归划分来构建决策树。
ID3 算法的基本思想是在每个节点上选择最佳的特征进行分割,以使得得到的子集尽可能地“纯净”。纯净度通常用信息增益(Information Gain)或基尼指数(Gini Index)等指标来衡量,这些指标可以反映数据集的纯度或不确定性程度。
ID3 算法的步骤如下:
matlab代码如下:main.m
- clc;clear;
-
- data_name = 'xigua'; %数据名称,
-
- data_r = 'csv'; %数据格式
- dir_ = cd; %目录,默认同文件下
-
- %% 数据预处理
- filename = fullfile([dir_ '\' data_name '.' data_r]);%文件名
- % 获取属性标签
- data = readtable(filename,"VariableNamingRule","preserve");
- size_data = size(data); %数据大小
- if isempty(data.Properties.VariableDescriptions) %英文属性值,无描述
-
- labels = data.Properties.VariableNames(1,1:size(data,2)-1); %获取属性值,必须是英文
- else %使用原始列标题以支持中文属性值
- labels = cell(1,size_data(2)-1);
- for i = 1:size_data(2)-1
- VariableDescriptions = data.Properties.VariableDescriptions;%获取原始名称
- labels{i} = VariableDescriptions{i}(9:length(VariableDescriptions{1})-1);%添加标签
- end
- end
-
- % 获取数据集
- opts = detectImportOptions(filename);%检查数据
- opts = setvartype(opts,opts.VariableNames,'char');
- data = readtable(filename,opts) %读入数据
- dataset = data{:,:}; %获取数据集
- % 调用函数
- myTree = ID3(dataset,labels);%生成决策树,并画出来

另创一个文件ID3.m, ID3代码:
-
- function myTree = ID3(dataset,labels)
- % 输入参数:
- % dataset:数据集,元胞数组或字符串数组
- % labels:属性标签,元胞数组或字符串数组
-
- myTree = createTree(dataset,labels); %生成决策树
- [nodeids,nodevalue,branchvalue] = print_tree(myTree); %解析决策树
- tree_plot(nodeids,nodevalue,branchvalue); %画出
- end
-
- %% 使用熵最小策略构建决策树
- function myTree = createTree(dataset,labels)
-
- % 数据为空,则报错
- if(isempty(dataset))
- error('必须提供数据!')
- end
- size_data = size(dataset);
- % 数据大小与属性数量不一致,则报错
- if (size_data(2)-1)~=length(labels)
- error('属性数量与数据集不一致!')
- end
-
- classList = dataset(:,size_data(2));
- %全为同一类,熵为0,返回
- if length(unique(classList))==1
- myTree = char(classList(1));
- return
- end
- %%属性集为空,应该用找最多数的那一类,这里取值NONE
- if size_data(2) == 1
- myTree = 'NONE';
- %myTree = char(classList(1));
- return
- end
- % 选取特征属性
- bestFeature = chooseFeature(dataset);
- bestFeatureLabel = char(labels(bestFeature));
- % 构建树
- myTree = containers.Map;
- leaf = containers.Map;
- % 该属性下的不同取值
- featValues = dataset(:,bestFeature);
- uniqueVals = unique(featValues);
- % 删除该属性
- labels=[labels(1:bestFeature-1) labels(bestFeature+1:length(labels))]; %删除该属性
- % 对该属性下不同取值,递归调用ID3函数
- for i=1:length(uniqueVals)
- subLabels = labels(:)';
- value = char(uniqueVals(i));
- subdata = splitDataset(dataset,bestFeature,value);%数据集分割
- leaf(value) = createTree(subdata,subLabels); %递归调用
- myTree(char(bestFeatureLabel)) = leaf;
- end
- end
-
- %% 计算信息熵
- function shannonEnt = calShannonEnt(dataset)
- data_size = size(dataset);
- labels = dataset(:,data_size(2));
- numEntries = data_size(1);
- labelCounts = containers.Map;
- for i = 1:length(labels)
- label = char(labels(i));
- if labelCounts.isKey(label)
- labelCounts(label) = labelCounts(label)+1;
- else
- labelCounts(label) = 1;
- end
- end
- shannonEnt = 0.0;
- for key = labelCounts.keys
- key = char(key);
- labelCounts(key);
- prob = labelCounts(key) / numEntries;
- shannonEnt = shannonEnt - prob*(log(prob)/log(2));
- end
- end
-
- % 选择熵最小的属性特征
- function bestFeature=chooseFeature(dataset,~)
- baseEntropy = calShannonEnt(dataset);
- data_size = size(dataset);
- numFeatures = data_size(2) - 1;
- minEntropy = 2.0;
- bestFeature = 0;
- for i = 1:numFeatures
- uniqueVals = unique(dataset(:,i));
- newEntropy = 0.0;
- for j=1:length(uniqueVals)
- value = uniqueVals(j);
- subDataset = splitDataset(dataset,i,value);
- size_sub = size(subDataset);
- prob = size_sub(1)/data_size(1);
-
- newEntropy = newEntropy + prob*calShannonEnt(subDataset);
- end
-
- if newEntropy<minEntropy
- minEntropy = newEntropy;
- bestFeature = i;
- end
- end
- end
- % 分割数据集,取出该特征值为value的所有样本,并去除该属性
- function subDataset = splitDataset(dataset,axis,value)
- subDataset = {};
- data_size = size(dataset);
- for i=1:data_size(1)
- data = dataset(i,:);
- if string(data(axis)) == string(value)
- subDataset = [subDataset;[data(1:axis-1) data(axis+1:length(data))]];
- end
- end
- end
- % 层序遍历决策树,返回nodeids,nodevalue,branchvalue
- function [nodeids_,nodevalue_,branchvalue_] = print_tree(tree)
- nodeids(1) = 0;
- nodeid = 0;
- nodevalue={};
- branchvalue={};
-
- queue = {tree} ;%创建队列
- while ~isempty(queue)
- node = queue{1}; %取数据
- queue(1) = []; %出队
- if string(class(node))~="containers.Map" %叶节点
- nodeid = nodeid+1;
- nodevalue = [nodevalue,{node}];
- elseif length(node.keys)==1 %节点
- nodevalue = [nodevalue,node.keys];
- node_info = node(char(node.keys));
- nodeid = nodeid+1;
- branchvalue = [branchvalue,node_info.keys];
- for i=1:length(node_info.keys)
- nodeids = [nodeids,nodeid];
- end
- end
-
- if string(class(node))=="containers.Map"
- keys = node.keys();
- for i = 1:length(keys)
- key = keys{i};
- queue=[queue,{node(key)}]; %入队
- end
- end
- nodeids_=nodeids;
- nodevalue_=nodevalue;
- branchvalue_ = branchvalue;
- end
- end
- %% 参考treeplot,画图
- function tree_plot(p,nodevalue,branchvalue)
-
- [x,y,h] = treelayout(p); %x:横坐标,y:纵坐标;h:树的深度
- f = find(p~=0); %非0节点
- pp = p(f); %非0值
- X = [x(f); x(pp); NaN(size(f))];
- Y = [y(f); y(pp); NaN(size(f))];
- X = X(:);
- Y = Y(:);
- n = length(p);
- if n<500
- hold on;
- %plot(x,y,'ro',X,Y,'r-')
- set(gcf,'Position',get(0,'ScreenSize'))
- plot(X,Y,'r-');
- nodesize = length(x);
- for i=1:nodesize
- t = text(x(i),y(i),nodevalue{1,i},'HorizontalAlignment','center');
- t.EdgeColor = 'blue';
- t.BackgroundColor = 'w';
- end
- for i=2:nodesize
- j = 3*i-5;%获取连线坐标
- t=text((X(j)+X(j+1))/2,(Y(j)+Y(j+1))/2,branchvalue{1,i-1},'HorizontalAlignment','center');
- t.BackgroundColor = 'w';
- end
- hold off
- else
- plot(X,Y,'r-');
- end
- xlabel(['height = ' int2str(h)]);
- axis([0 1 0 1]);
- end

数据集用excel表格,表格名称用xigua.csv
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。