当前位置:   article > 正文

随机森林原理及其用于分类问题的matlab实现_matlab trainindices

matlab trainindices

随机森林

随机森林是多个决策树的集成学习,每个决策树用bagging的方法选数据集,并且在选择最佳属性划分的时候随机划分一些属性进行分类,比单个分类器效果更好,泛化能力更强。

代码解释

1.用结构体的嵌套实现树的结构。
2.makerandomtree递归的创建树。
3.可自动适应不同的类别标签,不同的属性个数和不同的类别个数。
4.函数ent(D)返回D的信息熵

代 码

树的主体:

function tree=makerandomtree(D,a) 
tree=struct('isnode',1,'a',0.0,'mark',0.0,'child',{});%isnode判断是否是分支还是叶子,a表示节点属性,若节点是叶子,a表示分类结果,child是孩子
tree(1).a=1;%给tree分配一个确切的内存
if length(unique(D(:,end)))==1%D中样本属于同一类别
    tree.isnode=0;%把tree标记为树叶
    tree.a=D(1,end);%把tree的类别标记为D的类别
    return
end
if sum(a)==0 ||length(D)==0 %属性划分完毕
    tree.isnode=0;%把tree标记为树叶
    tree.a=mode(D(:,end));%把tree的类别标记为D出现最多的类别
    return
end
for i=1:length(a)
    if a(i)==1
        if length(unique(D(:,i)))==1
            tree.isnode=0;%把tree标记为树叶
            tree.a=mode(D(:,end));%把tree的类别标记为D出现最多的类别
            return
        end
    end
end
k=ceil(log2(sum(a)));%随机选k个属性进行学习
randomindices=zeros(length(a),1); %随机去掉属性的索引,结束之后要恢复随机去掉的属性
if k>1
i=1;
su=sum(a);
while 1 %随机去掉一些属性,使得剩下的属性是k个
    random1=randperm(length(a),1);
    if a(random1)==1
        randomindices(random1)=1;
        a(random1)=0;
        i=i+1;
    end
    if i==(su-k+1)
        break;
    end
end
end

gain=zeros(length(a),1); %保存每个属性的信息增益
best=zeros(length(a),1); %保存每个属性的最佳划分

for i=1:length(a)
    if a(i)==1
        t=D(:,i);
        t=sort(t);
    
        gain1=zeros(length(t)-1,1);
        for j=1:length(t)-1%二分划分
            ta=(t(j)+t(j+1))/2;
         
            Df=D(D(:,i)<=ta,:);
            Dz=D(D(:,i)>ta,:);
            gain1(j)=ent(D)-(ent(Df)*length(Df(:,end))/length(D(:,end))+ent(Dz)*length(Dz(:,end))/length(D(:,end)));
        end
     
        [gain(i),j]=max(gain1);
        ta=(t(j)+t(j+1))/2;
        best(i)=ta; 
    end
end
[~,m]=max(gain);%选择信息增益最大的属性
D1=D(D(:,m)<=best(m),:);
D2=D(D(:,m)>best(m),:);
a(m)=0;
for i=1:length(a)  %恢复随机去掉的属性
    if randomindices(i)==1
        a(i)=1;
    end
end
tree.a=best(m); %建立分支
tree.mark=m;
% disp('****************************')
% tree.a
% tree.mark
tree.isnode=1;
tree.child(1)=makerandomtree(D1,a);
tree.child(2)=makerandomtree(D2,a);

end
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81

计算ent

function f=ent(D)%计算信息商
l=unique(D(:,end));
if length(D)==0
    f=0;
    return
end
f=0;
t=zeros(length(l),1);
for i=1:length(D(:,end))
    for j=1:length(l)
        if D(i,end)==l(j)
            t(j)=t(j)+1;
            break;
        end
    end
end
n=length(D(:,end));
for i=1:length(l)
    f=f+(t(i)/n)*log2(t(i)/n);
end
f=-f;
end
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22

主函数

function randomforest()
clc
clear all
T=3;%bagging采样的次数
M = importdata('D:\毕业设计\数据集1\australian.txt');  %读取数据
[sm,sn]=size(M);
% for i=1:sm             %归一化
%     mins=min(M(i,1:sn-1));
%     maxs=max(M(i,1:sn-1));
%     for j=1:sn-1
%         M(i,j)=2*(M(i,j)-mins)/(maxs-mins)-1;
%     end
% end
indices=crossvalind('Kfold',M(1:sm,sn),10); %十折交叉,划分训练集和测试集
testindices=(indices==1); %测试集索引
trainindices=~testindices;%训练集索引
trainset=M(trainindices,:); %获取训练集
testset=M(testindices,:);%获取测试集
[testm,~]=size(testset);
[trainm,trainn]=size(trainset);

predict=zeros(trainm,T);
for t=1:T %开始bagging采样
    D=[];%训练集
    for i=1:trainm%采样
        k=randperm(trainm,1);
        D=[D;trainset(k,:)];
    end
    [~,sn]=size(D);
    a=ones(sn-1,1);%属性集合a,1代表该属性未被划分
    
    tree=makerandomtree(D,a);%递归构造简单决策树
    
    for i=1:trainm
        treet=tree;
        while 1
           
            if treet.isnode==0
                predict(i,t)=treet.a;
                break;
            end
            if trainset(i,treet.mark)<=treet.a
                treet=treet.child(1);
            else
                treet=treet.child(2);
            end
          
        end
    end
    
end
acc=0;
for i=1:trainm
    if trainset(i,end)==mode(predict(i,:))
        acc=acc+1;
    end
end
acc=acc/trainm
end

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
本文内容由网友自发贡献,转载请注明出处:【wpsshop博客】
推荐阅读
相关标签
  

闽ICP备14008679号