当前位置:   article > 正文

机器学习_9.随机森林算法_随机森林min.node.size

随机森林min.node.size

定义

随机森林指的是利用多棵树对样本进行训练并预测的一种分类器。该分类器最早由Leo Breiman和Adele Cutler提出,并被注册成了商标。

机器学习中,随机森林是一个包含多个决策树的分类器,并且其输出的类别是由个别树输出的类别的众数而定。Leo Breiman和Adele Cutler发展出推论出随机森林的算法。 随机森林是一种灵活且易于使用的机器学习算法,即便没有超参数调优,也可以在大多数情况下得到很好的结果。它也是最常用的算法之一,因为它很简易,既可用于分类也能用于回归任务。

思想引例

一个名叫安德鲁的人,想知道一年的假期旅行中他应该去哪些地方。他会向了解他的朋友们咨询建议。

起初,他去寻找一位朋友,这位朋友会问安德鲁他曾经去过哪些地方,他喜欢还是不喜欢这些地方。基于这些回答就能给安德鲁一些建议,这便是一种典型的决策树算法。

朋友通过安德鲁的回答,为其制定出一些规则来指导应当推荐的地方。随后,安德鲁开始寻求越来越多朋友们的建议,他们会问他不同的问题,并从中给出一些建议。最后,安德鲁选择了推荐最多的地方,这便是典型的随机森林算法。

工作原理

随机森林建立了多个决策树,并将它们合并在一起以获得更准确和稳定的预测。随机森林的一大优势在于它既可用于分类,也可用于回归问题,这两类问题恰好构成了当前的大多数机器学习系统所需要面对的。

随机森林分类器使用所有的决策树分类器以及bagging 分类器的超参数来控制整体结构。 与其先构建bagging分类器,并将其传递给决策树分类器,以直接使用随机森林分类器类,这样对于决策树而言,更加方便和优化。要注意的是,回归问题同样有一个随机森林回归器与之相对应。随机森林算法中树的增长会给模型带来额外的随机性。与决策树不同的是,每个节点被分割成最小化误差的最佳特征,在随机森林中我们选择随机选择的特征来构建最佳分割。因此,当您在随机森林中,仅考虑用于分割节点的随机子集,甚至可以通过在每个特征上使用随机阈值来使树更加随机,而不是如正常的决策树一样搜索最佳阈值。这个过程产生了广泛的多样性,通常可以得到更好的模型。

超参数

用于提高模型预测准确性方面:

1.“n_estimators”超参数表示算法在进行最大投票或采取预测平均值之前建立的树数。一般来说,树的数量越多,性能越好,预测也越稳定,但这也会减慢计算速度。

2.“max_features”,它表示随机森林在单个树中可拥有的特征最大数量。 

3.“min_sample_leaf”,决定了叶子的数量。

 

加快模型计算方面:

1.“n_jobs”超参数表示引擎允许使用处理器的数量。 若值为1,则只能使用一个处理器。 值为-1则表示没有限制。

2.“random_state”,表示随机数种子,保证模型的输出具有可复制性。 当它被赋于一个指定值,且模型训练具有相同的参数和相同的训练数据时,该模型将始终产生相同的结果。

3.“oob_score”(也称为oob采样),它是一种随机森林交叉验证方法。 在这个抽样中,大约三分之一的数据不用于模型训练,而用来评估模型的性能。这些样本被称为袋外样本。它与留一法交叉验证方法非常相似,但几乎没有附加的计算负担。

优缺点分析及适用范围

优点:

可以用于回归和分类任务,并且很容易查看模型的输入特征的相对重要性。

默认的超参数通常会产生一个很好的预测结果。

随机森林可以有效地避免过拟合的问题,因为只要森林中有足够多的树,分类器就不会过度拟合模型。

缺点:

使用大量的树会使算法变得很慢,并且无法做到实时预测。一般而言,这些算法训练速度很快,预测十分缓慢。越准确的预测需要越多的树,这将导致模型越慢。

适用范围:

随机森林算法可被用于很多不同的领域,如银行,股票市场,医药和电子商务。在银行领域,它通常被用来检测那些比普通人更高频率使用银行服务的客户,并及时偿还他们的债务。同时,它也会被用来检测那些想诈骗银行的客户。在金融领域,它可用于预测未来股票的趋势。在医疗保健领域,它可用于识别药品成分的正确组合,分析患者的病史以识别疾病。除此之外,在电子商务领域中,随机森林可以被用来确定客户是否真的喜欢某个产品。

流程实现

随机森林就是对决策树的集成,但有两点不同:

(1)采样的差异性:从含m个样本的数据集中有放回的采样,得到含m个样本的采样集,用于训练。这样能保证每个决策树的训练样本不完全一样。

(2)特征选取的差异性:每个决策树的n个分类特征是在所有特征中随机选择的(n是一个需要我们自己调整的参数)

随机森林需要调整的参数有:

(1)    决策树的个数

(2)    特征属性的个数

(3)    递归次数(即决策树的深度)

Python代码

参考:https://blog.csdn.net/flying_sfeng/article/details/64133822

(1)    导入文件并将所有特征转换为float形式

  1. #加载数据
  2. def loadCSV(filename):
  3. dataSet=[]
  4. with open(filename,'r') as file:
  5. csvReader=csv.reader(file)
  6. for line in csvReader:
  7. dataSet.append(line)
  8. return dataSet
  9. #除了判别列,其他列都转换为float类型
  10. def column_to_float(dataSet):
  11. featLen=len(dataSet[0])-1
  12. for data in dataSet:
  13. for column in range(featLen):
  14. data[column]=float(data[column].strip())

(2)    将数据集分成n份,方便交叉验证

  1. #将数据集分成N块,方便交叉验证
  2. def spiltDataSet(dataSet,n_folds):
  3. fold_size=int(len(dataSet)/n_folds)
  4. dataSet_copy=list(dataSet)
  5. dataSet_spilt=[]
  6. for i in range(n_folds):
  7. fold=[]
  8. while len(fold) < fold_size: #这里不能用if,if只是在第一次判断时起作用,while执行循环,直到条件不成立
  9. index=randrange(len(dataSet_copy))
  10. fold.append(dataSet_copy.pop(index)) #pop() 函数用于移除列表中的一个元素(默认最后一个元素),并且返回该元素的值。
  11. dataSet_spilt.append(fold)
  12. return dataSet_spilt

(3)    构造数据子集(随机采样),并在指定特征个数(假设m个,手动调参)下选取最优特征

  1. #构造数据子集
  2. def get_subsample(dataSet,ratio):
  3. subdataSet=[]
  4. lenSubdata=round(len(dataSet)*ratio)
  5. while len(subdataSet) < lenSubdata:
  6. index=randrange(len(dataSet)-1)
  7. subdataSet.append(dataSet[index])
  8. #print len(subdataSet)
  9. return subdataSet
  10. #选取任意的n个特征,在这n个特征中,选取分割时的最优特征
  11. def get_best_spilt(dataSet,n_features):
  12. features=[]
  13. class_values=list(set(row[-1] for row in dataSet))
  14. b_index,b_value,b_loss,b_left,b_right=999,999,999,None,None
  15. while len(features) < n_features:
  16. index=randrange(len(dataSet[0])-1)
  17. if index not in features:
  18. features.append(index)
  19. #print 'features:',features
  20. for index in features:
  21. for row in dataSet:
  22. left,right=data_spilt(dataSet,index,row[index])
  23. loss=spilt_loss(left,right,class_values)
  24. if loss < b_loss:
  25. b_index,b_value,b_loss,b_left,b_right=index,row[index],loss,left,right
  26. #print b_loss
  27. #print type(b_index)
  28. return {'index':b_index,'value':b_value,'left':b_left,'right':b_right}

