当前位置:   article > 正文

基于随机森林的分类算法的matlab简单实现_matlab随机森林分类

matlab随机森林分类

说明

关于熵、信息增益、信息增益比、基尼指数的计算不再写出

决策树构建——使用最简单的ID3算法

1.输入:训练数据集D,特征集A,阈值(后面会说明数据集的内容)
2.输出:决策树T
(1)若D中所有实例属于同一类Ck,则T为单结点树,并将Ck作为该结点的类标记,返回T;
(2)若A是空集,则T为单结点树,并将D中实例数最大的类Ck作为该结点的类标记,返回T;
(3)否则,计算A中各特征对D的信息增益,选择信息增益最大的特征Ag;
(4)如果Ag的信息增益小于阈值,则置T为单结点树,并将D中实例数最大的类Ck作为该结点的类标记,返回T;
(5)否则,对Ag的每一可能值ai,依Ag=ai将D分割为若干非空子集Di,将Di中实例数最大的类作为标记,构建子结点,由结点及其子结点构成树T,返回T;
(6)对第i个子结点,以Di为训练集,以A-{Ag}为特征集,递归的调用步(1)到步(5),得到子树Ti,返回Ti。
注:如果采用C4.5算法构建决策树,则只需要把ID3算法中的信息增益更换为信息增益比。

数据集说明

训练集1500个数据,测试集为500个数据。
表示的是银行对是否同意贷款的分类预测。
如图:
部分数据集
第一列:1代表青年;2代表中年;3代表老年。
第二列:2代表收入最高;1代表收入一般;0代表收入低。
第三列:1代表有房;0代表没房
第四列:1表示信用很好;2表示信用好;3表示信用一般。
第五列:1代表男;0代表女。
第六列:1表示城市;0表示农村。
第七列:1表示同意贷款;2表示还需考虑;3表示不同意。

随机森林的构建

1.随机森林是集成学习的一种,它的基础单元是决策树。
2.随机森林有两个随机,(假设训练集数目为N)一是训练决策树时随机有放回的抽取n(n<N)个数据,二是随机抽取指定数目m的特征(m<M);
3.在本例中,训练集1500个数据,每次有放回抽取900个数据,有6个特征,每次随机不重复的抽取5个特征,训练一棵决策树。
4.我们可以选择构建多棵决策树,每一棵决策树可对每一个数据进行分类,得出一个结果;我们对所有决策树的对同一数据的结果作出一个投票处理,即少数服从多数,投票的结果就是随机森林对这一数据的分类结果。
5.如果不确定如何选择n,m以及决策树的棵树,可以采用其他算法针对随机森林分类的准确率进一步优化。

matlab代码及结果

1.主程序RandomForest.m

clear all;
clc;
rnode=cell(3,1);%3*1的单元数组
rchild_value=cell(3,1);%3*1的单元数组
rchild_node_num=cell(3,1);%3*1的单元数组
sn=900; %随机可重复的抽取sn个样本
tn=30;  %森林中决策树的棵树
S=xlsread('loan_train.xls');
%% 样本训练采用随机森林和ID3算法构建决策森林
    for j=1:tn         %训练十棵决策树
        Feature=randperm(6,5);%随机选取数个特征
        Sample_num=randi([1,1500],1,sn);%从1至1000内随机抽取sn个样本,1*sn矩阵
        SData=S(Sample_num,:);      %一棵树的训练集
        [node,child_value,child_node_num]=ID3(SData,Feature);
        rnode{j,1}=node;
        rchild_value{j,1}=child_value;
        rchild_node_num{j,1}=child_node_num;
    end
    
%% 样本测试
    T=xlsread('loan_test.xls');
    %TData=roundn(T,-1);
    TData=roundn(T(:,1:end-1),-1);
    len=length(TData(:,1));%测试样本的数目
    type=zeros(len,1);
    for j=1:len
        %统计函数,对输入的测试向量进行投票,然后统计出选票最高的标签类型输出
        [type(j)]=statistics(tn,rnode,rchild_value,rchild_node_num,TData(j,:));
    end
    xlswrite('loan_result.xls',[T type]);%输出测试报告
    gd = T(:,end);
    count = sum(type==gd);
    fprintf('共有%d个样本,判断正确的有%d\n准确率为:百分之%s\n',len,count,count/len*100);
  
    

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36

2.决策树构建程序——ID3.m

