当前位置:   article > 正文

随机森林、决策树的MATLAB及Python完整代码实现_randomforestclassifier matlab

randomforestclassifier matlab

- ## 随机森林、决策树MATLAB及Python完整代码实现

代码1:自编程序csdn1.m

treeNum = 10;
featureNum  = 10;
dataNum =1000 ;
dataTrain=xlsread('NSL_train.xlsx');
dataTest =xlsread('NSL_test.xlsx');
y=RF(treeNum ,featureNum,dataNum ,dataTrain,dataTest);
fprintf('\n*****随机森林分类准确率为:%f***\n',y);
function [accuracy] = RF(treeNum ,featureNum,dataNum ,dataTrain,dataTest)
[dataAll,featureGrounp] = dataSet(dataTrain,treeNum,featureNum,dataNum);
RF = buildRandForest(dataAll,treeNum);
RF_prection = RFprection(RF,featureGrounp,dataTest);
accuracy = calAccuracy(dataTest,RF_prection);
end

function RF = buildRandForest(dataTrain,treeNum)
     RF = [];
   
     fprintf('*********正在训练随机森林,共%d课树**********\n',treeNum);
for a = 1: treeNum
     data = dataTrain(:,:,a);
     note = buildCartTree(data,0);
     fprintf('++++++第%d课树训练完成\n',a);
     RF = [RF,note];
     fprintf('===============================\n');
end   
    fprintf('************随机森林训练完成!*******\n');
end

function  note = buildCartTree(data,k)
  k = k + 1;
  [m,n] = size(data);

   if m == 0
      note = struct();
   else
      currentGini =  calGiniIndex(data);
      bestGini = 0;
      featureNum = n - 1;
      for a = 1:featureNum   
          feature_values = unique(data(:,a));
          [m1,n1] = size(feature_values);
          for b = 1:m1
              [D1,D2] = splitData(data,a,feature_values(b,n1));
              [m2,n2] = size(D1);
              [m3,n3] = size(D2);
             
              Gini_1 = calGiniIndex(D1);
              Gini_2 = calGiniIndex(D2);
              nowGini = (m2*Gini_1+m3*Gini_2)/m;
              gGini = currentGini - nowGini;
            
              if gGini > bestGini && m2>0 && m3>0
                 bestGini =  gGini;
                 bestFeature = [a,feature_values(b,n1)];  
                 rightData = D1;
                 leftData = D2;
              end   
             
          end 
      end 
      if bestGini > 0
         note =  buildCartTree(rightData,k) ;
         right = note;
         note = buildCartTree(leftData,k) ;
         left = note ;
         s1 = 'bestFeature';
         s2 = 'value';
         s3 = 'rightBranch';
         s4 = 'leftBranch';
         s5 = 'leaf';
         leafValue = [];
         note =  struct(s1,bestFeature(1,1),s2,bestFeature(1,2),s3,right,s4,left,s5,leafValue);
      else
         leafValue = data(1,n);
         s1 = 'leaf';
         note = struct(s1,leafValue);
      end    
   end
end
function  [dataAll,featureAll] = dataSet(dataTrain,treeNum,featureNum,dataNum)%数据集建立子函数
   dataAll = zeros(dataNum,featureNum+1,treeNum);
   featureAll = zeros(featureNum,1,treeNum);
   for a = 1: treeNum
    [data,feature] = chooseSample(dataTrain,featureNum,dataNum);
    dataAll(:,:,a) = data;
    featureAll(:,:,a) = feature';
   end 
end
function RF_prection_ = RFprection(RF,featureGrounp,dataTrain)
     [m,n] = size(RF);
     [m2,n2] = size(dataTrain);
     RF_prection = [];
     
     for a = 1:n
         RF_single = RF(:,a);
         feature = featureGrounp(:,:,a);
         data = splitData2(dataTrain,feature);
         RF_prection_single = [];
         for b =1:m2
             A = prection(RF_single,data(b,:));
             RF_prection_single = [RF_prection_single;A];
         end    
         RF_prection = [RF_prection,RF_prection_single];  
     end    
     RF_prection_ = mode(RF_prection,2);
end
function [Data1,Data2] = splitData(data,fea,value)

     D1 = [];
     D2 = [];
     [m,n] = size(data);
     if m == 0
       D1 = 0;
       D2 = 0;  
     else    
        D1 = find(data(:,fea) >= value);
        D2 = find(data(:,fea) < value);   
        Data1 = data(D1,:);
        Data2 = data(D2,:);
     end