(4)    构造决策树

  1. #构造决策树
  2. def build_tree(dataSet,n_features,max_depth,min_size):
  3. root=get_best_spilt(dataSet,n_features)
  4. sub_spilt(root,n_features,max_depth,min_size,1)
  5. return root

(5)    创建随机森林(多个决策树的结合)

  1. #创建随机森林
  2. def random_forest(train,test,ratio,n_feature,max_depth,min_size,n_trees):
  3. trees=[]
  4. for i in range(n_trees):
  5. subTrain=get_subsample(train,ratio)
  6. tree=build_tree(subTrain,n_features,max_depth,min_size)
  7. #print 'tree %d: '%i,tree
  8. trees.append(tree)
  9. #predict_values = [predict(trees,row) for row in test]
  10. predict_values = [bagging_predict(trees, row) for row in test]
  11. return predict_values

(6)    输入测试集并进行测试,输出预测结果

  1. #预测测试集结果
  2. def predict(tree,row):
  3. predictions=[]
  4. if row[tree['index']] < tree['value']:
  5. if isinstance(tree['left'],dict):
  6. return predict(tree['left'],row)
  7. else:
  8. return tree['left']
  9. else:
  10. if isinstance(tree['right'],dict):
  11. return predict(tree['right'],row)
  12. else:
  13. return tree['right']
  14. # predictions=set(predictions)

Java代码实现

参考:https://blog.csdn.net/xiuxian4728/article/details/78897134

决策树的创建

随机森林算法

Main Feature

1.对程序的测试的数据采用了公开数据集NSL-KDD的数据集 [ NSL-KDD ]。

2.属性数据采用的是连续属性(Continuous Attributes),在划分决策树节点,分成两个分叉。

3.在代码的后面,增加了计算Variable Importance 的内容,理论知识可参考 [ Variable Importance ]

 

代码

