当前位置:   article > 正文

机器学习——西瓜树决策树id3算法,matlab代码不能运行,来砍我_决策树西瓜集matlab

决策树西瓜集matlab

决策树的生成是一个递归过程。在决策树基本算法中,有三种情形会导致递归返回:

当前结点包含的样本全属于同一类别,无需划分

当前属性集为空,或是所有样本在所有属性上取值相同,无法划分

(3)当前结点包含的样本集合为空,不能划分。

在第(2)中情况下,我们把当前结点标记为叶结点,并将其类别设定为该结点所含样本最多的类别,即在利用当前结点的后验分布;

在第(3)种情况下,同样把当前结点标记为叶结点,但将其类别设定为其父结点所含样本最多的类别,即把父结点的样本分布作为当前结点的先验分布。

ID3(Iterative Dichotomiser 3)是一种经典的决策树学习算法,由 Ross Quinlan 在 1986 年提出。ID3 算法主要用于解决分类问题,它通过对数据集进行递归划分来构建决策树。

ID3 算法的基本思想是在每个节点上选择最佳的特征进行分割,以使得得到的子集尽可能地“纯净”。纯净度通常用信息增益(Information Gain)或基尼指数(Gini Index)等指标来衡量,这些指标可以反映数据集的纯度或不确定性程度。

ID3 算法的步骤如下:

  1. 若所有实例属于同一类,则将当前节点标记为叶节点,并以该类别作为节点的类别标签。
  2. 若特征集为空集,或者当前节点的所有实例属于同一类,则将当前节点标记为叶节点,并以当前节点中实例数最多的类别作为节点的类别标签。
  3. 否则,计算每个特征的信息增益(或基尼指数),选择信息增益(或基尼指数)最大的特征作为当前节点的划分特征。
  4. 根据选定的划分特征将数据集划分为多个子集,并为每个子集递归地应用上述步骤,构建子节点。