% 函数返回一棵决策树
function  [node,child_value,child_node_num]=ID3(S,Feature)%%%
    clear clear global node child_value child_node_num;
    global node child_value child_node_num
    %S=xlsread('aaa.xls');%%%
    DValue=S(:,1:6);    
    DValue=roundn(DValue,-1);%四舍五入保留一位小数
    CN=S(:,7);
    CN=num2str(CN);%将标签设为string(字符串)型
    for i=1:length(CN)
        A(i)=i;
    end
    [Feature,~]=sort(Feature);
    ClassPNum=Feature;
    CLASSPNUM=[1 2 3 4 5 6];
    [CHA,~] = setdiff(CLASSPNUM,ClassPNum) ;
    DValue(:,CHA)=0;%把没用到的特征置0
    m=0;
    [node,child_value,child_node_num]=TreeNode( DValue, CN, A, ClassPNum,m ); 

end

% 生成树结点
% DValue--前6列数据
% A--参与划分的行号
% CN--属性值的集合(第7列数据)
% ClassPNum为划分的剩余属性编号
% 当前node的父亲结点为node{m}
function [node,child_value,child_node_num]=TreeNode( DValue, CN, A, ClassPNum,m)
    global node child_value child_node_num
    n=length(node);
    if m>0
        %如果父亲结点存在,将本结点的序号存入父亲结点的子结点序号集中
        k=length(child_node_num{m});
        child_node_num{m}(k+1)=n+1;  
    end     
    % 1、样本为空,则树为空
    if isempty(DValue)
        node{ n+1 }=[];
        child_value{ n+1 }=[];
        child_node_num{ n+1 }=[];
        return;
    end 
    % 2、用于划分的剩余属性为空,选择多数元组所在的类作为结点
    if isempty( ClassPNum ) 
       node{ n+1 }=find_most( CN,A );
       child_value{ n+1 }=[];
       child_node_num{ n+1 }=[];
       return;
    end 
    % 3、样本中所有数据都属于同一类,将此类作为结点
    CNRowNum=CN_sta( CN, A);
    if length( find(CNRowNum==0) )>=2 %表示两类为空,则都属于一类
        node{ n+1 }=CN(A(1));
        child_value{ n+1 }=[];
        child_node_num{ n+1 }=[];
        return;
    % 4、样本中所有数据不属于同一类
    else
        I=Exp( CN,A );
        for i=1:length( ClassPNum )   %计算针对所有特征的信息增益         
            Entropy(i)=avg_entropy( DValue(:,ClassPNum(i)), A, CN);
            Gain(i)=I-Entropy(i);
        end
        % 4.1、各属性的信息增益均小于0,选择多数元组所在的类作为结点
        if max(Gain)<=0
            node{ n+1 }=find_most( CN,A );
            child_value{ n+1 }=[];
            child_node_num{ n+1 }=[];
        return;
        % 4.2、在信息增益最大的属性上进行划分
        else
            maxG=find( Gain==max(Gain) );
            [PValue RowNum]=type_sta( DValue(:,ClassPNum(maxG(1))), A );
            node{ n+1 }=ClassPNum(maxG(1));
            child_value{ n+1 }=PValue;
            child_node_num{ n+1 }=[];
            ClassPNum(maxG)=[];     % 删除ClassPNum(maxG)--已经进行划分的属性
            for i=1:length(PValue)
                [node,child_value,child_node_num]=TreeNode( DValue, CN, RowNum{i}, ClassPNum,n+1 );
            end
            return;
        end
    end
end

% A--参与划分的行号
% DValue--数据集的前四列
% 本函数用于统计参与划分的行大多数属于哪一个类
function most_type=find_most( CN,A )
    TypeName={'1','2','3'};
    CNRowNum=CN_sta( CN, A); %1 2 3总数存在里面
    n=max(CNRowNum);%求最大数量
    maxn=find( CNRowNum==n );%maxn就是最多的类别对应的数
    most_type=TypeName{maxn};%返回最多类别的字符串
end

% 计算属性P的熵
% A--参与计算的行号,即计算的行范围
% Attri--求属性Attri的熵
% CN--类别属性值
function entropy=avg_entropy( Attri, A, CN )
    k=0;entropy=0;
    n=length(A);
    I=Exp( CN,A );
    [PValue,RowNum]=type_sta( Attri, A );
    for i=1:length( PValue )
        CI=Exp( CN, RowNum{i});
        entropy=entropy-length( RowNum{i} )/n*CI;
    end
end