DTree.java

  1. package com.rf.real;
  2. import java.util.*;
  3. /**
  4. * Creates a decision tree based on the specifications of random forest trees
  5. */
  6. public class DTree {
  7.     /** Instead of checking each index we'll skip every INDEX_SKIP indices unless there's less than MIN_SIZE_TO_CHECK_EACH*/
  8.     private static final int INDEX_SKIP = 2;
  9.     /** If there's less than MIN_SIZE_TO_CHECK_EACH points, we'll check each one */
  10.     private static final int MIN_SIZE_TO_CHECK_EACH = 10;
  11.     /** If the number of data points is less than MIN_NODE_SIZE, we won't continue splitting, we'll take the majority vote */
  12.     private static final int MIN_NODE_SIZE=5;
  13.     /** the number of data records */
  14.     private int N;
  15.     /** the number of samples left out of the boostrap of all N to test error rate
  16. * @see <a href="http://www.stat.berkeley.edu/~breiman/RandomForests/cc_home.htm#ooberr">OOB error estimate</a>
  17. */
  18.     private int testN;
  19.     /** Of the testN, the number that were correctly identified */
  20.     private int correct;
  21.     /** an estimate of the importance of each attribute in the data record
  22. * @see <a href="http://www.stat.berkeley.edu/~breiman/RandomForests/cc_home.htm#varimp">Variable Importance</a>
  23. */
  24.     private int[] importances;
  25.     /** This is the root of the Decision Tree */
  26.     private TreeNode root;
  27.     /** This is a pointer to the Random Forest this decision tree belongs to */
  28.     private RandomForest forest;
  29.     /** This keeps track of all the predictions done by this tree */
  30.     public ArrayList<Integer> predictions;
  31.     /**
  32. * This constructs a decision tree from a data matrix.
  33. * It first creates a bootstrap sample, the train data matrix, as well as the left out records,
  34. * the test data matrix. Then it creates the tree, then calculates the variable importances (not essential)
  35. * and then removes the links to the actual data (to save memory)
  36. *
  37. * @param data The data matrix as a List of int arrays - each array is one record,
  38. * each index in the array is one attribute, and the last index is the class
  39. * (ie [ x1, x2, . . ., xM, Y ]).
  40. * @param forest The random forest this decision tree belongs to
  41. * @param num the Tree number
  42. */
  43.     public DTree(ArrayList<int[]> data, RandomForest forest, int num){
  44.         this.forest = forest;
  45.         N = data.size();
  46.         importances = new int[RandomForest.M];
  47.         predictions = new ArrayList<Integer>();
  48.         //System.out.println("Make a Dtree num : "+num+" with N:"+N+" M:"+RandomForest.M+" Ms:"+RandomForest.Ms);
  49.         ArrayList<int[]> train = new ArrayList<int[]>(N); //data becomes the "bootstrap" - that's all it knows
  50.         ArrayList<int[]> val = new ArrayList<int[]>();
  51.         //System.out.println("Creating tree No."+num);
  52.         BootStrapSample(data, train, val, num);//populates train and test using data
  53.         testN = val.size();
  54.         correct = 0;
  55.         root = CreateTree(train, num);//creating tree using training data
  56.         CalcTreeVariableImportanceAndError(val, num);
  57.         FlushData(root);
  58.     }
  59.     /**
  60. * Responsible for gauging the error rate of this tree and
  61. * calculating the importance values of each attribute
  62. *
  63. * @param val The left out data matrix
  64. * @param nv The Tree number
  65. */
  66.     private void CalcTreeVariableImportanceAndError(ArrayList<int[]> val, int nv) {
  67.         //calculate error rate
  68.         correct = CalcTreeErrorRate(val, nv);//the num of correct prediction record
  69.         CalculateClasses(val, nv);
  70.         //calculate importance of each attribute
  71.         for (int m=0; m<RandomForest.M; m++){
  72.             ArrayList<int[]> test_data = RandomlyPermuteAttribute(CloneData(val), m);
  73.             int correctAfterPermute = 0;
  74.             for (int[] arr:test_data){
  75.                 int pre_label = Evaluate(arr);
  76.                 if (pre_label == GetClass(arr))
  77.                     correctAfterPermute++;
  78.             }
  79.             importances[m] += (correct - correctAfterPermute);
  80.         }
  81. System.out.println("The importances of tree " + nv + " as follows");
  82. // for(int m=0; m<importances.length; m++){
  83. // System.out.println(" Attr" + m + ":" + importances[m]);
  84. // }
  85.     }
  86.     /**
  87. * Calculates the tree error rate,
  88. * displays the error rate to console,
  89. * and updates the total forest error rate
  90. *
  91. * @param val the left out test data matrix
  92. * @param nu The Tree number
  93. * @return the number correct
  94. */
  95.     public int CalcTreeErrorRate(ArrayList<int[]> val, int nu){
  96.         int correct = 0;
  97.         for (int[] record:val){
  98.             int pre_label = Evaluate(record);
  99.             forest.UpdateOOBEstimate(record, pre_label);
  100.             int actual_label = record[record.length-1];//actual_label
  101.             if (pre_label == actual_label)
  102.                 correct++;
  103.         }
  104.         double err = 1 - correct/((double)val.size());
  105. // System.out.print("\n");
  106.         System.out.println("Number of correct = " + correct + ", out of :" + val.size());
  107.         System.out.println("Test-Data error rate of tree " + nu + " is: " + (err * 100) + " %");
  108.         return correct;
  109.     }
  110.     /**
  111. * This method will get the classes and will return the updates
  112. * @param val The left out data matrix
  113. * @param nu The Tree number
  114. */
  115.     public ArrayList<Integer> CalculateClasses(ArrayList<int[]> val, int nu){
  116.         ArrayList<Integer> preList = new ArrayList<Integer>();
  117.         int korect = 0;
  118.         for(int[] record : val){
  119.             int pre_label = Evaluate(record);
  120.             preList.add(pre_label);
  121.             int actual_label = record[record.length-1];
  122.             if (pre_label==actual_label)
  123.                 korect++;
  124.         }
  125.         predictions = preList;
  126.         return preList;
  127.     }
  128.     /**
  129. * This will classify a new data record by using tree
  130. * recursion and testing the relevant variable at each node.
  131. *
  132. * This is probably the most-used function in all of <b>GemIdent</b>.
  133. * It would make sense to inline this in assembly for optimal performance.
  134. *
  135. * @param record the data record to be classified
  136. * @return the class the data record was classified into
  137. */
  138.     public int Evaluate(int[] record){
  139.         TreeNode evalNode = root;
  140.         while (true){
  141.             if (evalNode.isLeaf)
  142.                 return evalNode.Class;
  143.             if (record[evalNode.splitAttributeM] <= evalNode.splitValue)
  144.                 evalNode = evalNode.left;
  145.             else
  146.                 evalNode = evalNode.right;
  147.         }
  148.     }
  149.     /**
  150. * Takes a list of data records, and switches the m-th attribute across data records.
  151. * This is important in order to test the importance of the attribute. If the attribute
  152. * is randomly permuted and the result of the classification is the same, the attribute is
  153. * not important to the classification and vice versa.
  154. *
  155. * @see <a href="http://www.stat.berkeley.edu/~breiman/RandomForests/cc_home.htm#varimp">Variable Importance</a>
  156. * @param val The left out data matrix to be permuted
  157. * @param m The attribute index to be permuted
  158. * @return The data matrix with the m-th column randomly permuted
  159. */
  160.     private ArrayList<int[]> RandomlyPermuteAttribute(ArrayList<int[]> val, int m){
  161.         int num = val.size() * 2;
  162.         for (int i=0; i<num; i++){
  163.             int a = (int)Math.floor(Math.random() * val.size());
  164.             int b = (int)Math.floor(Math.random() * val.size());
  165.             int[] arrA = val.get(a);
  166.             int[] arrB = val.get(b);
  167.             int temp = arrA[m];
  168.             arrA[m] = arrB[m];
  169.             arrB[m] = temp;
  170.         }
  171.         return val;
  172.     }
  173.     /**
  174. * Creates a copy of the data matrix
  175. * @param data the data matrix to be copied
  176. * @return the cloned data matrix
  177. */
  178.     private ArrayList<int[]> CloneData(ArrayList<int[]> data){
  179.         ArrayList<int[]> clone=new ArrayList<int[]>(data.size());
  180.         int M=data.get(0).length;
  181.         for (int i=0;i<data.size();i++){
  182.             int[] arr=data.get(i);
  183.             int[] arrClone=new int[M];
  184.             for (int j=0;j<M;j++){
  185.                 arrClone[j]=arr[j];
  186.             }
  187.             clone.add(arrClone);
  188.         }
  189.         return clone;
  190.     }
  191.     /**
  192. * This creates the decision tree according to the specifications of random forest trees.
  193. *
  194. * @see <a href="http://www.stat.berkeley.edu/~breiman/RandomForests/cc_home.htm#overview">Overview of random forest decision trees</a>
  195. * @param train the training data matrix (a bootstrap sample of the original data)
  196. * @param ntree the tree number
  197. * @return the TreeNode object that stores information about the parent node of the created tree
  198. */
  199.     private TreeNode CreateTree(ArrayList<int[]> train, int ntree){
  200.         TreeNode root = new TreeNode();
  201.         root.data = train; // public List<int[]> data;
  202.         //System.out.println("creating ");
  203.         RecursiveSplit(root, ntree);
  204.         return root;
  205.     }
  206.     /**
  207. * @author DEGIS
  208. */
  209.     private class TreeNode implements Cloneable{
  210.         public boolean isLeaf;
  211.         public TreeNode left;
  212.         public TreeNode right;
  213.         public int splitAttributeM;
  214.         public Integer Class;
  215.         public List<int[]> data;
  216.         public int splitValue;
  217.         public int generation;
  218.         public ArrayList<Integer> attrArr;
  219.         public TreeNode(){
  220.             splitAttributeM=-99;
  221.             splitValue=-99;
  222.             generation=1;
  223.         }
  224.         public TreeNode clone(){ //"data" element always null in clone
  225.             TreeNode treeCopy = new TreeNode();
  226.             treeCopy.isLeaf = isLeaf;
  227.             if (left != null) //otherwise null
  228.                 treeCopy.left = left.clone();
  229.             if (right != null) //otherwise null
  230.                 treeCopy.right = right.clone();
  231.             treeCopy.splitAttributeM = splitAttributeM;
  232.             treeCopy.Class = Class;
  233.             treeCopy.splitValue = splitValue;
  234.             return treeCopy;
  235.         }
  236.     }
  237.     private class DoubleWrap{
  238.         public double d;
  239.         public DoubleWrap(double d){
  240.             this.d=d;
  241.         }
  242.     }
  243.     /**
  244. * This is the crucial function in tree creation.
  245. *
  246. * <ul>
  247. * <li>Step A
  248. * Check if this node is a leaf, if so, it will mark isLeaf true
  249. * and mark Class with the leaf's class. The function will not
  250. * recurse past this point.
  251. * </li>
  252. * <li>Step B
  253. * Create a left and right node and keep their references in
  254. * this node's left and right fields. For debugging purposes,
  255. * the generation number is also recorded. The {@link RandomForest#Ms Ms} attributes are
  256. * now chosen by the {@link #GetVarsToInclude() GetVarsToInclude} function
  257. * </li>
  258. * <li>Step C
  259. * For all Ms variables, first {@link #SortAtAttribute(List,int) sort} the data records by that attribute,
  260. * then look through the values from lowest to highest.
  261. * If value i is not equal to value i+1, record i in the list of "indicesToCheck."
  262. * This speeds up the splitting. If the number of indices in indicesToCheck > MIN_SIZE_TO_CHECK_EACH
  263. * then we will only {@link #CheckPosition(int, int, int, DoubleWrap, TreeNode, int) check} the
  264. * entropy at every {@link #INDEX_SKIP INDEX_SKIP} index otherwise,
  265. * we {@link #CheckPosition(int, int, int, DoubleWrap, TreeNode, int) check}
  266. * the entropy for all. The "E" variable records the entropy and we are trying to find the minimum in which to split on
  267. * </li>
  268. * <li>Step D
  269. * The newly generated left and right nodes are now checked:
  270. * If the node has only one record, we mark it as a leaf and set its class equal to the class of the record.
  271. * If it has less than {@link #MIN_NODE_SIZE MIN_NODE_SIZE} records,
  272. * then we mark it as a leaf and set its class equal to the {@link #GetMajorityClass(List) majority class}.
  273. * If it has more, then we do a manual check on its data records and if all have the same class, then it
  274. * is marked as a leaf. If not, then we run {@link #RecursiveSplit(TreeNode, int) RecursiveSplit} on
  275. * that node
  276. * </li>
  277. * </ul>
  278. *
  279. * @param parent The node of the parent
  280. * @param Ntreenum the tree number
  281. */
  282.     private void RecursiveSplit(TreeNode parent, int Ntreenum){
  283.         //System.out.println("Recursivly spilitting tree : "+Ntreenum);
  284.         if (!parent.isLeaf){
  285.             //-------------------------------Step A
  286.             //当前结点包含的样本全属于同一类别,无需划分;
  287.             Integer Class = CheckIfLeaf(parent.data);
  288.             if (Class != null){
  289.                 parent.isLeaf = true;
  290.                 parent.Class = Class;
  291.                 //System.out.println("leaf for this tree: "+Ntreenum);
  292. // System.out.print("parent leaf! Class:"+parent.Class+" ");
  293. // PrintOutClasses(parent.data);
  294.                 return;
  295.             }
  296.             //-------------------------------Step B
  297.             int Nsub = parent.data.size();
  298. // PrintOutClasses(parent.data);
  299.             ArrayList<Integer> vars = GetVarsToInclude();//randomly selects Ms' index of attributes from M
  300.             parent.attrArr = vars;
  301.             parent.left = new TreeNode();
  302.             parent.left.generation = parent.generation + 1;
  303.             parent.right = new TreeNode();
  304.             parent.right.generation = parent.generation + 1;
  305.             DoubleWrap lowestE = new DoubleWrap(Double.MIN_VALUE);
  306.             //假如当前属性集为空,返回样本数最多的类;
  307.             if(parent.attrArr.size() == 0){
  308.                 parent.Class = GetMajorityClass(parent.data);
  309.                 return;
  310.             }
  311.             //-------------------------------Step C
  312.             //所有样本在所有属性上取值相同,无法划分,返回样本数最多的类;
  313.             int sameClass = 0;
  314.             for (int m:parent.attrArr){
  315.                 SortAtAttribute(parent.data, m);//sorts on a particular column in the row
  316.                 ArrayList<Integer> indicesToCheck = new ArrayList<Integer>();
  317.                 for (int n=1; n<Nsub; n++){
  318.                     int classA = GetClass(parent.data.get(n-1));
  319.                     int classB = GetClass(parent.data.get(n));
  320.                     if (classA != classB)
  321.                         indicesToCheck.add(n);
  322.                 }
  323.                 //所有样本在所有属性上取值相同,无法划分,返回样本数最多的类;
  324.                 if (indicesToCheck.size() == 0)
  325.                     sameClass++;
  326.             }
  327.             if(sameClass == parent.attrArr.size()){
  328.                 parent.isLeaf = true;
  329.                 parent.Class = GetMajorityClass(parent.data);
  330.                 return;
  331.             }
  332.             for (int m:parent.attrArr){
  333.                 SortAtAttribute(parent.data, m);//sorts on a particular column in the row
  334.                 ArrayList<Integer> indicesToCheck = new ArrayList<Integer>();
  335.                 for (int n=1; n<Nsub; n++){
  336.                     int classA = GetClass(parent.data.get(n-1));
  337.                     int classB = GetClass(parent.data.get(n));
  338.                     if (classA != classB)
  339.                         indicesToCheck.add(n);
  340.                 }
  341. // System.out.print("indices to check for tree : "+Ntreenum);
  342. // for (int n:indicesToCheck)
  343. // System.out.print(" "+n);
  344. // System.out.print("\n ");
  345.                 if (indicesToCheck.size() > MIN_SIZE_TO_CHECK_EACH){
  346.                     for (int i=0; i<indicesToCheck.size(); i+=INDEX_SKIP){
  347.                         //System.out.println("Checking positions for index : "+i+" and tree :"+Ntreenum);
  348.                         CheckPosition(m, indicesToCheck.get(i), Nsub, lowestE, parent, Ntreenum);
  349.                         if (lowestE.d == 0)//lowestE now has the minimum conditional entropy so IG is max there
  350.                             break;
  351.                     }
  352.                 }
  353.                 else {
  354.                     for (int index:indicesToCheck){
  355.                         CheckPosition(m, index, Nsub, lowestE, parent, Ntreenum);
  356.                         if (lowestE.d == 0)
  357.                             break;
  358.                     }
  359.                 }
  360. // BufferedReader reader=new BufferedReader(new InputStreamReader(System.in));
  361. // System.out.println("************************* lowest e:"+lowestE.d);
  362. // try {reader.readLine();} catch (IOException e){}
  363.                 if (lowestE.d == 0)
  364.                     break;
  365.             }
  366.             //从属性集合删除分裂属性
  367.             Iterator<Integer> it = parent.attrArr.iterator();
  368.             while(it.hasNext()){
  369.                 int attr = it.next();
  370.                 if (attr == parent.splitAttributeM){
  371.                     it.remove();
  372.                 }
  373.             }
  374.             parent.left.attrArr = parent.attrArr;
  375.             parent.right.attrArr = parent.attrArr;
  376. // System.out.print("\n");
  377. // System.out.print("split attrubute num:"+parent.splitAttributeM+" at val:"+parent.splitValue+" n:"+parent.data.size()+" ");
  378. // PrintOutClasses(parent.data);
  379.             // System.out.println("\nmadeSplit . . .");
  380. // PrintOutNode(parent," ");
  381. // PrintOutNode(parent.left," ");
  382. // PrintOutNode(parent.right," ");
  383.             //-------------------------------Step D
  384.             //------------Left Child
  385.             if (parent.left.data.size() == 1){//训练集为空
  386.                 parent.left.isLeaf = true;
  387.                 parent.left.Class = GetClass(parent.left.data.get(0));
  388.             }
  389.             else if (parent.left.attrArr.size() == 0){//属性集为空
  390.                 parent.left.isLeaf = true;
  391.                 parent.Class = GetMajorityClass(parent.left.data);
  392.             }
  393. // else if (parent.left.data.size() < MIN_NODE_SIZE){
  394. // parent.left.isLeaf = true;
  395. // parent.left.Class = GetMajorityClass(parent.left.data);
  396. // }
  397.             else {
  398.                 Class = CheckIfLeaf(parent.left.data);
  399.                 if (Class == null){
  400.                     parent.left.isLeaf = false;
  401.                     parent.left.Class = null;
  402. // System.out.println("make branch left: m:"+m);
  403.                 }
  404.                 else {//训练集样本全属于同一类别
  405.                     parent.left.isLeaf = true;
  406.                     parent.left.Class = Class;
  407.                 }
  408.             }
  409.             //------------Right Child
  410.             if (parent.right.data.size() == 1){//训练集为空
  411.                 parent.right.isLeaf = true;
  412.                 parent.right.Class = GetClass(parent.right.data.get(0));
  413.             }
  414.             else if (parent.right.attrArr.size() == 0){//属性集为空
  415.                 parent.right.isLeaf = true;
  416.                 parent.Class = GetMajorityClass(parent.right.data);
  417.             }
  418. // else if (parent.left.data.size() < MIN_NODE_SIZE){
  419. // parent.left.isLeaf = true;
  420. // parent.left.Class = GetMajorityClass(parent.left.data);
  421. // }
  422.             else {
  423.                 Class = CheckIfLeaf(parent.right.data);
  424.                 if (Class == null){
  425.                     parent.right.isLeaf = false;
  426.                     parent.right.Class = null;
  427. // System.out.println("make branch right: m:"+m);
  428.                 }
  429.                 else {//训练集样本全属于同一类别
  430.                     parent.right.isLeaf = true;
  431.                     parent.right.Class = Class;
  432.                 }
  433.             }
  434.             if (!parent.left.isLeaf){
  435.                 RecursiveSplit(parent.left, Ntreenum);
  436.             }
  437. // else {
  438. // System.out.print("left leaf! Class:"+parent.left.Class+" ");
  439. // PrintOutClasses(parent.left.data);
  440. // }
  441.             if (!parent.right.isLeaf)
  442.                 RecursiveSplit(parent.right, Ntreenum);
  443. // else {
  444. // System.out.print("leaf right! Class:"+parent.right.Class+" ");
  445. // PrintOutClasses(parent.right.data);
  446. // }
  447.         }
  448.     }
  449.     /**
  450. * Given a data matrix, return the most popular Y value (the class)
  451. * @param data The data matrix
  452. * @return The most popular class
  453. */
  454.     private int GetMajorityClass(List<int[]> data){
  455.         int[] counts=new int[RandomForest.C];
  456.         for (int[] record:data){
  457.             int Class=record[record.length-1];//GetClass(record);
  458.             counts[Class-1]++;
  459.         }
  460.         int index=-99;
  461.         int max=Integer.MIN_VALUE;
  462.         for (int i=0;i<counts.length;i++){
  463.             if (counts[i] > max){
  464.                 max=counts[i];
  465.                 index=i+1;
  466.             }
  467.         }
  468.         return index;
  469.     }
  470.     /**
  471. * Checks the {@link #CalcEntropy(double[]) entropy} of an index in a data matrix at a particular attribute (m)
  472. * and returns the entropy. If the entropy is lower than the minimum to date (lowestE), it is set to the minimum.
  473. *
  474. * The total entropy is calculated by getting the sub-entropy for below the split point and upper the split point.
  475. * The sub-entropy is calculated by first getting the {@link #GetClassProbs(List) proportion} of each of the classes
  476. * in this sub-data matrix. Then the entropy is {@link #CalcEntropy(double[]) calculated}. The lower sub-entropy
  477. * and upper sub-entropy are then weight averaged to obtain the total entropy.
  478. *
  479. * @param m the attribute to split on
  480. * @param n the index to check(rowID)
  481. * @param Nsub the num of records in the data matrix
  482. * @param lowestE the minimum entropy to date
  483. * @param parent the parent node
  484. * @return the entropy of this split
  485. */
  486.     private double CheckPosition(int m, int n, int Nsub, DoubleWrap lowestE, TreeNode parent, int nTre){
  487.         // var, index, train.size, lowest number, for a tree
  488.         //System.out.println("Checking position of the index attribute of tree :"+nTre);
  489.         if (n < 1) //exit conditions
  490.             return 0;
  491.         if (n > Nsub)
  492.             return 0;
  493.         List<int[]> lower = GetLower(parent.data, n);
  494.         List<int[]> upper = GetUpper(parent.data, n);
  495.         if (lower == null)
  496.             System.out.println("lower list null");
  497.         if (upper == null)
  498.             System.out.println("upper list null");
  499.         double[] p = GetClassProbs(parent.data);
  500.         double[] pl = GetClassProbs(lower);
  501.         double[] pu = GetClassProbs(upper);
  502.         double eP = CalcEntropy(p);
  503.         double eL = CalcEntropy(pl);
  504.         double eU = CalcEntropy(pu);
  505.         double e = eP - eL * lower.size()/(double)Nsub - eU * upper.size()/(double)Nsub;
  506. // System.out.println("g:"+parent.generation+" N:"+parent.data.size()+" M:"+RandomForest.M+" Ms:"+RandomForest.Ms+" n:"+n+" m:"+m+" val:"+parent.data.get(n)[m]+" e:"+e);
  507. // out.write(m+","+n+","+parent.data.get(n)[m]+","+e+"\n");
  508.         if (e > lowestE.d){
  509.             lowestE.d = e;
  510. // System.out.print("-");
  511.             parent.splitAttributeM = m;
  512.             parent.splitValue = parent.data.get(n)[m];
  513.             parent.left.data = lower;
  514.             parent.right.data = upper;
  515.         }
  516.         return e;//entropy
  517.     }
  518.     /**
  519. * Given a data record, return the Y value - take the last index
  520. *
  521. * @param record the data record
  522. * @return its y value (class)
  523. */
  524.     public static int GetClass(int[] record){
  525.         return record[RandomForest.M];
  526.     }
  527.     /**
  528. * Given a data matrix, check if all the y values are the same. If so,
  529. * return that y value, null if not
  530. *
  531. * @param data the data matrix
  532. * @return the common class (null if not common)
  533. */
  534.     private Integer CheckIfLeaf(List<int[]> data){
  535. // System.out.println("checkIfLeaf");
  536.         boolean isLeaf = true;
  537.         int ClassA = GetClass(data.get(0));
  538.         for (int i=1; i<data.size(); i++){
  539.             int[] recordB = data.get(i);
  540.             if (ClassA != GetClass(recordB)){
  541.                 isLeaf = false;
  542.                 break;
  543.             }
  544.         }
  545.         if (isLeaf)
  546.             return GetClass(data.get(0));
  547.         else
  548.             return null;
  549.     }
  550.     /**
  551. * Split a data matrix and return the upper portion
  552. *
  553. * @param data the data matrix to be split
  554. * @param nSplit index in a sub-data matrix that we will return all data records above it
  555. * @return the upper sub-data matrix
  556. */
  557.     private List<int[]> GetUpper(List<int[]> data, int nSplit){
  558.         int N = data.size();
  559.         List<int[]> upper = new ArrayList<int[]>(N-nSplit);
  560.         for (int n=nSplit; n<N; n++)
  561.             upper.add(data.get(n));
  562.         return upper;
  563.     }
  564.     /**
  565. * Split a data matrix and return the lower portion
  566. *
  567. * @param data the data matrix to be split
  568. * @param nSplit this index in a sub-data matrix that return all data records below it
  569. * @return the lower sub-data matrix
  570. */
  571.     private List<int[]> GetLower(List<int[]> data, int nSplit){
  572.         List<int[]> lower = new ArrayList<int[]>(nSplit);
  573.         for (int n=0; n<nSplit; n++)
  574.             lower.add(data.get(n));
  575.         return lower;
  576.     }
  577.     /**
  578. * This class compares two data records by numerically comparing a specified attribute
  579. *
  580. * @author kapelner
  581. *
  582. */
  583.     private class AttributeComparator implements Comparator{
  584.         /** the specified attribute */
  585.         private int m;
  586.         /**
  587. * Create a new comparator
  588. * @param m the attribute in which to compare on
  589. */
  590.         public AttributeComparator(int m){
  591.             this.m = m;
  592.         }
  593.         /**
  594. * Compare the two data records. They must be of type int[].
  595. *
  596. * @param o1 data record A
  597. * @param o2 data record B
  598. * @return -1 if A[m] < B[m], 1 if A[m] > B[m], 0 if equal
  599. */
  600.         public int compare(Object o1, Object o2){
  601.             int a = ((int[])o1)[m];
  602.             int b = ((int[])o2)[m];
  603.             if (a < b)
  604.                 return -1;
  605.             if (a > b)
  606.                 return 1;
  607.             else
  608.                 return 0;
  609.         }
  610.     }
  611.     /**
  612. * Sorts a data matrix by an attribute from lowest record to highest record
  613. * @param data the data matrix to be sorted
  614. * @param m the attribute to sort on
  615. */
  616.     @SuppressWarnings("unchecked")
  617.     private void SortAtAttribute(List<int[]> data, int m){
  618.         Collections.sort(data, new AttributeComparator(m));
  619.     }
  620.     /**
  621. * Given a data matrix, return a probability mass function representing
  622. * the frequencies of a class in the matrix (the y values)
  623. *
  624. * @param records the data matrix to be examined
  625. * @return the probability mass function
  626. */
  627.     private double[] GetClassProbs(List<int[]> records){
  628.         double N = records.size();
  629.         int[] counts = new int[RandomForest.C];//the num of target class
  630. // System.out.println("counts:");
  631. // for (int i:counts)
  632. // System.out.println(i+" ");
  633.         for (int[] record:records)
  634.             counts[GetClass(record)-1]++;
  635.         double[] ps = new double[RandomForest.C];
  636.         for (int j=0; j<RandomForest.C; j++)
  637.             ps[j] = counts[j]/N;
  638. // System.out.print("probs:");
  639. // for (double p:ps)
  640. // System.out.print(" "+p);
  641. // System.out.print("\n");
  642.         return ps;
  643.     }
  644.     /** ln(2) */
  645.     private static final double logoftwo = Math.log(2);
  646.     /**
  647. * Given a probability mass function indicating the frequencies of
  648. * class representation, calculate an "entropy" value using the method
  649. * in Tan Steinbach Kumar's "Data Mining" textbook
  650. *
  651. * @param ps the probability mass function
  652. * @return the entropy value calculated
  653. */
  654.     private double CalcEntropy(double[] ps){
  655.         double e = 0;
  656.         for (double p:ps){
  657.             if (p != 0) //otherwise it will divide by zero - see TSK p159
  658.                 e += p * Math.log(p)/Math.log(2);
  659.         }
  660.         return -e; //according to TSK p158
  661.     }
  662.     /**
  663. * Of the M attributes, select {@link RandomForest#Ms Ms} at random.
  664. *
  665. * @return The list of the Ms attributes' indices
  666. */
  667.     private ArrayList<Integer> GetVarsToInclude() {
  668.         boolean[] whichVarsToInclude = new boolean[RandomForest.M];
  669.         for (int i=0; i<RandomForest.M; i++)//初始化全为false
  670.             whichVarsToInclude[i]=false;
  671.         while (true){
  672.             int a = (int)Math.floor(Math.random() * RandomForest.M);//左闭右开 [0,1)
  673.             whichVarsToInclude[a] = true;
  674.             int N = 0;
  675.             for (int i=0; i<RandomForest.M; i++)
  676.                 if (whichVarsToInclude[i])
  677.                     N++;
  678.             if (N == RandomForest.Ms)
  679.                 break;
  680.         }
  681.         ArrayList<Integer> shortRecord = new ArrayList<Integer>(RandomForest.Ms);
  682.         for (int i=0; i<RandomForest.M; i++)
  683.             if (whichVarsToInclude[i])
  684.                 shortRecord.add(i);
  685.         return shortRecord;
  686.     }
  687.     /**
  688. * Create a boostrap sample of a data matrix
  689. * @param data the data matrix to be sampled
  690. * @param train the bootstrap sample
  691. * @param val the records that are absent in the bootstrap sample
  692. * @param numb the tree number
  693. */
  694.     private void BootStrapSample(ArrayList<int[]> data, ArrayList<int[]> train, ArrayList<int[]> val, int numb){
  695.         ArrayList<Integer> indices = new ArrayList<Integer>(N);
  696.         for (int n=0; n<N; n++)
  697.             indices.add((int)Math.floor(Math.random() * N));
  698.         ArrayList<Boolean> IsIn = new ArrayList<Boolean>(N);
  699.         for (int n=0; n<N; n++)
  700.             IsIn.add(false); //initialize it first
  701.         for (int index:indices){
  702.             train.add((data.get(index)).clone());//train has duplicated record
  703.             IsIn.set(index, true);
  704.         }
  705.         for (int i=0; i<N; i++)
  706.             if (!IsIn.get(i))
  707.                 val.add((data.get(i)).clone());
  708.         //System.out.println("created testing-data for tree : "+numb);//everywhere its set to false we get those to test data
  709. // System.out.println("bootstrap N:"+N+" size of bootstrap:"+bootstrap.size());
  710.     }
  711.     /**
  712. * Recursively deletes all data records from the tree. This is run after the tree
  713. * has been computed and can stand alone to classify incoming data.
  714. *
  715. * @param node initially, the root node of the tree
  716. */
  717.     private void FlushData(TreeNode node){
  718.         node.data = null;
  719.         if (node.left != null)
  720.             FlushData(node.left);
  721.         if (node.right != null)
  722.             FlushData(node.right);
  723.     }
  724. // // possible to clone trees
  725. // private DTree(){}
  726. // public DTree clone(){
  727. // DTree copy=new DTree();
  728. // copy.root=root.clone();
  729. // return copy;
  730. // }
  731.     /**
  732. * Get the number of data records in the test data matrix that were classified correctly
  733. */
  734.     public int getNumCorrect(){
  735.         return correct;
  736.     }
  737.     /**
  738. * Get the number of data records left out of the bootstrap sample
  739. */
  740.     public int getTotalNumInTestSet(){
  741.         return testN;
  742.     }
  743.     /**
  744. * Get the importance level of attribute m for this tree
  745. */
  746.     public int getImportanceLevel(int m){
  747.         return importances[m];
  748.     }
  749. // private void PrintOutNode(TreeNode parent,String init){
  750. // try {
  751. // System.out.println(init+"node: left"+parent.left.toString());
  752. // } catch (Exception e){
  753. // System.out.println(init+"node: left null");
  754. // }
  755. // try {
  756. // System.out.println(init+" right:"+parent.right.toString());
  757. // } catch (Exception e){
  758. // System.out.println(init+"node: right null");
  759. // }
  760. // try {
  761. // System.out.println(init+" isleaf:"+parent.isLeaf);
  762. // } catch (Exception e){}
  763. // try {
  764. // System.out.println(init+" splitAtrr:"+parent.splitAttributeM);
  765. // } catch (Exception e){}
  766. // try {
  767. // System.out.println(init+" splitval:"+parent.splitValue);
  768. // } catch (Exception e){}
  769. // try {
  770. // System.out.println(init+" class:"+parent.Class);
  771. // } catch (Exception e){}
  772. // try {
  773. // System.out.println(init+" data size:"+parent.data.size());
  774. // PrintOutClasses(parent.data);
  775. // } catch (Exception e){
  776. // System.out.println(init+" data: null");
  777. // }
  778. // }
  779. // private void PrintOutClasses(List<int[]> data){
  780. // try {
  781. // System.out.print(" (n="+data.size()+") ");
  782. // for (int[] record:data)
  783. // System.out.print(GetClass(record));
  784. // System.out.print("\n");
  785. // }
  786. // catch (Exception e){
  787. // System.out.println("PrintOutClasses: data null");
  788. // }
  789. // }
  790. // public static void PrintBoolArray(boolean[] b) {
  791. // System.out.print("vars to include: ");
  792. // for (int i=0;i<b.length;i++)
  793. // if (b[i])
  794. // System.out.print(i+" ");
  795. // System.out.print("\n\n");
  796. // }
  797. //
  798. // public static void PrintIntArray(List<int[]> lower) {
  799. // System.out.println("tree");
  800. // for (int i=0;i<lower.size();i++){
  801. // int[] record=lower.get(i);
  802. // for (int j=0;j<record.length;j++){
  803. // System.out.print(record[j]+" ");
  804. // }
  805. // System.out.print("\n");
  806. // }
  807. // System.out.print("\n");
  808. // System.out.print("\n");
  809. // }
  810. }