matlab代码如下:main.m

  1. clc;clear;
  2. data_name = 'xigua'; %数据名称,
  3. data_r = 'csv'; %数据格式
  4. dir_ = cd; %目录,默认同文件下
  5. %% 数据预处理
  6. filename = fullfile([dir_ '\' data_name '.' data_r]);%文件名
  7. % 获取属性标签
  8. data = readtable(filename,"VariableNamingRule","preserve");
  9. size_data = size(data); %数据大小
  10. if isempty(data.Properties.VariableDescriptions) %英文属性值,无描述
  11. labels = data.Properties.VariableNames(1,1:size(data,2)-1); %获取属性值,必须是英文
  12. else %使用原始列标题以支持中文属性值
  13. labels = cell(1,size_data(2)-1);
  14. for i = 1:size_data(2)-1
  15. VariableDescriptions = data.Properties.VariableDescriptions;%获取原始名称
  16. labels{i} = VariableDescriptions{i}(9:length(VariableDescriptions{1})-1);%添加标签
  17. end
  18. end
  19. % 获取数据集
  20. opts = detectImportOptions(filename);%检查数据
  21. opts = setvartype(opts,opts.VariableNames,'char');
  22. data = readtable(filename,opts) %读入数据
  23. dataset = data{:,:}; %获取数据集
  24. % 调用函数
  25. myTree = ID3(dataset,labels);%生成决策树,并画出来

另创一个文件ID3.m,  ID3代码:

  1. function myTree = ID3(dataset,labels)
  2. % 输入参数:
  3. % dataset:数据集,元胞数组或字符串数组
  4. % labels:属性标签,元胞数组或字符串数组
  5. myTree = createTree(dataset,labels); %生成决策树
  6. [nodeids,nodevalue,branchvalue] = print_tree(myTree); %解析决策树
  7. tree_plot(nodeids,nodevalue,branchvalue); %画出
  8. end
  9. %% 使用熵最小策略构建决策树
  10. function myTree = createTree(dataset,labels)
  11. % 数据为空,则报错
  12. if(isempty(dataset))
  13. error('必须提供数据!')
  14. end
  15. size_data = size(dataset);
  16. % 数据大小与属性数量不一致,则报错
  17. if (size_data(2)-1)~=length(labels)
  18. error('属性数量与数据集不一致!')
  19. end
  20. classList = dataset(:,size_data(2));
  21. %全为同一类,熵为0,返回
  22. if length(unique(classList))==1
  23. myTree = char(classList(1));
  24. return
  25. end
  26. %%属性集为空,应该用找最多数的那一类,这里取值NONE
  27. if size_data(2) == 1
  28. myTree = 'NONE';
  29. %myTree = char(classList(1));
  30. return
  31. end
  32. % 选取特征属性
  33. bestFeature = chooseFeature(dataset);
  34. bestFeatureLabel = char(labels(bestFeature));
  35. % 构建树
  36. myTree = containers.Map;
  37. leaf = containers.Map;
  38. % 该属性下的不同取值
  39. featValues = dataset(:,bestFeature);
  40. uniqueVals = unique(featValues);
  41. % 删除该属性
  42. labels=[labels(1:bestFeature-1) labels(bestFeature+1:length(labels))]; %删除该属性
  43. % 对该属性下不同取值,递归调用ID3函数
  44. for i=1:length(uniqueVals)
  45. subLabels = labels(:)';
  46. value = char(uniqueVals(i));
  47. subdata = splitDataset(dataset,bestFeature,value);%数据集分割
  48. leaf(value) = createTree(subdata,subLabels); %递归调用
  49. myTree(char(bestFeatureLabel)) = leaf;
  50. end
  51. end
  52. %% 计算信息熵
  53. function shannonEnt = calShannonEnt(dataset)
  54. data_size = size(dataset);
  55. labels = dataset(:,data_size(2));
  56. numEntries = data_size(1);
  57. labelCounts = containers.Map;
  58. for i = 1:length(labels)
  59. label = char(labels(i));
  60. if labelCounts.isKey(label)
  61. labelCounts(label) = labelCounts(label)+1;
  62. else
  63. labelCounts(label) = 1;
  64. end
  65. end
  66. shannonEnt = 0.0;
  67. for key = labelCounts.keys
  68. key = char(key);
  69. labelCounts(key);
  70. prob = labelCounts(key) / numEntries;
  71. shannonEnt = shannonEnt - prob*(log(prob)/log(2));
  72. end
  73. end
  74. % 选择熵最小的属性特征
  75. function bestFeature=chooseFeature(dataset,~)
  76. baseEntropy = calShannonEnt(dataset);
  77. data_size = size(dataset);
  78. numFeatures = data_size(2) - 1;
  79. minEntropy = 2.0;
  80. bestFeature = 0;
  81. for i = 1:numFeatures
  82. uniqueVals = unique(dataset(:,i));
  83. newEntropy = 0.0;
  84. for j=1:length(uniqueVals)
  85. value = uniqueVals(j);
  86. subDataset = splitDataset(dataset,i,value);
  87. size_sub = size(subDataset);
  88. prob = size_sub(1)/data_size(1);
  89. newEntropy = newEntropy + prob*calShannonEnt(subDataset);
  90. end
  91. if newEntropy<minEntropy
  92. minEntropy = newEntropy;
  93. bestFeature = i;
  94. end
  95. end
  96. end
  97. % 分割数据集,取出该特征值为value的所有样本,并去除该属性
  98. function subDataset = splitDataset(dataset,axis,value)
  99. subDataset = {};
  100. data_size = size(dataset);
  101. for i=1:data_size(1)
  102. data = dataset(i,:);
  103. if string(data(axis)) == string(value)
  104. subDataset = [subDataset;[data(1:axis-1) data(axis+1:length(data))]];
  105. end
  106. end
  107. end
  108. % 层序遍历决策树,返回nodeids,nodevalue,branchvalue
  109. function [nodeids_,nodevalue_,branchvalue_] = print_tree(tree)
  110. nodeids(1) = 0;
  111. nodeid = 0;
  112. nodevalue={};
  113. branchvalue={};
  114. queue = {tree} ;%创建队列
  115. while ~isempty(queue)
  116. node = queue{1}; %取数据
  117. queue(1) = []; %出队
  118. if string(class(node))~="containers.Map" %叶节点
  119. nodeid = nodeid+1;
  120. nodevalue = [nodevalue,{node}];
  121. elseif length(node.keys)==1 %节点
  122. nodevalue = [nodevalue,node.keys];
  123. node_info = node(char(node.keys));
  124. nodeid = nodeid+1;
  125. branchvalue = [branchvalue,node_info.keys];
  126. for i=1:length(node_info.keys)
  127. nodeids = [nodeids,nodeid];
  128. end
  129. end
  130. if string(class(node))=="containers.Map"
  131. keys = node.keys();
  132. for i = 1:length(keys)
  133. key = keys{i};
  134. queue=[queue,{node(key)}]; %入队
  135. end
  136. end
  137. nodeids_=nodeids;
  138. nodevalue_=nodevalue;
  139. branchvalue_ = branchvalue;
  140. end
  141. end
  142. %% 参考treeplot,画图
  143. function tree_plot(p,nodevalue,branchvalue)
  144. [x,y,h] = treelayout(p); %x:横坐标,y:纵坐标;h:树的深度
  145. f = find(p~=0); %非0节点
  146. pp = p(f); %非0值
  147. X = [x(f); x(pp); NaN(size(f))];
  148. Y = [y(f); y(pp); NaN(size(f))];
  149. X = X(:);
  150. Y = Y(:);
  151. n = length(p);
  152. if n<500
  153. hold on;
  154. %plot(x,y,'ro',X,Y,'r-')
  155. set(gcf,'Position',get(0,'ScreenSize'))
  156. plot(X,Y,'r-');
  157. nodesize = length(x);
  158. for i=1:nodesize
  159. t = text(x(i),y(i),nodevalue{1,i},'HorizontalAlignment','center');
  160. t.EdgeColor = 'blue';
  161. t.BackgroundColor = 'w';
  162. end
  163. for i=2:nodesize
  164. j = 3*i-5;%获取连线坐标
  165. t=text((X(j)+X(j+1))/2,(Y(j)+Y(j+1))/2,branchvalue{1,i-1},'HorizontalAlignment','center');
  166. t.BackgroundColor = 'w';
  167. end
  168. hold off
  169. else
  170. plot(X,Y,'r-');
  171. end
  172. xlabel(['height = ' int2str(h)]);
  173. axis([0 1 0 1]);
  174. end

数据集用excel表格,表格名称用xigua.csv

 

仔细按照流程去设置,不能运行,请来砍我!!!!!!!!!

 

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/天景科技苑/article/detail/768546
推荐阅读
相关标签
  

闽ICP备14008679号