end
function data = splitData2(dataTrain,feature)
   [m,n] = size(dataTrain);
   [m1,n1] = size(feature);
   data = zeros(m,m1);

   data(:,:) = dataTrain(:,feature);
end
function [data,feature] = chooseSample(data1,featureNum,dataNum)
   [m,n] = size(data1);
     B = randperm(n-1);
     feature = B(1,1:featureNum);
    C= zeros(1,dataNum);
        A = randperm(m);
        C(1,:) = A(1,1:dataNum);
    data= data1(C,feature);
    data = [data,data1(C,n)];
end
function Gini = calGiniIndex(data)
    [m,n] = size(data);
    if  m == 0
        Gini = 0;
    else
        labelsNum = labels_num2(data);
        [m1,n1] = size(labelsNum);
        
        Gini = 0;
        for a = 1:m1
            Gini = Gini + labelsNum(a,n1)^2;
        end    
        Gini = 1 - Gini/(m^2);
    end    
end

%统计标签中不同类型标签的数量
function labelsNum = labels_num2(data)
      [m,n] = size(data);
   
     if m == 0
        labelsNum = 0; 
     else    
      labels = data(:,n);
      
     A = unique(labels,'sorted');
     [m1,n1] = size(A);
     B = zeros(m1,2);
     B(:,1) = A(:,1);
     for a = 1:m1
          B(a,2) = size(find(labels == A(a,1)),1);
     end    
     labelsNum = B;
     end
end

function A = prection(RF_single,sample)
    if isempty(RF_single.leaf) == 0 
       A =  RF_single.leaf;
    else
       B = sample(1,RF_single.bestFeature);
       if B >= RF_single.value
           branch = RF_single.rightBranch;
       else
           branch = RF_single.leftBranch;
       end 
       A = prection(branch,sample);
    end    
end
function accuracy = calAccuracy(dataTest,RF_prection)
      [m,n] = size(dataTest);
      A = dataTest(:,n);
      right = sum(A == RF_prection);
      accuracy = right/m;
end
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107
  • 108
  • 109
  • 110
  • 111
  • 112
  • 113
  • 114
  • 115
  • 116
  • 117
  • 118
  • 119
  • 120
  • 121
  • 122
  • 123
  • 124
  • 125
  • 126
  • 127
  • 128
  • 129
  • 130
  • 131
  • 132
  • 133
  • 134
  • 135
  • 136
  • 137
  • 138
  • 139
  • 140
  • 141
  • 142
  • 143
  • 144
  • 145
  • 146
  • 147
  • 148
  • 149
  • 150
  • 151
  • 152
  • 153
  • 154
  • 155
  • 156
  • 157
  • 158
  • 159
  • 160
  • 161
  • 162
  • 163
  • 164
  • 165
  • 166
  • 167
  • 168
  • 169
  • 170
  • 171
  • 172
  • 173
  • 174
  • 175
  • 176
  • 177
  • 178
  • 179
  • 180
  • 181
  • 182
  • 183
  • 184
  • 185
  • 186
  • 187
  • 188
  • 189
  • 190
  • 191
  • 192
  • 193

代码2:TreeBagger

%训练模型
Factor = TreeBagger(treeNumber, train, trainLabel,'Method','classification','NumPredictorsToSample',featureNum,'OOBpredictorImportance','on');%
%性能评估,k-fold交叉验证法
[Predict_label,Scores]  = predict(Factor, test);%%%%测试集预测标签
  • 1
  • 2
  • 3
  • 4

代码3:Sklearn

import numpy as np
import pandas as pd
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split
from sklearn.model_selection import KFold
from sklearn.metrics import roc_auc_score
from sklearn import metrics  # 分类结果评价函数

#数据集读取、交叉验证等省略
forest_1 = RandomForestClassifier(n_estimators=2000, random_state=10, n_jobs=-1, oob_score=True)
        forest_1.fit(x_train, y_train)
        expected = y_test  # 测试样本的期望输出
        predicted = forest_1.predict(x_selected_test_3)  # 测试样本预测
        # 输出结果
        print(metrics.classification_report(expected, predicted))  # 输出结果,精确度、召回率、f-1分数
        print(metrics.confusion_matrix(expected, predicted))  # 混淆矩阵
        auc = metrics.roc_auc_score(y_test, predicted)
        accuracy = metrics.accuracy_score(y_test, predicted)  # 求精度
        print("RF_Accuracy: %.2f%%" % (accuracy * 100.0))
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/凡人多烦事01/article/detail/94491
推荐阅读
相关标签
  

闽ICP备14008679号