DescribeTrees.java

  1. package com.rf.real;
  2. import java.io.BufferedReader;
  3. import java.io.FileReader;
  4. import java.io.IOException;
  5. import java.util.ArrayList;
  6. public class DescribeTrees {
  7.     //把txt文件作为输入,导入到randomForest中
  8.     BufferedReader br = null;
  9.     String path;
  10.     public DescribeTrees(String path){
  11.         this.path = path;
  12.     }
  13.     public ArrayList<int[]> CreateInput(String path){
  14.         ArrayList<int[]> DataInput = new ArrayList<int[]>();
  15.         try {
  16.             String sCurrentLine;
  17.             br = new BufferedReader(new FileReader(path));
  18.             while ((sCurrentLine = br.readLine()) != null) {
  19.                 ArrayList<Integer> spaceIndex = new ArrayList<Integer>();//空格的index
  20.                 int i;
  21.                 if(sCurrentLine != null){
  22.                     sCurrentLine = " " + sCurrentLine + " ";
  23.                     for(i=0; i < sCurrentLine.length(); i++){
  24.                         if(Character.isWhitespace(sCurrentLine.charAt(i)))
  25.                             spaceIndex.add(i);
  26.                     }
  27.                     int[] DataPoint = new int[spaceIndex.size()-1];
  28.                     for(i=0; i<spaceIndex.size()-1; i++){
  29.                         DataPoint[i]=Integer.parseInt(sCurrentLine.substring(spaceIndex.get(i)+1, spaceIndex.get(i+1)));
  30.                     }
  31.                     /* print DataPoint
  32. for(k=0; k<DataPoint.length; k++){
  33. //System.out.print("-");
  34. System.out.print(DataPoint[k]);
  35. System.out.print(" ");
  36. }
  37. **/
  38.                     DataInput.add(DataPoint);
  39.                 }
  40.             }
  41.         } catch (IOException e) {
  42.             e.printStackTrace();
  43.         } finally {
  44.             try {
  45.                 if (br != null)
  46.                     br.close();
  47.             } catch (IOException ex) {
  48.                 ex.printStackTrace();
  49.             }
  50.         }
  51.         return DataInput;
  52.     }
  53. }

