赞
踩
随机森林指的是利用多棵树对样本进行训练并预测的一种分类器。该分类器最早由Leo Breiman和Adele Cutler提出,并被注册成了商标。
在机器学习中,随机森林是一个包含多个决策树的分类器,并且其输出的类别是由个别树输出的类别的众数而定。Leo Breiman和Adele Cutler发展出推论出随机森林的算法。 随机森林是一种灵活且易于使用的机器学习算法,即便没有超参数调优,也可以在大多数情况下得到很好的结果。它也是最常用的算法之一,因为它很简易,既可用于分类也能用于回归任务。
一个名叫安德鲁的人,想知道一年的假期旅行中他应该去哪些地方。他会向了解他的朋友们咨询建议。
起初,他去寻找一位朋友,这位朋友会问安德鲁他曾经去过哪些地方,他喜欢还是不喜欢这些地方。基于这些回答就能给安德鲁一些建议,这便是一种典型的决策树算法。
朋友通过安德鲁的回答,为其制定出一些规则来指导应当推荐的地方。随后,安德鲁开始寻求越来越多朋友们的建议,他们会问他不同的问题,并从中给出一些建议。最后,安德鲁选择了推荐最多的地方,这便是典型的随机森林算法。
随机森林建立了多个决策树,并将它们合并在一起以获得更准确和稳定的预测。随机森林的一大优势在于它既可用于分类,也可用于回归问题,这两类问题恰好构成了当前的大多数机器学习系统所需要面对的。
随机森林分类器使用所有的决策树分类器以及bagging 分类器的超参数来控制整体结构。 与其先构建bagging分类器,并将其传递给决策树分类器,可以直接使用随机森林分类器类,这样对于决策树而言,更加方便和优化。要注意的是,回归问题同样有一个随机森林回归器与之相对应。随机森林算法中树的增长会给模型带来额外的随机性。与决策树不同的是,每个节点被分割成最小化误差的最佳特征,在随机森林中我们选择随机选择的特征来构建最佳分割。因此,当您在随机森林中,仅考虑用于分割节点的随机子集,甚至可以通过在每个特征上使用随机阈值来使树更加随机,而不是如正常的决策树一样搜索最佳阈值。这个过程产生了广泛的多样性,通常可以得到更好的模型。
在用于提高模型预测准确性方面:
1.“n_estimators”超参数,表示算法在进行最大投票或采取预测平均值之前建立的树数。一般来说,树的数量越多,性能越好,预测也越稳定,但这也会减慢计算速度。
2.“max_features”,它表示随机森林在单个树中可拥有的特征最大数量。
3.“min_sample_leaf”,决定了叶子的数量。
加快模型计算方面:
1.“n_jobs”超参数表示引擎允许使用处理器的数量。 若值为1,则只能使用一个处理器。 值为-1则表示没有限制。
2.“random_state”,表示随机数种子,保证模型的输出具有可复制性。 当它被赋于一个指定值,且模型训练具有相同的参数和相同的训练数据时,该模型将始终产生相同的结果。
3.“oob_score”(也称为oob采样),它是一种随机森林交叉验证方法。 在这个抽样中,大约三分之一的数据不用于模型训练,而用来评估模型的性能。这些样本被称为袋外样本。它与留一法交叉验证方法非常相似,但几乎没有附加的计算负担。
优点:
可以用于回归和分类任务,并且很容易查看模型的输入特征的相对重要性。
默认的超参数通常会产生一个很好的预测结果。
随机森林可以有效地避免过拟合的问题,因为只要森林中有足够多的树,分类器就不会过度拟合模型。
缺点:
使用大量的树会使算法变得很慢,并且无法做到实时预测。一般而言,这些算法训练速度很快,预测十分缓慢。越准确的预测需要越多的树,这将导致模型越慢。
适用范围:
随机森林算法可被用于很多不同的领域,如银行,股票市场,医药和电子商务。在银行领域,它通常被用来检测那些比普通人更高频率使用银行服务的客户,并及时偿还他们的债务。同时,它也会被用来检测那些想诈骗银行的客户。在金融领域,它可用于预测未来股票的趋势。在医疗保健领域,它可用于识别药品成分的正确组合,分析患者的病史以识别疾病。除此之外,在电子商务领域中,随机森林可以被用来确定客户是否真的喜欢某个产品。
随机森林就是对决策树的集成,但有两点不同:
(1)采样的差异性:从含m个样本的数据集中有放回的采样,得到含m个样本的采样集,用于训练。这样能保证每个决策树的训练样本不完全一样。
(2)特征选取的差异性:每个决策树的n个分类特征是在所有特征中随机选择的(n是一个需要我们自己调整的参数)
随机森林需要调整的参数有:
(1) 决策树的个数
(2) 特征属性的个数
(3) 递归次数(即决策树的深度)
参考:https://blog.csdn.net/flying_sfeng/article/details/64133822
(1) 导入文件并将所有特征转换为float形式
- #加载数据
- def loadCSV(filename):
- dataSet=[]
- with open(filename,'r') as file:
- csvReader=csv.reader(file)
- for line in csvReader:
- dataSet.append(line)
- return dataSet
-
- #除了判别列,其他列都转换为float类型
- def column_to_float(dataSet):
- featLen=len(dataSet[0])-1
- for data in dataSet:
- for column in range(featLen):
- data[column]=float(data[column].strip())
(2) 将数据集分成n份,方便交叉验证
- #将数据集分成N块,方便交叉验证
- def spiltDataSet(dataSet,n_folds):
- fold_size=int(len(dataSet)/n_folds)
- dataSet_copy=list(dataSet)
- dataSet_spilt=[]
- for i in range(n_folds):
- fold=[]
- while len(fold) < fold_size: #这里不能用if,if只是在第一次判断时起作用,while执行循环,直到条件不成立
- index=randrange(len(dataSet_copy))
- fold.append(dataSet_copy.pop(index)) #pop() 函数用于移除列表中的一个元素(默认最后一个元素),并且返回该元素的值。
- dataSet_spilt.append(fold)
- return dataSet_spilt
(3) 构造数据子集(随机采样),并在指定特征个数(假设m个,手动调参)下选取最优特征
- #构造数据子集
- def get_subsample(dataSet,ratio):
- subdataSet=[]
- lenSubdata=round(len(dataSet)*ratio)
- while len(subdataSet) < lenSubdata:
- index=randrange(len(dataSet)-1)
- subdataSet.append(dataSet[index])
- #print len(subdataSet)
- return subdataSet
-
- #选取任意的n个特征,在这n个特征中,选取分割时的最优特征
- def get_best_spilt(dataSet,n_features):
- features=[]
- class_values=list(set(row[-1] for row in dataSet))
- b_index,b_value,b_loss,b_left,b_right=999,999,999,None,None
- while len(features) < n_features:
- index=randrange(len(dataSet[0])-1)
- if index not in features:
- features.append(index)
- #print 'features:',features
- for index in features:
- for row in dataSet:
- left,right=data_spilt(dataSet,index,row[index])
- loss=spilt_loss(left,right,class_values)
- if loss < b_loss:
- b_index,b_value,b_loss,b_left,b_right=index,row[index],loss,left,right
- #print b_loss
- #print type(b_index)
- return {'index':b_index,'value':b_value,'left':b_left,'right':b_right}
(4) 构造决策树
- #构造决策树
- def build_tree(dataSet,n_features,max_depth,min_size):
- root=get_best_spilt(dataSet,n_features)
- sub_spilt(root,n_features,max_depth,min_size,1)
- return root
(5) 创建随机森林(多个决策树的结合)
- #创建随机森林
- def random_forest(train,test,ratio,n_feature,max_depth,min_size,n_trees):
- trees=[]
- for i in range(n_trees):
- subTrain=get_subsample(train,ratio)
- tree=build_tree(subTrain,n_features,max_depth,min_size)
- #print 'tree %d: '%i,tree
- trees.append(tree)
- #predict_values = [predict(trees,row) for row in test]
- predict_values = [bagging_predict(trees, row) for row in test]
- return predict_values
(6) 输入测试集并进行测试,输出预测结果
- #预测测试集结果
- def predict(tree,row):
- predictions=[]
- if row[tree['index']] < tree['value']:
- if isinstance(tree['left'],dict):
- return predict(tree['left'],row)
- else:
- return tree['left']
- else:
- if isinstance(tree['right'],dict):
- return predict(tree['right'],row)
- else:
- return tree['right']
- # predictions=set(predictions)
参考:https://blog.csdn.net/xiuxian4728/article/details/78897134
决策树的创建
随机森林算法
1.对程序的测试的数据采用了公开数据集NSL-KDD的数据集 [ NSL-KDD ]。
2.属性数据采用的是连续属性(Continuous Attributes),在划分决策树节点,分成两个分叉。
3.在代码的后面,增加了计算Variable Importance 的内容,理论知识可参考 [ Variable Importance ]
代码
DTree.java
- package com.rf.real;
-
- import java.util.*;
-
- /**
- * Creates a decision tree based on the specifications of random forest trees
- */
- public class DTree {
-
- /** Instead of checking each index we'll skip every INDEX_SKIP indices unless there's less than MIN_SIZE_TO_CHECK_EACH*/
- private static final int INDEX_SKIP = 2;
- /** If there's less than MIN_SIZE_TO_CHECK_EACH points, we'll check each one */
- private static final int MIN_SIZE_TO_CHECK_EACH = 10;
- /** If the number of data points is less than MIN_NODE_SIZE, we won't continue splitting, we'll take the majority vote */
- private static final int MIN_NODE_SIZE=5;
- /** the number of data records */
- private int N;
- /** the number of samples left out of the boostrap of all N to test error rate
- * @see <a href="http://www.stat.berkeley.edu/~breiman/RandomForests/cc_home.htm#ooberr">OOB error estimate</a>
- */
- private int testN;
- /** Of the testN, the number that were correctly identified */
- private int correct;
- /** an estimate of the importance of each attribute in the data record
- * @see <a href="http://www.stat.berkeley.edu/~breiman/RandomForests/cc_home.htm#varimp">Variable Importance</a>
- */
- private int[] importances;
- /** This is the root of the Decision Tree */
- private TreeNode root;
- /** This is a pointer to the Random Forest this decision tree belongs to */
- private RandomForest forest;
- /** This keeps track of all the predictions done by this tree */
- public ArrayList<Integer> predictions;
-
- /**
- * This constructs a decision tree from a data matrix.
- * It first creates a bootstrap sample, the train data matrix, as well as the left out records,
- * the test data matrix. Then it creates the tree, then calculates the variable importances (not essential)
- * and then removes the links to the actual data (to save memory)
- *
- * @param data The data matrix as a List of int arrays - each array is one record,
- * each index in the array is one attribute, and the last index is the class
- * (ie [ x1, x2, . . ., xM, Y ]).
- * @param forest The random forest this decision tree belongs to
- * @param num the Tree number
- */
- public DTree(ArrayList<int[]> data, RandomForest forest, int num){
- this.forest = forest;
- N = data.size();
- importances = new int[RandomForest.M];
- predictions = new ArrayList<Integer>();
-
- //System.out.println("Make a Dtree num : "+num+" with N:"+N+" M:"+RandomForest.M+" Ms:"+RandomForest.Ms);
-
- ArrayList<int[]> train = new ArrayList<int[]>(N); //data becomes the "bootstrap" - that's all it knows
- ArrayList<int[]> val = new ArrayList<int[]>();
- //System.out.println("Creating tree No."+num);
- BootStrapSample(data, train, val, num);//populates train and test using data
- testN = val.size();
- correct = 0;
-
- root = CreateTree(train, num);//creating tree using training data
- CalcTreeVariableImportanceAndError(val, num);
- FlushData(root);
- }
- /**
- * Responsible for gauging the error rate of this tree and
- * calculating the importance values of each attribute
- *
- * @param val The left out data matrix
- * @param nv The Tree number
- */
- private void CalcTreeVariableImportanceAndError(ArrayList<int[]> val, int nv) {
- //calculate error rate
- correct = CalcTreeErrorRate(val, nv);//the num of correct prediction record
- CalculateClasses(val, nv);
- //calculate importance of each attribute
- for (int m=0; m<RandomForest.M; m++){
- ArrayList<int[]> test_data = RandomlyPermuteAttribute(CloneData(val), m);
- int correctAfterPermute = 0;
- for (int[] arr:test_data){
- int pre_label = Evaluate(arr);
- if (pre_label == GetClass(arr))
- correctAfterPermute++;
- }
- importances[m] += (correct - correctAfterPermute);
- }
- System.out.println("The importances of tree " + nv + " as follows");
- // for(int m=0; m<importances.length; m++){
- // System.out.println(" Attr" + m + ":" + importances[m]);
- // }
- }
-
- /**
- * Calculates the tree error rate,
- * displays the error rate to console,
- * and updates the total forest error rate
- *
- * @param val the left out test data matrix
- * @param nu The Tree number
- * @return the number correct
- */
- public int CalcTreeErrorRate(ArrayList<int[]> val, int nu){
- int correct = 0;
- for (int[] record:val){
- int pre_label = Evaluate(record);
- forest.UpdateOOBEstimate(record, pre_label);
- int actual_label = record[record.length-1];//actual_label
- if (pre_label == actual_label)
- correct++;
- }
-
- double err = 1 - correct/((double)val.size());
- // System.out.print("\n");
- System.out.println("Number of correct = " + correct + ", out of :" + val.size());
- System.out.println("Test-Data error rate of tree " + nu + " is: " + (err * 100) + " %");
- return correct;
- }
- /**
- * This method will get the classes and will return the updates
- * @param val The left out data matrix
- * @param nu The Tree number
- */
- public ArrayList<Integer> CalculateClasses(ArrayList<int[]> val, int nu){
- ArrayList<Integer> preList = new ArrayList<Integer>();
- int korect = 0;
- for(int[] record : val){
- int pre_label = Evaluate(record);
- preList.add(pre_label);
- int actual_label = record[record.length-1];
- if (pre_label==actual_label)
- korect++;
- }
- predictions = preList;
- return preList;
-
- }
- /**
- * This will classify a new data record by using tree
- * recursion and testing the relevant variable at each node.
- *
- * This is probably the most-used function in all of <b>GemIdent</b>.
- * It would make sense to inline this in assembly for optimal performance.
- *
- * @param record the data record to be classified
- * @return the class the data record was classified into
- */
- public int Evaluate(int[] record){
- TreeNode evalNode = root;
-
- while (true){
- if (evalNode.isLeaf)
- return evalNode.Class;
- if (record[evalNode.splitAttributeM] <= evalNode.splitValue)
- evalNode = evalNode.left;
- else
- evalNode = evalNode.right;
- }
- }
- /**
- * Takes a list of data records, and switches the m-th attribute across data records.
- * This is important in order to test the importance of the attribute. If the attribute
- * is randomly permuted and the result of the classification is the same, the attribute is
- * not important to the classification and vice versa.
- *
- * @see <a href="http://www.stat.berkeley.edu/~breiman/RandomForests/cc_home.htm#varimp">Variable Importance</a>
- * @param val The left out data matrix to be permuted
- * @param m The attribute index to be permuted
- * @return The data matrix with the m-th column randomly permuted
- */
- private ArrayList<int[]> RandomlyPermuteAttribute(ArrayList<int[]> val, int m){
- int num = val.size() * 2;
- for (int i=0; i<num; i++){
- int a = (int)Math.floor(Math.random() * val.size());
- int b = (int)Math.floor(Math.random() * val.size());
- int[] arrA = val.get(a);
- int[] arrB = val.get(b);
- int temp = arrA[m];
- arrA[m] = arrB[m];
- arrB[m] = temp;
- }
- return val;
- }
- /**
- * Creates a copy of the data matrix
- * @param data the data matrix to be copied
- * @return the cloned data matrix
- */
- private ArrayList<int[]> CloneData(ArrayList<int[]> data){
- ArrayList<int[]> clone=new ArrayList<int[]>(data.size());
- int M=data.get(0).length;
- for (int i=0;i<data.size();i++){
- int[] arr=data.get(i);
- int[] arrClone=new int[M];
- for (int j=0;j<M;j++){
- arrClone[j]=arr[j];
- }
- clone.add(arrClone);
- }
- return clone;
- }
- /**
- * This creates the decision tree according to the specifications of random forest trees.
- *
- * @see <a href="http://www.stat.berkeley.edu/~breiman/RandomForests/cc_home.htm#overview">Overview of random forest decision trees</a>
- * @param train the training data matrix (a bootstrap sample of the original data)
- * @param ntree the tree number
- * @return the TreeNode object that stores information about the parent node of the created tree
- */
- private TreeNode CreateTree(ArrayList<int[]> train, int ntree){
- TreeNode root = new TreeNode();
- root.data = train; // public List<int[]> data;
- //System.out.println("creating ");
- RecursiveSplit(root, ntree);
- return root;
- }
- /**
- * @author DEGIS
- */
- private class TreeNode implements Cloneable{
- public boolean isLeaf;
- public TreeNode left;
- public TreeNode right;
- public int splitAttributeM;
- public Integer Class;
- public List<int[]> data;
- public int splitValue;
- public int generation;
- public ArrayList<Integer> attrArr;
-
- public TreeNode(){
- splitAttributeM=-99;
- splitValue=-99;
- generation=1;
- }
- public TreeNode clone(){ //"data" element always null in clone
- TreeNode treeCopy = new TreeNode();
- treeCopy.isLeaf = isLeaf;
- if (left != null) //otherwise null
- treeCopy.left = left.clone();
- if (right != null) //otherwise null
- treeCopy.right = right.clone();
- treeCopy.splitAttributeM = splitAttributeM;
- treeCopy.Class = Class;
- treeCopy.splitValue = splitValue;
- return treeCopy;
- }
- }
- private class DoubleWrap{
- public double d;
- public DoubleWrap(double d){
- this.d=d;
- }
- }
- /**
- * This is the crucial function in tree creation.
- *
- * <ul>
- * <li>Step A
- * Check if this node is a leaf, if so, it will mark isLeaf true
- * and mark Class with the leaf's class. The function will not
- * recurse past this point.
- * </li>
- * <li>Step B
- * Create a left and right node and keep their references in
- * this node's left and right fields. For debugging purposes,
- * the generation number is also recorded. The {@link RandomForest#Ms Ms} attributes are
- * now chosen by the {@link #GetVarsToInclude() GetVarsToInclude} function
- * </li>
- * <li>Step C
- * For all Ms variables, first {@link #SortAtAttribute(List,int) sort} the data records by that attribute,
- * then look through the values from lowest to highest.
- * If value i is not equal to value i+1, record i in the list of "indicesToCheck."
- * This speeds up the splitting. If the number of indices in indicesToCheck > MIN_SIZE_TO_CHECK_EACH
- * then we will only {@link #CheckPosition(int, int, int, DoubleWrap, TreeNode, int) check} the
- * entropy at every {@link #INDEX_SKIP INDEX_SKIP} index otherwise,
- * we {@link #CheckPosition(int, int, int, DoubleWrap, TreeNode, int) check}
- * the entropy for all. The "E" variable records the entropy and we are trying to find the minimum in which to split on
- * </li>
- * <li>Step D
- * The newly generated left and right nodes are now checked:
- * If the node has only one record, we mark it as a leaf and set its class equal to the class of the record.
- * If it has less than {@link #MIN_NODE_SIZE MIN_NODE_SIZE} records,
- * then we mark it as a leaf and set its class equal to the {@link #GetMajorityClass(List) majority class}.
- * If it has more, then we do a manual check on its data records and if all have the same class, then it
- * is marked as a leaf. If not, then we run {@link #RecursiveSplit(TreeNode, int) RecursiveSplit} on
- * that node
- * </li>
- * </ul>
- *
- * @param parent The node of the parent
- * @param Ntreenum the tree number
- */
- private void RecursiveSplit(TreeNode parent, int Ntreenum){
- //System.out.println("Recursivly spilitting tree : "+Ntreenum);
- if (!parent.isLeaf){
- //-------------------------------Step A
- //当前结点包含的样本全属于同一类别,无需划分;
- Integer Class = CheckIfLeaf(parent.data);
- if (Class != null){
- parent.isLeaf = true;
- parent.Class = Class;
- //System.out.println("leaf for this tree: "+Ntreenum);
- // System.out.print("parent leaf! Class:"+parent.Class+" ");
- // PrintOutClasses(parent.data);
- return;
- }
-
- //-------------------------------Step B
- int Nsub = parent.data.size();
- // PrintOutClasses(parent.data);
- ArrayList<Integer> vars = GetVarsToInclude();//randomly selects Ms' index of attributes from M
- parent.attrArr = vars;
- parent.left = new TreeNode();
- parent.left.generation = parent.generation + 1;
- parent.right = new TreeNode();
- parent.right.generation = parent.generation + 1;
-
- DoubleWrap lowestE = new DoubleWrap(Double.MIN_VALUE);
-
- //假如当前属性集为空,返回样本数最多的类;
- if(parent.attrArr.size() == 0){
- parent.Class = GetMajorityClass(parent.data);
- return;
- }
- //-------------------------------Step C
- //所有样本在所有属性上取值相同,无法划分,返回样本数最多的类;
- int sameClass = 0;
- for (int m:parent.attrArr){
- SortAtAttribute(parent.data, m);//sorts on a particular column in the row
- ArrayList<Integer> indicesToCheck = new ArrayList<Integer>();
- for (int n=1; n<Nsub; n++){
- int classA = GetClass(parent.data.get(n-1));
- int classB = GetClass(parent.data.get(n));
- if (classA != classB)
- indicesToCheck.add(n);
- }
- //所有样本在所有属性上取值相同,无法划分,返回样本数最多的类;
- if (indicesToCheck.size() == 0)
- sameClass++;
- }
- if(sameClass == parent.attrArr.size()){
- parent.isLeaf = true;
- parent.Class = GetMajorityClass(parent.data);
- return;
- }
-
- for (int m:parent.attrArr){
- SortAtAttribute(parent.data, m);//sorts on a particular column in the row
- ArrayList<Integer> indicesToCheck = new ArrayList<Integer>();
- for (int n=1; n<Nsub; n++){
- int classA = GetClass(parent.data.get(n-1));
- int classB = GetClass(parent.data.get(n));
- if (classA != classB)
- indicesToCheck.add(n);
- }
- // System.out.print("indices to check for tree : "+Ntreenum);
- // for (int n:indicesToCheck)
- // System.out.print(" "+n);
- // System.out.print("\n ");
- if (indicesToCheck.size() > MIN_SIZE_TO_CHECK_EACH){
- for (int i=0; i<indicesToCheck.size(); i+=INDEX_SKIP){
- //System.out.println("Checking positions for index : "+i+" and tree :"+Ntreenum);
- CheckPosition(m, indicesToCheck.get(i), Nsub, lowestE, parent, Ntreenum);
- if (lowestE.d == 0)//lowestE now has the minimum conditional entropy so IG is max there
- break;
- }
- }
- else {
- for (int index:indicesToCheck){
- CheckPosition(m, index, Nsub, lowestE, parent, Ntreenum);
- if (lowestE.d == 0)
- break;
- }
- }
- // BufferedReader reader=new BufferedReader(new InputStreamReader(System.in));
- // System.out.println("************************* lowest e:"+lowestE.d);
- // try {reader.readLine();} catch (IOException e){}
- if (lowestE.d == 0)
- break;
- }
- //从属性集合删除分裂属性
- Iterator<Integer> it = parent.attrArr.iterator();
- while(it.hasNext()){
- int attr = it.next();
- if (attr == parent.splitAttributeM){
- it.remove();
- }
- }
- parent.left.attrArr = parent.attrArr;
- parent.right.attrArr = parent.attrArr;
- // System.out.print("\n");
- // System.out.print("split attrubute num:"+parent.splitAttributeM+" at val:"+parent.splitValue+" n:"+parent.data.size()+" ");
- // PrintOutClasses(parent.data);
- // System.out.println("\nmadeSplit . . .");
- // PrintOutNode(parent," ");
- // PrintOutNode(parent.left," ");
- // PrintOutNode(parent.right," ");
-
- //-------------------------------Step D
- //------------Left Child
- if (parent.left.data.size() == 1){//训练集为空
- parent.left.isLeaf = true;
- parent.left.Class = GetClass(parent.left.data.get(0));
- }
- else if (parent.left.attrArr.size() == 0){//属性集为空
- parent.left.isLeaf = true;
- parent.Class = GetMajorityClass(parent.left.data);
- }
- // else if (parent.left.data.size() < MIN_NODE_SIZE){
- // parent.left.isLeaf = true;
- // parent.left.Class = GetMajorityClass(parent.left.data);
- // }
- else {
- Class = CheckIfLeaf(parent.left.data);
- if (Class == null){
- parent.left.isLeaf = false;
- parent.left.Class = null;
- // System.out.println("make branch left: m:"+m);
- }
- else {//训练集样本全属于同一类别
- parent.left.isLeaf = true;
- parent.left.Class = Class;
- }
- }
- //------------Right Child
- if (parent.right.data.size() == 1){//训练集为空
- parent.right.isLeaf = true;
- parent.right.Class = GetClass(parent.right.data.get(0));
- }
- else if (parent.right.attrArr.size() == 0){//属性集为空
- parent.right.isLeaf = true;
- parent.Class = GetMajorityClass(parent.right.data);
- }
- // else if (parent.left.data.size() < MIN_NODE_SIZE){
- // parent.left.isLeaf = true;
- // parent.left.Class = GetMajorityClass(parent.left.data);
- // }
- else {
- Class = CheckIfLeaf(parent.right.data);
- if (Class == null){
- parent.right.isLeaf = false;
- parent.right.Class = null;
- // System.out.println("make branch right: m:"+m);
- }
- else {//训练集样本全属于同一类别
- parent.right.isLeaf = true;
- parent.right.Class = Class;
- }
- }
-
-
- if (!parent.left.isLeaf){
- RecursiveSplit(parent.left, Ntreenum);
- }
-
- // else {
- // System.out.print("left leaf! Class:"+parent.left.Class+" ");
- // PrintOutClasses(parent.left.data);
- // }
- if (!parent.right.isLeaf)
- RecursiveSplit(parent.right, Ntreenum);
- // else {
- // System.out.print("leaf right! Class:"+parent.right.Class+" ");
- // PrintOutClasses(parent.right.data);
- // }
- }
- }
- /**
- * Given a data matrix, return the most popular Y value (the class)
- * @param data The data matrix
- * @return The most popular class
- */
- private int GetMajorityClass(List<int[]> data){
- int[] counts=new int[RandomForest.C];
- for (int[] record:data){
- int Class=record[record.length-1];//GetClass(record);
- counts[Class-1]++;
- }
- int index=-99;
- int max=Integer.MIN_VALUE;
- for (int i=0;i<counts.length;i++){
- if (counts[i] > max){
- max=counts[i];
- index=i+1;
- }
- }
- return index;
- }
-
- /**
- * Checks the {@link #CalcEntropy(double[]) entropy} of an index in a data matrix at a particular attribute (m)
- * and returns the entropy. If the entropy is lower than the minimum to date (lowestE), it is set to the minimum.
- *
- * The total entropy is calculated by getting the sub-entropy for below the split point and upper the split point.
- * The sub-entropy is calculated by first getting the {@link #GetClassProbs(List) proportion} of each of the classes
- * in this sub-data matrix. Then the entropy is {@link #CalcEntropy(double[]) calculated}. The lower sub-entropy
- * and upper sub-entropy are then weight averaged to obtain the total entropy.
- *
- * @param m the attribute to split on
- * @param n the index to check(rowID)
- * @param Nsub the num of records in the data matrix
- * @param lowestE the minimum entropy to date
- * @param parent the parent node
- * @return the entropy of this split
- */
- private double CheckPosition(int m, int n, int Nsub, DoubleWrap lowestE, TreeNode parent, int nTre){
- // var, index, train.size, lowest number, for a tree
- //System.out.println("Checking position of the index attribute of tree :"+nTre);
- if (n < 1) //exit conditions
- return 0;
- if (n > Nsub)
- return 0;
-
- List<int[]> lower = GetLower(parent.data, n);
- List<int[]> upper = GetUpper(parent.data, n);
- if (lower == null)
- System.out.println("lower list null");
- if (upper == null)
- System.out.println("upper list null");
- double[] p = GetClassProbs(parent.data);
- double[] pl = GetClassProbs(lower);
- double[] pu = GetClassProbs(upper);
- double eP = CalcEntropy(p);
- double eL = CalcEntropy(pl);
- double eU = CalcEntropy(pu);
-
- double e = eP - eL * lower.size()/(double)Nsub - eU * upper.size()/(double)Nsub;
- // System.out.println("g:"+parent.generation+" N:"+parent.data.size()+" M:"+RandomForest.M+" Ms:"+RandomForest.Ms+" n:"+n+" m:"+m+" val:"+parent.data.get(n)[m]+" e:"+e);
- // out.write(m+","+n+","+parent.data.get(n)[m]+","+e+"\n");
- if (e > lowestE.d){
- lowestE.d = e;
- // System.out.print("-");
- parent.splitAttributeM = m;
- parent.splitValue = parent.data.get(n)[m];
- parent.left.data = lower;
- parent.right.data = upper;
-
- }
- return e;//entropy
- }
- /**
- * Given a data record, return the Y value - take the last index
- *
- * @param record the data record
- * @return its y value (class)
- */
- public static int GetClass(int[] record){
- return record[RandomForest.M];
- }
- /**
- * Given a data matrix, check if all the y values are the same. If so,
- * return that y value, null if not
- *
- * @param data the data matrix
- * @return the common class (null if not common)
- */
- private Integer CheckIfLeaf(List<int[]> data){
- // System.out.println("checkIfLeaf");
- boolean isLeaf = true;
- int ClassA = GetClass(data.get(0));
- for (int i=1; i<data.size(); i++){
- int[] recordB = data.get(i);
- if (ClassA != GetClass(recordB)){
- isLeaf = false;
- break;
- }
- }
- if (isLeaf)
- return GetClass(data.get(0));
- else
- return null;
- }
- /**
- * Split a data matrix and return the upper portion
- *
- * @param data the data matrix to be split
- * @param nSplit index in a sub-data matrix that we will return all data records above it
- * @return the upper sub-data matrix
- */
- private List<int[]> GetUpper(List<int[]> data, int nSplit){
- int N = data.size();
- List<int[]> upper = new ArrayList<int[]>(N-nSplit);
- for (int n=nSplit; n<N; n++)
- upper.add(data.get(n));
- return upper;
- }
- /**
- * Split a data matrix and return the lower portion
- *
- * @param data the data matrix to be split
- * @param nSplit this index in a sub-data matrix that return all data records below it
- * @return the lower sub-data matrix
- */
- private List<int[]> GetLower(List<int[]> data, int nSplit){
- List<int[]> lower = new ArrayList<int[]>(nSplit);
- for (int n=0; n<nSplit; n++)
- lower.add(data.get(n));
- return lower;
- }
- /**
- * This class compares two data records by numerically comparing a specified attribute
- *
- * @author kapelner
- *
- */
- private class AttributeComparator implements Comparator{
- /** the specified attribute */
- private int m;
- /**
- * Create a new comparator
- * @param m the attribute in which to compare on
- */
- public AttributeComparator(int m){
- this.m = m;
- }
- /**
- * Compare the two data records. They must be of type int[].
- *
- * @param o1 data record A
- * @param o2 data record B
- * @return -1 if A[m] < B[m], 1 if A[m] > B[m], 0 if equal
- */
- public int compare(Object o1, Object o2){
- int a = ((int[])o1)[m];
- int b = ((int[])o2)[m];
- if (a < b)
- return -1;
- if (a > b)
- return 1;
- else
- return 0;
- }
- }
- /**
- * Sorts a data matrix by an attribute from lowest record to highest record
- * @param data the data matrix to be sorted
- * @param m the attribute to sort on
- */
- @SuppressWarnings("unchecked")
- private void SortAtAttribute(List<int[]> data, int m){
- Collections.sort(data, new AttributeComparator(m));
- }
- /**
- * Given a data matrix, return a probability mass function representing
- * the frequencies of a class in the matrix (the y values)
- *
- * @param records the data matrix to be examined
- * @return the probability mass function
- */
- private double[] GetClassProbs(List<int[]> records){
-
- double N = records.size();
-
- int[] counts = new int[RandomForest.C];//the num of target class
- // System.out.println("counts:");
- // for (int i:counts)
- // System.out.println(i+" ");
-
- for (int[] record:records)
- counts[GetClass(record)-1]++;
-
- double[] ps = new double[RandomForest.C];
- for (int j=0; j<RandomForest.C; j++)
- ps[j] = counts[j]/N;
- // System.out.print("probs:");
- // for (double p:ps)
- // System.out.print(" "+p);
- // System.out.print("\n");
- return ps;
- }
- /** ln(2) */
- private static final double logoftwo = Math.log(2);
- /**
- * Given a probability mass function indicating the frequencies of
- * class representation, calculate an "entropy" value using the method
- * in Tan Steinbach Kumar's "Data Mining" textbook
- *
- * @param ps the probability mass function
- * @return the entropy value calculated
- */
- private double CalcEntropy(double[] ps){
- double e = 0;
- for (double p:ps){
- if (p != 0) //otherwise it will divide by zero - see TSK p159
- e += p * Math.log(p)/Math.log(2);
- }
- return -e; //according to TSK p158
- }
- /**
- * Of the M attributes, select {@link RandomForest#Ms Ms} at random.
- *
- * @return The list of the Ms attributes' indices
- */
- private ArrayList<Integer> GetVarsToInclude() {
- boolean[] whichVarsToInclude = new boolean[RandomForest.M];
-
- for (int i=0; i<RandomForest.M; i++)//初始化全为false
- whichVarsToInclude[i]=false;
-
- while (true){
- int a = (int)Math.floor(Math.random() * RandomForest.M);//左闭右开 [0,1)
- whichVarsToInclude[a] = true;
- int N = 0;
- for (int i=0; i<RandomForest.M; i++)
- if (whichVarsToInclude[i])
- N++;
- if (N == RandomForest.Ms)
- break;
- }
-
- ArrayList<Integer> shortRecord = new ArrayList<Integer>(RandomForest.Ms);
-
- for (int i=0; i<RandomForest.M; i++)
- if (whichVarsToInclude[i])
- shortRecord.add(i);
- return shortRecord;
- }
-
- /**
- * Create a boostrap sample of a data matrix
- * @param data the data matrix to be sampled
- * @param train the bootstrap sample
- * @param val the records that are absent in the bootstrap sample
- * @param numb the tree number
- */
- private void BootStrapSample(ArrayList<int[]> data, ArrayList<int[]> train, ArrayList<int[]> val, int numb){
- ArrayList<Integer> indices = new ArrayList<Integer>(N);
- for (int n=0; n<N; n++)
- indices.add((int)Math.floor(Math.random() * N));
- ArrayList<Boolean> IsIn = new ArrayList<Boolean>(N);
- for (int n=0; n<N; n++)
- IsIn.add(false); //initialize it first
- for (int index:indices){
- train.add((data.get(index)).clone());//train has duplicated record
- IsIn.set(index, true);
- }
- for (int i=0; i<N; i++)
- if (!IsIn.get(i))
- val.add((data.get(i)).clone());
- //System.out.println("created testing-data for tree : "+numb);//everywhere its set to false we get those to test data
-
- // System.out.println("bootstrap N:"+N+" size of bootstrap:"+bootstrap.size());
- }
- /**
- * Recursively deletes all data records from the tree. This is run after the tree
- * has been computed and can stand alone to classify incoming data.
- *
- * @param node initially, the root node of the tree
- */
- private void FlushData(TreeNode node){
- node.data = null;
- if (node.left != null)
- FlushData(node.left);
- if (node.right != null)
- FlushData(node.right);
- }
-
- // // possible to clone trees
- // private DTree(){}
- // public DTree clone(){
- // DTree copy=new DTree();
- // copy.root=root.clone();
- // return copy;
- // }
-
- /**
- * Get the number of data records in the test data matrix that were classified correctly
- */
- public int getNumCorrect(){
- return correct;
- }
- /**
- * Get the number of data records left out of the bootstrap sample
- */
- public int getTotalNumInTestSet(){
- return testN;
- }
- /**
- * Get the importance level of attribute m for this tree
- */
- public int getImportanceLevel(int m){
- return importances[m];
- }
- // private void PrintOutNode(TreeNode parent,String init){
- // try {
- // System.out.println(init+"node: left"+parent.left.toString());
- // } catch (Exception e){
- // System.out.println(init+"node: left null");
- // }
- // try {
- // System.out.println(init+" right:"+parent.right.toString());
- // } catch (Exception e){
- // System.out.println(init+"node: right null");
- // }
- // try {
- // System.out.println(init+" isleaf:"+parent.isLeaf);
- // } catch (Exception e){}
- // try {
- // System.out.println(init+" splitAtrr:"+parent.splitAttributeM);
- // } catch (Exception e){}
- // try {
- // System.out.println(init+" splitval:"+parent.splitValue);
- // } catch (Exception e){}
- // try {
- // System.out.println(init+" class:"+parent.Class);
- // } catch (Exception e){}
- // try {
- // System.out.println(init+" data size:"+parent.data.size());
- // PrintOutClasses(parent.data);
- // } catch (Exception e){
- // System.out.println(init+" data: null");
- // }
- // }
- // private void PrintOutClasses(List<int[]> data){
- // try {
- // System.out.print(" (n="+data.size()+") ");
- // for (int[] record:data)
- // System.out.print(GetClass(record));
- // System.out.print("\n");
- // }
- // catch (Exception e){
- // System.out.println("PrintOutClasses: data null");
- // }
- // }
- // public static void PrintBoolArray(boolean[] b) {
- // System.out.print("vars to include: ");
- // for (int i=0;i<b.length;i++)
- // if (b[i])
- // System.out.print(i+" ");
- // System.out.print("\n\n");
- // }
- //
- // public static void PrintIntArray(List<int[]> lower) {
- // System.out.println("tree");
- // for (int i=0;i<lower.size();i++){
- // int[] record=lower.get(i);
- // for (int j=0;j<record.length;j++){
- // System.out.print(record[j]+" ");
- // }
- // System.out.print("\n");
- // }
- // System.out.print("\n");
- // System.out.print("\n");
- // }
- }
DescribeTrees.java
- package com.rf.real;
-
- import java.io.BufferedReader;
- import java.io.FileReader;
- import java.io.IOException;
- import java.util.ArrayList;
-
- public class DescribeTrees {
- //把txt文件作为输入,导入到randomForest中
- BufferedReader br = null;
- String path;
- public DescribeTrees(String path){
- this.path = path;
- }
-
- public ArrayList<int[]> CreateInput(String path){
- ArrayList<int[]> DataInput = new ArrayList<int[]>();
- try {
- String sCurrentLine;
- br = new BufferedReader(new FileReader(path));
-
- while ((sCurrentLine = br.readLine()) != null) {
- ArrayList<Integer> spaceIndex = new ArrayList<Integer>();//空格的index
- int i;
- if(sCurrentLine != null){
- sCurrentLine = " " + sCurrentLine + " ";
- for(i=0; i < sCurrentLine.length(); i++){
- if(Character.isWhitespace(sCurrentLine.charAt(i)))
- spaceIndex.add(i);
- }
- int[] DataPoint = new int[spaceIndex.size()-1];
- for(i=0; i<spaceIndex.size()-1; i++){
- DataPoint[i]=Integer.parseInt(sCurrentLine.substring(spaceIndex.get(i)+1, spaceIndex.get(i+1)));
- }
- /* print DataPoint
- for(k=0; k<DataPoint.length; k++){
- //System.out.print("-");
- System.out.print(DataPoint[k]);
- System.out.print(" ");
- }
- **/
- DataInput.add(DataPoint);
- }
- }
-
- } catch (IOException e) {
- e.printStackTrace();
- } finally {
- try {
- if (br != null)
- br.close();
- } catch (IOException ex) {
- ex.printStackTrace();
- }
- }
- return DataInput;
- }
- }
MainRun.java
- package com.rf.real;
-
- import java.util.ArrayList;
-
- public class MainRun {
- @SuppressWarnings("static-access")
- public static void main(String args[]){
-
- String trainPath = "C:\\Users\\kwok\\Desktop\\rf_data\\Data.txt";
- String testPath = "C:\\Users\\kwok\\Desktop\\rf_data\\Test.txt";
- int numTrees = 100;
-
- DescribeTrees DT = new DescribeTrees(trainPath);
- ArrayList<int[]> Train = DT.CreateInput(trainPath);
-
- DescribeTrees DT2 = new DescribeTrees(testPath);
- ArrayList<int[]> Test = DT2.CreateInput(testPath);
- int categ = 0;
-
- //the num of labels
- int trainLength = Train.get(0).length;
- for(int k=0; k<Train.size(); k++){
- if(Train.get(k)[trainLength-1] < categ)
- continue;
- else{
- categ = Train.get(k)[trainLength-1];
- }
- }
-
- RandomForest Rf = new RandomForest(numTrees, Train, Test);
- Rf.C = categ;//the num of labels
- Rf.M = Train.get(0).length-1;//the num of Attr
- //属性扰动,每次从M个属性中随机选取Ms个属性,Ms = ln(m)/ln(2)
- Rf.Ms = (int)Math.round(Math.log(Rf.M)/Math.log(2) + 1);//随机选择的属性数量
- Rf.Start();
- }
- }
-
- RandomForest.java
- package com.rf.real;
-
- import java.util.ArrayList;
- import java.util.HashMap;
- import java.util.concurrent.ExecutorService;
- import java.util.concurrent.Executors;
- import java.util.concurrent.TimeUnit;
- /**
- * Random Forest
- */
- public class RandomForest {
-
- /**
- * 可用的线程数量
- * */
- private static final int NUM_THREADS = Runtime.getRuntime().availableProcessors();
- /**
- *target类别数量
- * */
- public static int C;
- /**
- * 属性(列)的数量
- * */
- public static int M;
- /**
- *属性扰动,每次从M个属性中随机选取Ms个属性,Ms = log2(M)
- */
- public static int Ms;
- /** the collection of the forest's decision trees */
- private ArrayList<DTree> trees;
- /**
- * 开始时间
- * */
- private long time_o;
- /** the number of trees in this random tree */
- private int numTrees;
- /**
- * 为了实时显示进度,每建立一棵树的更新量
- */
- private double update;
- /**
- * 为了实时显示进度,初试值
- */
- private double progress;
- /** importance Array */
- private int[] importances;
- /** key = a record from data matrix
- * value = RF的分类结果*/
- private HashMap<int[],int[]> estimateOOB;
- /** all of the predictions from RF */
- private ArrayList<ArrayList<Integer>> Prediction;
- /** RF的错误率 */
- private double error;
- /** 控制树生长的进程池 */
- private ExecutorService treePool;
- /** 原始训练数据 */
- private ArrayList<int[]> train_data;
- /** 测试数据*/
- private ArrayList<int[]> testdata;
- /**
- * Initializes a Random forest
- * @param numTrees RF的数量
- * @param train_data 原始训练数据
- * @param t_data 测试数据
- */
- public RandomForest(int numTrees, ArrayList<int[]> train_data, ArrayList<int[]> t_data ){
- this.numTrees = numTrees;
- this.train_data = train_data;
- this.testdata = t_data;
- trees = new ArrayList<DTree>(numTrees);
- update = 100 / ((double)numTrees);
- progress = 0;
- StartTimer();
- System.out.println("creating "+numTrees+" trees in a random Forest. . .");
- System.out.println("total data size is "+train_data.size());
- System.out.println("number of attributes " + (train_data.get(0).length-1));
- System.out.println("number of selected attributes " + ((int)Math.round(Math.log(train_data.get(0).length-1)/Math.log(2) + 1)));
- estimateOOB = new HashMap<int[],int[]>(train_data.size());
- Prediction = new ArrayList<ArrayList<Integer>>();
- }
- /**
- * Begins the creation of random forest
- */
- public void Start() {
- System.out.println("Num of threads started : " + NUM_THREADS);
- System.out.println("Running...");
- treePool = Executors.newFixedThreadPool(NUM_THREADS);
- for (int t=0; t < numTrees; t++){
- System.out.println("structing " + t + " Tree");
- treePool.execute(new CreateTree(train_data,this,t+1));
- //System.out.print(".");
- }
- treePool.shutdown();
- try {
- treePool.awaitTermination(Long.MAX_VALUE, TimeUnit.SECONDS); //effectively infinity
- } catch (InterruptedException ignored){
- System.out.println("interrupted exception in Random Forests");
- }
- System.out.println("");
- System.out.println("Finished tree construction");
- TestForest(trees, testdata);
- CalcImportances();
- System.out.println("Done in "+TimeElapsed(time_o));
- }
-
- /**
- * @param collec_tree the collection of the forest's decision trees
- * @param test_data 测试数据集
- */
- private void TestForest(ArrayList<DTree> collec_tree, ArrayList<int[]> test_data ) {
- int correstness = 0 ;
- int k = 0;
- ArrayList<Integer> actualLabel = new ArrayList<Integer>();
- for(int[] rec:test_data){
- actualLabel.add(rec[rec.length-1]);
- }
- int treeNumber = 1;
- for(DTree dt:collec_tree){
- dt.CalculateClasses(test_data, treeNumber);
- Prediction.add(dt.predictions);
- treeNumber++;
- }
- for(int i = 0; i<test_data.size(); i++){
- ArrayList<Integer> Val = new ArrayList<Integer>();
- for(int j =0; j<collec_tree.size(); j++){
- Val.add(Prediction.get(j).get(i));//The collection of each Tree's prediction in i-th record
- }
- int pred = labelVote(Val);//Voting algorithm
- if(pred == actualLabel.get(i)){
- correstness++;
- }
- }
- System.out.println("Accuracy of Forest is : " + (100 * correstness / test_data.size()) + "%");
- }
-
- /**
- * Voting algorithm
- * @param treePredict The collection of each Tree's prediction in i-th record
- */
- private int labelVote(ArrayList<Integer> treePredict){
- // TODO Auto-generated method stub
- int max=0, maxclass=-1;
- for(int i=0; i<treePredict.size(); i++){
- int count = 0;
- for(int j=0; j<treePredict.size(); j++){
- if(treePredict.get(j) == treePredict.get(i)){
- count++;
- }
- if(count > max){
- maxclass = treePredict.get(i);
- max = count;
- }
- }
- }
- return maxclass;
- }
- /**
- * 计算RF的分类错误率
- */
- private void CalcErrorRate(){
- double N=0;
- int correct=0;
- for (int[] record:estimateOOB.keySet()){
- N++;
- int[] map=estimateOOB.get(record);
- int Class=FindMaxIndex(map);
- if (Class == DTree.GetClass(record))
- correct++;
- }
- error=1-correct/N;
- System.out.println("correctly mapped "+correct);
- System.out.println("Forest error rate % is: "+(error*100));
- }
- /**
- * 更新 OOBEstimate
- * @param record a record from data matrix
- * @param Class
- */
- public void UpdateOOBEstimate(int[] record, int Class){
- if (estimateOOB.get(record) == null){
- int[] map = new int[C];
- //System.out.println("class of record : "+Class);map[Class-1]++;
- estimateOOB.put(record,map);
- }
- else {
- int[] map = estimateOOB.get(record);
- map[Class-1]++;
- }
- }
- /**
- * calculates the importance levels for all attributes.
- */
- private void CalcImportances() {
- importances = new int[M];
- for (DTree tree:trees){
- for (int i=0; i<M; i++)
- importances[i] += tree.getImportanceLevel(i);
- }
- for (int i=0;i<M;i++)
- importances[i] /= numTrees;
- System.out.println("The forest-wide importance as follows:");
- for (int j=0; j<importances.length; j++){
- System.out.println("Attr" + j + ":" + importances[j]);
- }
- }
- /** 计时开始 */
- private void StartTimer(){
- time_o = System.currentTimeMillis();
- }
- /**
- * 创建一棵决策树
- */
- private class CreateTree implements Runnable{
- /** 训练数据 */
- private ArrayList<int[]> train_data;
- /** 随机森林 */
- private RandomForest forest;
- /** the numb of RF */
- private int treenum;
-
- public CreateTree(ArrayList<int[]> train_data, RandomForest forest, int num){
- this.train_data = train_data;
- this.forest = forest;
- this.treenum = num;
- }
- /**
- * Create the decision tree
- */
- public void run() {
- System.out.println("Creating a Dtree num : " + treenum + " ");
- trees.add(new DTree(train_data, forest, treenum));
- //System.out.println("tree added in RandomForest.AddTree.run()");
- progress += update;
- System.out.println("---progress:" + progress);
- }
- }
-
- /**
- * Evaluates testdata
- * @param record a record to be evaluated
- */
- public int Evaluate(int[] record){
- int[] counts=new int[C];
- for (int t=0;t<numTrees;t++){
- int Class=(trees.get(t)).Evaluate(record);
- counts[Class]++;
- }
- return FindMaxIndex(counts);
- }
-
- public static int FindMaxIndex(int[] arr){
- int index=0;
- int max = Integer.MIN_VALUE;
- for (int i=0;i<arr.length;i++){
- if (arr[i] > max){
- max=arr[i];
- index=i;
- }
- }
- return index;
- }
-
-
- /**
- * @param timeinms 开始时间
- * @return the hr,min,s
- */
- private static String TimeElapsed(long timeinms){
- int s=(int)(System.currentTimeMillis()-timeinms)/1000;
- int h=(int)Math.floor(s/((double)3600));
- s-=(h*3600);
- int m=(int)Math.floor(s/((double)60));
- s-=(m*60);
- return ""+h+"hr "+m+"m "+s+"s";
- }
- }
Data.txt
https://github.com/Edgis/Machine-Learning-Algorithm/blob/master/randomForest/Data.txt
Test.txt
https://github.com/Edgis/Machine-Learning-Algorithm/blob/master/randomForest/Test.txt
Data.txt是训练数据集, Text.txt是测试数据集,经线下测试,在使用上述train data建立100棵树的RF model,Test集的forest-wide Error可达到90%,同时发现Attr16和Attr17的Variable Importance较高
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。