% 计算样本分类的期望
% A--参与计算的行号
% Attri--求期望的属性值的集合
function I=Exp(CN,A)
    CNRowNum=CN_sta( CN, A );
    n=length(A);
    I=0;
    for i=1:3
        if CNRowNum(i)>0
            P(i)=CNRowNum(i)/n;
            I=I-P(i)*log2( P(i) );
        end
    end
end

% 统计属性的取值及各取值对应的行号集合
% A为参与统计的记录的行号集合
% Attri为属性值的集合
function [PValue,RowNum]=type_sta( Attri, A)
    k=1;
    PValue=Attri( A(1) );
    RowNum{1}=A(1);
    for i=2:length(A)
        n1=find( PValue==Attri( A(i) ) );
        if isempty(n1)
            k=k+1;
            PValue(k)=Attri( A(i) );
            RowNum{k}=A(i);
        else
            n2=length( RowNum{n1} );
            RowNum{n1}(n2+1)=A(i);
        end
    end            
end

% 统计类别属性的取值及各取值对应的行号集合
% A为参与统计的记录的行号集合
% CN为类别属性值的集合
function CNRowNum=CN_sta( CN, A)
    CNRowNum=[0 0 0];
    TypeName={'1','2'};
    for i=1:length( A )
        if strcmp( CN(A(i)),TypeName{1})
            CNRowNum(1)=CNRowNum(1)+1;
        elseif strcmp( CN(A(i)),TypeName{2} )
            CNRowNum(2)=CNRowNum(2)+1;
        else CNRowNum(3)=CNRowNum(3)+1;
        end
    end            
end
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107
  • 108
  • 109
  • 110
  • 111
  • 112
  • 113
  • 114
  • 115
  • 116
  • 117
  • 118
  • 119
  • 120
  • 121
  • 122
  • 123
  • 124
  • 125
  • 126
  • 127
  • 128
  • 129
  • 130
  • 131
  • 132
  • 133
  • 134
  • 135
  • 136
  • 137
  • 138
  • 139
  • 140
  • 141
  • 142
  • 143
  • 144
  • 145
  • 146
  • 147
  • 148
  • 149
  • 150
  • 151
  • 152
  • 153
  • 154
  • 155
  • 156
  • 157
  • 158
  • 159
  • 160
  • 161
  • 162

3.统计及投票代码——statistics.m

function [type] = statistics(tn,rnode,rchild_value,rchild_node_num,PValue)
    TypeName={'1','2','3'};
    TypeNum=[0 0 0]; 
    for i=1:tn  %对测试向量进行投票,共有tn棵树
        [type]=vote(rnode,rchild_value,rchild_node_num,PValue,i);
        if strcmp( type,TypeName{1})
            TypeNum(1) = TypeNum(1) + 1;
        elseif strcmp( type,TypeName{2})
            TypeNum(2) = TypeNum(2) + 1;
        else TypeNum(3) = TypeNum(3) + 1;
        end
    end
    maxn=find( TypeNum==max(TypeNum) );
    type=str2num(TypeName{maxn(1)});
end
    
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
function [type] = vote(rnode,rchild_value,rchild_node_num,PValue,j)
    n=1;       %从树的根结点(即node{1})开始查找
    k=0;   
    while ~isempty(rchild_node_num{j,1}{n})%不为空则进入循环
         for i=1:length(rchild_value{j,1}{n})
                if PValue(rnode{j,1}{n})==rchild_value{j,1}{n}(i)
                    n=rchild_node_num{j,1}{n}(i);
                    k=0;
                    break;
                end                    
         end
        
        if i==length(rchild_value{j,1}{n})
            % 若这个值在分类器中不存在,则取其最近的值进行分类
           PValue(rnode{j,1}{n})=PValue(rnode{j,1}{n})+0.1*k;
           PValue=roundn(PValue,-1);
        end     
        k=(-1)^k*( abs(k)+1 );     
    end
    type=rnode{j,1}{n};                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                  type=rnode{j,1}{n};
    end

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22

分类结果

1.决策树的结构可以从rnode、rchild_node_num、rchild_value三个数组中对照着读出来。
2.分类结果为
在这里插入图片描述
3.测试集的分类结果将保存到loan_result.xls文件中。

代码缺点

1.只适用于属性值是标签值也就是离散值的类型,如果训练集中含有连续属性样本,那么训练出的决策树将存在较严重的过拟合现象。
2.未设置阈值,导致训练出的决策树结点数过多(可认为阈值为0,可以适当调大)
3.ID3算法在部分情形下不太合适。

文件

matlab文件及数据集地址

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/不正经/article/detail/94505
推荐阅读
相关标签
  

闽ICP备14008679号