MainRun.java

  1. package com.rf.real;
  2. import java.util.ArrayList;
  3. public class MainRun {
  4.     @SuppressWarnings("static-access")
  5.     public static void main(String args[]){
  6.         String trainPath = "C:\\Users\\kwok\\Desktop\\rf_data\\Data.txt";
  7.         String testPath = "C:\\Users\\kwok\\Desktop\\rf_data\\Test.txt";
  8.         int numTrees = 100;
  9.         DescribeTrees DT = new DescribeTrees(trainPath);
  10.         ArrayList<int[]> Train = DT.CreateInput(trainPath);
  11.         DescribeTrees DT2 = new DescribeTrees(testPath);
  12.         ArrayList<int[]> Test = DT2.CreateInput(testPath);
  13.         int categ = 0;
  14.         //the num of labels
  15.         int trainLength = Train.get(0).length;
  16.         for(int k=0; k<Train.size(); k++){
  17.             if(Train.get(k)[trainLength-1] < categ)
  18.                 continue;
  19.             else{
  20.                 categ = Train.get(k)[trainLength-1];
  21.             }
  22.         }
  23.         RandomForest Rf = new RandomForest(numTrees, Train, Test);
  24.         Rf.C = categ;//the num of labels
  25.         Rf.M = Train.get(0).length-1;//the num of Attr
  26.         //属性扰动,每次从M个属性中随机选取Ms个属性,Ms = ln(m)/ln(2)
  27.         Rf.Ms = (int)Math.round(Math.log(Rf.M)/Math.log(2) + 1);//随机选择的属性数量
  28.         Rf.Start();
  29.     }
  30. }
  31. RandomForest.java
  32. package com.rf.real;
  33. import java.util.ArrayList;
  34. import java.util.HashMap;
  35. import java.util.concurrent.ExecutorService;
  36. import java.util.concurrent.Executors;
  37. import java.util.concurrent.TimeUnit;
  38. /**
  39. * Random Forest
  40. */
  41. public class RandomForest {
  42.     /**
  43. * 可用的线程数量
  44. * */
  45.     private static final int NUM_THREADS = Runtime.getRuntime().availableProcessors();
  46.     /**
  47. *target类别数量
  48. * */
  49.     public static int C;
  50.     /**
  51. * 属性(列)的数量
  52. * */
  53.     public static int M;
  54.     /**
  55. *属性扰动,每次从M个属性中随机选取Ms个属性,Ms = log2(M)
  56. */
  57.     public static int Ms;
  58.     /** the collection of the forest's decision trees */
  59.     private ArrayList<DTree> trees;
  60.     /**
  61. * 开始时间
  62. * */
  63.     private long time_o;
  64.     /** the number of trees in this random tree */
  65.     private int numTrees;
  66.     /**
  67. * 为了实时显示进度,每建立一棵树的更新量
  68. */
  69.     private double update;
  70.     /**
  71. * 为了实时显示进度,初试值
  72. */
  73.     private double progress;
  74.     /** importance Array */
  75.     private int[] importances;
  76.     /** key = a record from data matrix
  77. * value = RF的分类结果*/
  78.     private HashMap<int[],int[]> estimateOOB;
  79.     /** all of the predictions from RF */
  80.     private ArrayList<ArrayList<Integer>> Prediction;
  81.     /** RF的错误率 */
  82.     private double error;
  83.     /** 控制树生长的进程池 */
  84.     private ExecutorService treePool;
  85.     /** 原始训练数据 */
  86.     private ArrayList<int[]> train_data;
  87.     /** 测试数据*/
  88.     private ArrayList<int[]> testdata;
  89.     /**
  90. * Initializes a Random forest
  91. * @param numTrees RF的数量
  92. * @param train_data 原始训练数据
  93. * @param t_data 测试数据
  94. */
  95.     public RandomForest(int numTrees, ArrayList<int[]> train_data, ArrayList<int[]> t_data ){
  96.         this.numTrees = numTrees;
  97.         this.train_data = train_data;
  98.         this.testdata = t_data;
  99.         trees = new ArrayList<DTree>(numTrees);
  100.         update = 100 / ((double)numTrees);
  101.         progress = 0;
  102.         StartTimer();
  103.         System.out.println("creating "+numTrees+" trees in a random Forest. . .");
  104.         System.out.println("total data size is "+train_data.size());
  105.         System.out.println("number of attributes " + (train_data.get(0).length-1));
  106.         System.out.println("number of selected attributes " + ((int)Math.round(Math.log(train_data.get(0).length-1)/Math.log(2) + 1)));
  107.         estimateOOB = new HashMap<int[],int[]>(train_data.size());
  108.         Prediction = new ArrayList<ArrayList<Integer>>();
  109.     }
  110.     /**
  111. * Begins the creation of random forest
  112. */
  113.     public void Start() {
  114.         System.out.println("Num of threads started : " + NUM_THREADS);
  115.         System.out.println("Running...");
  116.         treePool = Executors.newFixedThreadPool(NUM_THREADS);
  117.         for (int t=0; t < numTrees; t++){
  118.             System.out.println("structing " + t + " Tree");
  119.             treePool.execute(new CreateTree(train_data,this,t+1));
  120.             //System.out.print(".");
  121.         }
  122.         treePool.shutdown();
  123.         try {
  124.             treePool.awaitTermination(Long.MAX_VALUE, TimeUnit.SECONDS); //effectively infinity
  125.         } catch (InterruptedException ignored){
  126.             System.out.println("interrupted exception in Random Forests");
  127.         }
  128.         System.out.println("");
  129.         System.out.println("Finished tree construction");
  130.         TestForest(trees, testdata);
  131.         CalcImportances();
  132.         System.out.println("Done in "+TimeElapsed(time_o));
  133.     }
  134.     /**
  135. * @param collec_tree the collection of the forest's decision trees
  136. * @param test_data 测试数据集
  137. */
  138.     private void TestForest(ArrayList<DTree> collec_tree, ArrayList<int[]> test_data ) {
  139.         int correstness = 0 ;
  140.         int k = 0;
  141.         ArrayList<Integer> actualLabel = new ArrayList<Integer>();
  142.         for(int[] rec:test_data){
  143.             actualLabel.add(rec[rec.length-1]);
  144.         }
  145.         int treeNumber = 1;
  146.         for(DTree dt:collec_tree){
  147.             dt.CalculateClasses(test_data, treeNumber);
  148.             Prediction.add(dt.predictions);
  149.             treeNumber++;
  150.         }
  151.         for(int i = 0; i<test_data.size(); i++){
  152.             ArrayList<Integer> Val = new ArrayList<Integer>();
  153.             for(int j =0; j<collec_tree.size(); j++){
  154.                 Val.add(Prediction.get(j).get(i));//The collection of each Tree's prediction in i-th record
  155.             }
  156.             int pred = labelVote(Val);//Voting algorithm
  157.             if(pred == actualLabel.get(i)){
  158.                 correstness++;
  159.             }
  160.         }
  161.         System.out.println("Accuracy of Forest is : " + (100 * correstness / test_data.size()) + "%");
  162.     }
  163.     /**
  164. * Voting algorithm
  165. * @param treePredict The collection of each Tree's prediction in i-th record
  166. */
  167.     private int labelVote(ArrayList<Integer> treePredict){
  168.         // TODO Auto-generated method stub
  169.         int max=0, maxclass=-1;
  170.         for(int i=0; i<treePredict.size(); i++){
  171.             int count = 0;
  172.             for(int j=0; j<treePredict.size(); j++){
  173.                 if(treePredict.get(j) == treePredict.get(i)){
  174.                     count++;
  175.                 }
  176.                 if(count > max){
  177.                     maxclass = treePredict.get(i);
  178.                     max = count;
  179.                 }
  180.             }
  181.         }
  182.         return maxclass;
  183.     }
  184.     /**
  185. * 计算RF的分类错误率
  186. */
  187.     private void CalcErrorRate(){
  188.         double N=0;
  189.         int correct=0;
  190.         for (int[] record:estimateOOB.keySet()){
  191.             N++;
  192.             int[] map=estimateOOB.get(record);
  193.             int Class=FindMaxIndex(map);
  194.             if (Class == DTree.GetClass(record))
  195.                 correct++;
  196.         }
  197.         error=1-correct/N;
  198.         System.out.println("correctly mapped "+correct);
  199.         System.out.println("Forest error rate % is: "+(error*100));
  200.     }
  201.     /**
  202. * 更新 OOBEstimate
  203. * @param record a record from data matrix
  204. * @param Class
  205. */
  206.     public void UpdateOOBEstimate(int[] record, int Class){
  207.         if (estimateOOB.get(record) == null){
  208.             int[] map = new int[C];
  209.             //System.out.println("class of record : "+Class);map[Class-1]++;
  210.             estimateOOB.put(record,map);
  211.         }
  212.         else {
  213.             int[] map = estimateOOB.get(record);
  214.             map[Class-1]++;
  215.         }
  216.     }
  217.     /**
  218. * calculates the importance levels for all attributes.
  219. */
  220.     private void CalcImportances() {
  221.         importances = new int[M];
  222.         for (DTree tree:trees){
  223.             for (int i=0; i<M; i++)
  224.                 importances[i] += tree.getImportanceLevel(i);
  225.         }
  226.         for (int i=0;i<M;i++)
  227.             importances[i] /= numTrees;
  228.         System.out.println("The forest-wide importance as follows:");
  229.         for (int j=0; j<importances.length; j++){
  230.             System.out.println("Attr" + j + ":" + importances[j]);
  231.         }
  232.     }
  233.     /** 计时开始 */
  234.     private void StartTimer(){
  235.         time_o = System.currentTimeMillis();
  236.     }
  237.     /**
  238. * 创建一棵决策树
  239. */
  240.     private class CreateTree implements Runnable{
  241.         /** 训练数据 */
  242.         private ArrayList<int[]> train_data;
  243.         /** 随机森林 */
  244.         private RandomForest forest;
  245.         /** the numb of RF */
  246.         private int treenum;
  247.    
  248.         public CreateTree(ArrayList<int[]> train_data, RandomForest forest, int num){
  249.             this.train_data = train_data;
  250.             this.forest = forest;
  251.             this.treenum = num;
  252.         }
  253.         /**
  254. * Create the decision tree
  255. */
  256.         public void run() {
  257.             System.out.println("Creating a Dtree num : " + treenum + " ");
  258.             trees.add(new DTree(train_data, forest, treenum));
  259.             //System.out.println("tree added in RandomForest.AddTree.run()");
  260.             progress += update;
  261.             System.out.println("---progress:" + progress);
  262.         }
  263.     }
  264.     /**
  265. * Evaluates testdata
  266. * @param record a record to be evaluated
  267. */
  268.     public int Evaluate(int[] record){
  269.         int[] counts=new int[C];
  270.         for (int t=0;t<numTrees;t++){
  271.             int Class=(trees.get(t)).Evaluate(record);
  272.             counts[Class]++;
  273.         }
  274.         return FindMaxIndex(counts);
  275.     }
  276.     
  277.     public static int FindMaxIndex(int[] arr){
  278.         int index=0;
  279.         int max = Integer.MIN_VALUE;
  280.         for (int i=0;i<arr.length;i++){
  281.             if (arr[i] > max){
  282.                 max=arr[i];
  283.                 index=i;
  284.             }
  285.         }
  286.         return index;
  287.     }
  288.     
  289.     /**
  290. * @param timeinms 开始时间
  291. * @return the hr,min,s
  292. */
  293.     private static String TimeElapsed(long timeinms){
  294.         int s=(int)(System.currentTimeMillis()-timeinms)/1000;
  295.         int h=(int)Math.floor(s/((double)3600));
  296.         s-=(h*3600);
  297.         int m=(int)Math.floor(s/((double)60));
  298.         s-=(m*60);
  299.         return ""+h+"hr "+m+"m "+s+"s";
  300.     }
  301. }

Data.txt

https://github.com/Edgis/Machine-Learning-Algorithm/blob/master/randomForest/Data.txt

Test.txt 

https://github.com/Edgis/Machine-Learning-Algorithm/blob/master/randomForest/Test.txt

 

结论

Data.txt是训练数据集, Text.txt是测试数据集,经线下测试,在使用上述train data建立100棵树的RF model,Test集的forest-wide Error可达到90%,同时发现Attr16和Attr17的Variable Importance较高

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/码创造者/article/detail/878198
推荐阅读
相关标签
  

闽ICP备14008679号