赞
踩
function shannonEnt=CalcEntropy(typeslist)
% 根据types列表,计算数据集的熵
% typeslist: 数据集的属性列表
% shannonEnt: 熵值计算
shannonEnt = 0;
Length = length(typeslist);
itemList = unique(typeslist); % 去重
pNum = length(itemList);
for i = 1:pNum
itemLength = length(find(typeslist==itemList(i)));
pItem = itemLength/Length;
shannonEnt = shannonEnt-pItem*log2(pItem);
end
1.2. splitDataset
function retData=SplitDataset(data,axis,value)
% 按照给定的特征划分数据集
% data: 待划分数据集
% axis: 列数
% value: 特征的返回值
% retData: 划分后数据集
retData = [];
[m,n] = size(data);
for i = 1:m
if data(i,axis) == value
retData = [retData;data(i,:)];
end
end
retData(:,axis) = [];
1.3. chooseBestFeatureToSplit
function bestFeature=chooseBestFeatureToSplit(dataset) % 找到最好的划分指标 % dataset: 数据集 % bestFeature: 最优指标 num数值 [m,n] = size(dataset); numFeatures = n-1; originEntropy = CalcEntropy(dataset(:,n)); % 初始熵 bestInfoGain = 0.0; % 熵变化最大值 bestFeature = -1; % 最佳特征 for i = 1:numFeatures uniqueVals = unique(dataset(:,i)); tmpEntropy = 0.0; for j = 1:length(uniqueVals) subDataset = SplitDataset(dataset,i,uniqueVals(j)); prob = length(subDataset(:,1))/m; tmpEntropy = tmpEntropy+prob*CalcEntropy(subDataset); end infoGain = originEntropy-tmpEntropy; if infoGain > bestInfoGain bestInfoGain = infoGain; bestFeature = i; end end
1.4. createTree
function tree=createTree(fatherNode, level, Edge, dataset, labels) % 递归创建决策树 % fatherNode: 父节点 % level: 所属层次 % Edge: 边的属性 % dataset: 数据集 % labels: 特征属性 global tree; branch = struct('level',level+1,'fatherNode',fatherNode,'Edge',Edge,'Node',[]); [m,n] = size(dataset); typesList = dataset(:,n); % 第一种情况 数据集只剩一种type if length(unique(typesList)) == 1 branch.Node = typesList(1); tree = [tree branch]; return; end % 第二种情况 遍历完所有特征 if length(dataset(1,:)) == 1 branch.Node = mode(typeslist); % 取众数 tree = [tree branch]; return; end % 第三种情况 bestFeat = chooseBestFeatureToSplit(dataset); bestFeatLabel = labels(bestFeat); branch.Node = bestFeatLabel; tree = [tree branch]; labels(strcmp(labels,bestFeatLabel)) = []; featVals = unique(dataset(:,bestFeat)); for i = 1:length(featVals) createTree(branch.Node, branch.level, featVals(i), SplitDataset(dataset,bestFeat,featVals(i)), labels); end
1.5. decisionTree
function decisionTreeModel=decisionTree(dataset,labels)
% 决策树模型训练主函数
% dataset: 数据集
% labels: 特征属性
% decisionTreeModel: 保存模型数据的struct数组
global tree;
tree=struct('level',-1,'fatherNode',[],'Edge',[],'Node',[]);
createTree('root',-1,-1,dataset,labels);
tree(1) = [];
tree(1) = [];
model.Node = tree;
decisionTreeModel = model;
1.6. modelPredict
function type=modelPredict(model,sample,labels,typesName) % 训练好的模型进行预测 % model: 训练好的模型 % sample: 待预测样本 % typesName: 类别名称 % type: 输出类别 Nodes = model.Node; rootNode = Nodes(1); head = rootNode.fatherNode; level = 1; for i = 1:length(Nodes) if Nodes(i).level == level if Nodes(i).Edge == sample(find(labels==head)) if length(find(typesName==double(Nodes(i).Node))) == 1 type = Nodes(i).Node; break; else head = Nodes(i).Node; level = level+1; end end end end
Matlab自带函数实现
2.1. dataPreprocess
function [xtrain,ytrain,xtest,ytest] = dataPreprocess() % 数据预处理 % xtrain,ytrain,xtest,ytest:x训练集,测试集,y训练集,测试集 load fisheriris x = meas; y = species; % 数据划分 train_index = randsample(150,120,false); test_index = randsample(150,30,false); xtrain = x(train_index,:); xtest = x(test_index,:); ytrain = y(train_index,:); ytest = y(test_index,:); % 数据归一化 % Flattened1 = xtrain'; % MappedFlattened1 = mapminmax(Flattened1); % 默认行归一 % xtrain = MappedFlattened1'; % Flattened2 = xtest'; % MappedFlattened2 = mapminmax(Flattened2); % 默认行归一 % xtest = MappedFlattened2';
2.2. modelTrain
function model = modelTrain(xtrain,ytrain)
% 模型训练
model = fitctree(xtrain,ytrain);
view(model, 'Mode', 'graph');
2.3. modelTrainOpt
function model = modelTrainOpt(xtrain,ytrain)
% 模型训练 Optimize Classification Tree
model = fitctree(xtrain,ytrain,'OptimizeHyperparameters','auto');
view(model, 'Mode', 'graph');
2.4. modelPredict
function [train_acc,test_acc]=modelPredict(model,xtrain,ytrain,xtest,ytest) % xtrain,ytrain,xtest,ytest train_pre = predict(model,xtrain); test_pre = predict(model,xtest); train_right = 0; for i = 1:length(train_pre) if isequal(train_pre(i),ytrain(i)) train_right = train_right + 1; end end test_right = 0; for i = 1:length(test_pre) if isequal(test_pre(i),ytest(i)) test_right = test_right + 1; end end train_acc = train_right/length(ytrain); test_acc = test_right/length(ytest);
本博客为个人笔记,欢迎交流!!!!
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。