当前位置:   article > 正文

决策树ID3算法及实现_手动模拟id3算法来实现决策过程

手动模拟id3算法来实现决策过程

0. 信息论

信道模型和信息的含义

信息论是关于信息的本质和传输规律的理论。
信道模型:信源(发送端)-> 信道 -> 信宿(接收端)
1. 通信过程是在随机干扰的环境汇中传递信息的过程
2. 信宿对于信源的先验不确定性:在通信前,信宿不能确切的了解信源的状态;
3. 信宿对于信源的后验不确定性:在通信后,由于存在干扰,信宿对于接收到的信息仍然具有不确定性
4. 后验不确定性总是要小于先验不确定性的。
信息:是消除不确定性的度量。
信息量的大小:由所消除的不确定性的大小来计量。

信息的定量描述

直观理解:
若消息发生的概率很大,受信者事先已经有所估计,则该消息的信息量就很小。
若消息发生的概率很小,受信者感觉到很突然,该消息所含有的信息量就很大。
所以信息量和概率联系在了一起,信息量可以表示为概率的函数。那么怎样的函数可以用来描述信息量呢?函数f(p)应该满足以下条件:
1. f(p)应该是概率p的严格单调递减函数,
2. 当p=1时,f(p)=0
3. 当p=0时,f(p)=
4. 两个独立事件的联合信息量应该等于它们信息量之和。
以下是f(p)=log(p)的图像,满足以上的所有的要求。


自信息和熵的定义

若一个消息x出现的概率为p,那么这个消息所含有的信息量为

I=log(p)

上式称为消息 x的自信息,自信息有两种含义:
1. 当该消息发生之前,表示发生该消息的不确定性,
2. 当该消息发生之后,表示消息所含有的信息量。
信源含有的信息量是信源发出的所有可能消息的平均不确定性,香农把信源所含有的信息量称为是指每个符号所含有的信息量(自信息)的统计平均。若X是一个离散随机变量,概率分布为p(x)=P(X=x)xX,那么 X的熵为
H(X)=iNp(xi)I(xi)=iNp(xi)logp(xi)

一个随机变量的熵越大,其不确定性就越大,(不管是先验熵,后验熵,还是条件熵都是这样的)正确的估计其值的可能性就越小,越是不确定的随机变量越是需要更大的信息量来确定其值。
结合信道模型, H(X)是信源发出前的平均不确定性,是 先验熵。其中, p(xi)越接近, H(X)越大。 p(xi)相差越大, H(X)越小。在事件 yj出现的条件下,随机事件 xi发生的条件概率为 p(xi|yj),定义它的 条件自信息量为条件概率对数的负值。如下:
I(xi|yj)=logp(xi|yj)

在给定 yj的条件下,( xi的条件自信息量为 Ixi|yj),此时关于 X的不确定性定义为后验熵,(接收到一个输出信号后对于信源的不确定性)。如下:
H(X|yj)=iNp(xi|yj)I(xi|yj)=iNp(xi|yj)log(p(xi|yj))

在给定 Y(即各个yj)的条件下, X集合的条件熵(接收到了所有的输出信号后对于信源的不确定性)为H(X|Y)
H(X|Y)=jp(yj)H(X|yj)=jp(yj)ip(xi|yj)logp(xi|yj)

条件熵表示知道Y之后, 对X的不确定性。(知道了天气状态之后是否要出去活动的不确定性,其值越小,不确定性越小,说明天气情况带来的信息越大)。在通信后总能消除一定的关于信源端的不确定性,所以存在关系: H(U|V)<H(U),那么定义互信息
I(X,Y)=H(X)H(X|Y)
表示在接收到 Y后获得的关于X的信息量。


例子

已知,垒球活动进行和取消的概率分别为914514
那么是否进行活动的熵的计算方法如下:(先验熵)

H()=914log914514log514=0.94

又,已知天气情况对活动进行的影响如下:

活动活动进行活动取消
晴天2/53/5
阴天10
雨天3/52/5

计算已知户外的天气情况下活动的条件熵
(总的步骤是计算先验熵,在计算后验熵,在计算条件熵。现在先验熵已知)
计算后验墒:分别计算晴天对于活动的后验熵阴天对于活动的后验熵雨天对于活动的后验熵如下。

H(|)=25log2535log35=0.971

H(|)=1log10log0=0

H(|)=35log3525log25=0.971

又已知天气的状况为 p()=514p()=414p()=514
所以已知户外的天气的时候,活动的 条件熵为:
H(|)=514H(|)+414H(|)+514H(|)=0.693

平均互信息为
I()=H()H(|)=0.246


1. ID3算法

引入了信息论中的互信息(信息增益)作为选择判别因素的度量,即:以信息增益的下降速度作为选取分类属性的标准,所选的测试属性是从根节点到当前节点的路径上从没有被考虑过的具有最高的信息增益的属性。这就需要计算各个属性的信息增益的值,找出最大的作为判别的属性:
1. 计算先验熵,没有接收到其他的属性值时的平均不确定性,
2. 计算后验墒,在接收到输出符号yi时关于信源的不确定性,
3. 条件熵,对后验熵在输出符号集Y中求期望,接收到全部的付好后对信源的不确定性,
4. 互信息,先验熵和条件熵的差,

实例

是否适合打垒球的决策表如下

天气温度湿度风速活动
炎热取消
炎热取消
炎热进行
适中进行
寒冷正常进行
寒冷正常取消
寒冷正常进行
适中取消
寒冷正常进行
适中正常进行
适中正常进行
适中进行
炎热正常进行
适中取消

1.计算先验熵:在没有接收到其他的任何的属性值时候,活动进行与否的熵根据下表进行计算。
这里写图片描述

H()=914log914514log514=0.94

2.分别将各个属性作为决策属性时的条件熵(先计算后验墒,在计算条件熵)

(1) 计算已知天气情况下活动是否进行的条件熵(已知天气情况下对于活动的不确定性)
这里写图片描述
先计算后验墒

H(|)=P(|)logP(|)P(|)logP(|)=0.971

H(|)=P(|)logP(|)P(|)logP(|)=0

H(|)=P(|)logP(|)P(|)logP(|)=0.971

再计算 条件熵:(知道了Y之后,对X的不确定性:知道了天气之后,对活动的不确定性,越小是越好的)
H|=5/14H|+4/14H|+5/14H|=0.693

(2)计算已知 温度情况时对活动的条件熵(不确定性)
这里写图片描述
H(|)=0.911

(3)已知 湿度情况下对于活动是否进行的条件熵(不确定性)
这里写图片描述
H(|湿)=0.789

(4)已知 风速情况下对于活动是否进行的条件熵(不确定性)
这里写图片描述
H(|)=0.892

3.计算信息增益
I()=H()H(|)=0.940.693=0.246

I()=H()H(|)=0.940.911=0.029

I(湿)=H()H(|湿)=0.940.789=0.151

I()=H()H(|)=0.940.892=0.048

所以选择天气作为第一个判别因素
这里写图片描述
在选择了天气作为第一个判别因素之后,我们很容易看出(计算的方法和上面提到的一样),针对上图的中间的三张子表来说,第一张子表在选择湿度作为划分数据的feature的时候,分类问题可以完全解决:湿度正常的情况下进行活动,湿度高的时候取消(在天气状态为晴的条件下);第二个子表不需要划分,即,天气晴的情况下不管其他的因素是什么,活动都要进行;第三张子表当选择风速作为划分的feature时,分类问题也完全解决:风速弱的时候进行,风速强的时候取消(在天气状况为雨的条件下)。

2. java实现

2.1 计算给定数据集的香农熵

ID3算法实现中,训练数据和测试数据都是用ArrayList<ArrayList<String>> 存放,每一个子ArrayList是一个sample(feature+label)。即,data中的一列是一个属性,一行是一个样本。
uniqueLabels用来统计不同的label出现的个数。

public double calculateShannonEntropy(ArrayList<ArrayList<String>> data) {
        double shannon = 0.0;

        int length = data.get(0).size(); // length-1就是label的index
        HashMap<String, Integer> uniqueLabels = new HashMap<>();
        for (int i = 0; i < data.size(); i++) {
            if (uniqueLabels.containsKey(data.get(i).get(length - 1))) {
                uniqueLabels.replace(data.get(i).get(length - 1), uniqueLabels.get(data.get(i).get(length - 1)) + 1);
            } else {
                uniqueLabels.put(data.get(i).get(length - 1), 1);
            }
        }
        for (String one : uniqueLabels.keySet()) {
            shannon += -(((double) (uniqueLabels.get(one)) / (data.size()))
                    * Math.log((double) (uniqueLabels.get(one)) / (data.size())) / Math.log(2));
        }
        return shannon;
    }
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18

2.2 按照给定的feature的取值划分数据集

三个参数(data, index, value)的含义: 将data中第index列上值为value的样本返回,并且在返回的结果中样本不包括index列的特征

public ArrayList<ArrayList<String>> splitDataSetByFeature(ArrayList<ArrayList<String>> data, int index,
            String value) {
        ArrayList<ArrayList<String>> subData = new ArrayList<>();
        for (int i = 0; i < data.size(); i++) {
            ArrayList<String> newSample = new ArrayList<>();
            if (data.get(i).get(index).equals(value)) {
                for (int j = 0; j < data.get(i).size(); j++) {
                    if (j != index) {
                        newSample.add(data.get(i).get(j));
                    }
                }
                subData.add(newSample);
            }
        }
        return subData;
    }
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16

2.3 选择最好的数据集划分方式(选择最好的feature来划分数据集)

对于一个数据集data,要选择其中的最好的feature来划分数据, 所以需要一列一列(data中的一列是一个属性,一行是一个样本)的比较(比较使用哪个特征来划分得到的信息增益最大)。对于每一列来说,计算该列中的属性值有多少种,然后计算每种属性值的熵的大小,然后按照比例求和。最后比较每一列的熵值的总和,信息增益最大的属性就是我们想要找的最好的属性。
featureStatistic用来统计某一个特征可能的取值以及这些取值的个数

public int chooseBestFeature(ArrayList<ArrayList<String>> data, ArrayList<String> featureName) {
        int featureSize = data.get(0).size();
        int dataSize = data.size();
        int bestFuatrue = -1;
        double bestInfoGain = 0.0;
        double infoGain = 0.0;
        double baseShannon = this.calculateShannonEntropy(data);
        double shannon = 0.0;
        HashMap<String, Integer> featureStatistic = new HashMap<>(); 
        for (int i = 0; i < featureSize - 1; i++) {
            for (int j = 0; j < data.size(); j++) {
                if (featureStatistic.containsKey(data.get(j).get(i))) {
                    featureStatistic.replace(data.get(j).get(i), featureStatistic.get(data.get(j).get(i)) + 1);
                } else {
                    featureStatistic.put(data.get(j).get(i), 1);
                }
            }

            ArrayList<ArrayList<String>> subdata;
            for (String featureValue : featureStatistic.keySet()) {
                subdata = this.splitDataSetByFeature(data, i, featureValue);
                shannon += this.calculateShannonEntropy(subdata)
                        * ((double) featureStatistic.get(featureValue) / dataSize);
            }
            infoGain = baseShannon - shannon;
            if (infoGain > bestInfoGain) {
                bestInfoGain = infoGain;
                bestFuatrue = i;
            }
            shannon = 0.0;
            featureStatistic.clear();
        }
        return bestFuatrue;
    }
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34

2.4 构造决策树

递归的构造决策树,注意函数的返回类型是object,而不是DecisionTree(该类的定义下面给出),这是因为当我们构造到叶子结点的时候,我们可能返回的是String(正例还是反例,yes or no,而不再是棵子树),所以使用Object

public Object createDecisionTree(ArrayList<ArrayList<String>> data, ArrayList<String> featureName) {

        int dataSize = data.size();
        int featureSize = data.get(0).size();
        // 如果没有特征了,data.get(0).size = 1 说明只剩下标签了, 开始投票。
        if (data.get(0).size() == 1) {
            return vote(data);
        }
        // 判断是不是所有的sample的label都一致了, 如果是,返回这个统一的类别标签。
        HashSet<String> labels = new HashSet<>();
        for (int i = 0; i < dataSize; i++) {
            if (!labels.contains(data.get(i).get(featureSize - 1))) {
                labels.add(data.get(i).get(featureSize - 1));
            }
        }
        if (labels.size() == 1) {
            return data.get(0).get(featureSize - 1);
        }

        // 选择最好的feature来进行决策树(子决策树)的构建
        int bestFeatureIndex = this.chooseBestFeature(data, featureName);
        String bestFeature = featureName.get(bestFeatureIndex);
        featureName.remove(bestFeatureIndex);

        // 统计上一步选出的最好的属性,都有那些取值。
        HashSet<String> bestFeatureValuesSet = new HashSet<>();
        for (int i = 0; i < data.size(); i++) {
            if (!bestFeatureValuesSet.contains(data.get(i).get(bestFeatureIndex))) {
                bestFeatureValuesSet.add(data.get(i).get(bestFeatureIndex));
            }
        }

        DecisionTree tree = new DecisionTree();
        tree.setAttributeName(bestFeature);

        // 最好的属性的每一个取值,都形成一个子树的root, 开始递归。
        Iterator<String> iterator = bestFeatureValuesSet.iterator();
        while (iterator.hasNext()) {
            ArrayList<String> subFeatureName = new ArrayList<>();
            for (int i = 0; i < featureName.size(); i++) {
                subFeatureName.add(featureName.get(i));
            } // 递归的一个关键问题。
            String featureValue = iterator.next();
            tree.children.put(featureValue,
                    createDecisionTree(splitDataSetByFeature(data, bestFeatureIndex, featureValue), subFeatureName));
        }
        return tree;
    }
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48

2.5 投票函数

当已经没有属性可以作为划分的依据了, 但是这些样本的类的标签依然不同, 那么这个时候就要投票决定了。这个时候data的形式应该是只有一列标签了。那么我们就找这一列标签中最多的,作为类别返回。

public String vote(ArrayList<ArrayList<String>> data) {
        String voteResult = null;
        int dataSize = data.size();
        int length = data.get(0).size();

        HashMap<String, Integer> sta = new HashMap<>();
        for (int i = 0; i < dataSize; i++) {
            if (!sta.keySet().contains(data.get(i).get(length - 1))) {
                sta.put(data.get(i).get(length - 1), 1);
            } else {
                sta.replace(data.get(i).get(length - 1), sta.get(data.get(i).get(length - 1)) + 1);
            }
        }
        int maxValue = Collections.max(sta.values());
        for (String key : sta.keySet()) {
            if (maxValue == sta.get(key)) {
                voteResult = key;
            }
        }
        return voteResult;
    }
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21

2.6 决策树的数据结构

不像python中有一个功能比较强大的字典,所以这里自定义了一个决策树的数据结构(类DecisionTree),两个域:
1. String:用来表示该树(子树)的属性(feature)。
2. HashMap<String, Object> : key的值表示feature的取值,Object是子树(DecisionTree)或者是最终的label。
典型的一个递归的定义。并且在该类中提供了:
1. 遍历树的方法。
2. 将构造的树输出到指定的文件中。

public class DecisionTree implements Serializable{
    private static final long serialVersionUID = 1L;

    private String attributeName;
    public HashMap<String, Object> children;
    private String decisionTree = "./outputTree/decisionTree.data";

    public void printTree(Object tree, ArrayList<String> record, BufferedWriter bufferedWriter) {
        if (tree instanceof String) {
            record.add((String) tree);
            System.out.println(record);
            try {
                bufferedWriter.write(record.toString());
                bufferedWriter.newLine();
            } catch (IOException e) {
                // TODO Auto-generated catch block
                e.printStackTrace();
            }
            record.remove(record.size() - 1);
            record.remove(record.size() - 1);
            return;
        }
        record.add(((DecisionTree) tree).getAttributeName());
        for (String key : ((DecisionTree) tree).children.keySet()) {
            record.add(key);
            printTree(((DecisionTree) tree).children.get(key), record, bufferedWriter);
        }
        int count = 1;
        while( record.size() > 0 && count <= 2){
            record.remove(record.size() - 1);
            count++;
        }
    }

    public void saveDecisionTree(Object tree)
    {
        try {
            BufferedWriter bufferedWriter = new BufferedWriter(new FileWriter(this.decisionTree));
            this.printTree(tree, new ArrayList<>(), bufferedWriter);
            bufferedWriter.close();
            System.out.println("\r\nthe decision tree has saved in the file: './outputTree/decisionTree.data'");
        } catch (IOException e) {
            // TODO Auto-generated catch block
            e.printStackTrace();
        }
    }
}
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47

3. ID3算法测试

3.1 要不要出去踢球?

数据集

这里写图片描述

代码实现

  1. 创建一个ID3对象,将训练数据集的名称和文件中用于分隔数据的符号(这里是空格" ")传递给构造函数
  2. 构造决策树
  3. 一个交互的函数answerYou()
import java.util.HashMap;
import java.util.Map;
import java.util.Scanner;

public class Football {

    ID3 id3 = null;
    DecisionTree decisionTree = null;

    public Football() {
        this.id3 = new ID3("./football.txt", " ");  
        // 创建一个ID3对象,将trainingSet的名称和文件中用于分隔数据的符号(这里是空格" ")传递给构造函数
        this.decisionTree = id3.trainingDecisionTree();
        // 构造决策树
    }

    public String answerYou() {
        Scanner scanner = new Scanner(System.in);
        Map<String, String> userData = new HashMap<>();
        String answer = null;
        System.out.println("\r\n回答几个问题,我可以帮你决定是否应该出去踢球。");
        System.out.println("天气怎么样呢? 晴? 阴? 还是 雨?");
        userData.put("天气", scanner.next());
        System.out.println("温度怎么样呢? 炎热? 适中? 还是 寒冷?");
        userData.put("温度", scanner.next());
        System.out.println("湿度呢? 高? 还是 正常?");
        userData.put("湿度", scanner.next());
        System.out.println("风速呢? 强? 还是弱? ");
        userData.put("风速", scanner.next());

        System.out.println("show the information about your input");
        System.out.println(userData);

        DecisionTree treePointer = this.decisionTree;

        while (true) {
            System.out.println("用户在" + ((DecisionTree) treePointer).getAttributeName() + "的取值是"
                    + userData.get((((DecisionTree) treePointer).getAttributeName())));
            if (!((((DecisionTree) treePointer).children
                    .get((userData.get((((DecisionTree) treePointer).getAttributeName()))))).getClass().getSimpleName()
                            .equals("String"))) {
                treePointer = (DecisionTree) ((DecisionTree) treePointer).children
                        .get(userData.get(((DecisionTree) treePointer).getAttributeName()));
            } else {
                System.out.println(((DecisionTree) treePointer).children
                        .get(userData.get((((DecisionTree) treePointer).getAttributeName()))));
                answer = (String) ((DecisionTree) treePointer).children
                        .get(userData.get((((DecisionTree) treePointer).getAttributeName())));
                return answer;
            }
        }
    }

    public static void main(String args[]) {
        Football football = new Football();
        System.out.println("\r\n足球活动应该" + football.answerYou());
    }
}
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58

运行结果

这里写图片描述

产生的决策树

这里写图片描述

3.2 隐形眼镜的类型

数据集

这里写图片描述

代码实现

import java.util.HashMap;
import java.util.Map;
import java.util.Scanner;

public class Lenses {
    ID3 id3 = null;
    DecisionTree decisionTree = null;

    public Lenses() {
        this.id3 = new ID3("./lenses.txt", "\t");
        this.decisionTree = id3.trainingDecisionTree();
    }

    public String answerYou() {
        Scanner scanner = new Scanner(System.in);
        Map<String, String> userData = new HashMap<>();
        String answer = null;
        System.out.println(
                "\r\nI can help you choose the type of contact lenses,as long as you answer me a few quentions");
        System.out.println("ok, let's start.");
        System.out.println("tearRate? please type 'reduced' or 'normal'");
        userData.put("tearRate", scanner.next());
        System.out.println("astigmatic? please type 'yes' or 'no'");
        userData.put("astigmatic", scanner.next());
        System.out.println("prescript? please type 'hyper' or 'myope'");
        userData.put("prescript", scanner.next());
        System.out.println("age? please type 'pre', 'presbyopic' or 'young'");
        userData.put("age", scanner.next());

        System.out.println("show the information about your input");
        System.out.println(userData);

        DecisionTree treePointer = this.decisionTree;

        while (true) {
            if (!((((DecisionTree) treePointer).children
                    .get((userData.get((((DecisionTree) treePointer).getAttributeName()))))).getClass().getSimpleName()
                            .equals("String"))) {
                treePointer = (DecisionTree) ((DecisionTree) treePointer).children
                        .get(userData.get(((DecisionTree) treePointer).getAttributeName()));
            } else {
                answer = (String) ((DecisionTree) treePointer).children
                        .get(userData.get((((DecisionTree) treePointer).getAttributeName())));
                return answer;
            }
        }
    }

    public static void main(String args[]) {
        Lenses lenses = new Lenses();
        System.out.println("\r\nThe type of contants lenses that fits you is " + lenses.answerYou());
    }
}
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53

运行结果

这里写图片描述

产生的决策树

这里写图片描述

3.3 手写数字识别(MNIST)

数据集的大小为42000个处理后的手写数字图片,需要进行二值处理。选取其中的90%作为训练集,10%作为预测集。
计算得到的准确度为 88.05%
构造的决策树规则共有2318条
这里写图片描述

代码实现

  • 需要注意的是:如果在probe set中某一个sample的某个feature的value在所有的training set中的sample中的该feature上都没有出现过。所以决策树根本没有针对这个feature这个取值的判断(或者说分支),这个时候决策树什么都做不了。例如,天气这个feature,在training set中只有rain和sunny这两个取值,但是在probe set中却feature出现了windy这一取值。此时,决策树没有学过当天气取值为windy应该做什么。所以这种情况下它无法分类。
package com.alibaba;

import java.io.BufferedReader;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.FileNotFoundException;
import java.io.FileReader;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.util.ArrayList;
import java.util.HashMap;

public class Mnist {

    ID3 id3 = null;
    DecisionTree decisionTree = null;
    private ArrayList<ArrayList<String>> probeSet = new ArrayList<>();

    private String trainingSetFileName = "./MNIST/training.data";
    private String probeSetFileName = "./MNIST/probe.data";

    public Mnist() {
        this.id3 = new ID3(this.trainingSetFileName, ",");
        this.decisionTree = id3.trainingDecisionTree();
    }

    public void calculatePrecise() {

        this.getProbeSet();

        DecisionTree treePointer = null;
        String predictLabel = null;
        int labelIndex = this.probeSet.get(0).size() - 1;
        double success = 0.0;

        for (int i = 0; i < this.probeSet.size(); i++) {
            // treePointer = this.decisionTree;
            try {
                ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream();
                ObjectOutputStream objectOutputStream = new ObjectOutputStream(byteArrayOutputStream);
                objectOutputStream.writeObject(this.decisionTree);
                ByteArrayInputStream byteArrayInputStream = new ByteArrayInputStream(
                        byteArrayOutputStream.toByteArray());
                ObjectInputStream objectInputStream = new ObjectInputStream(byteArrayInputStream);
                treePointer = (DecisionTree) objectInputStream.readObject();
            } catch (Exception e) {
                e.printStackTrace();
            }

            HashMap<String, String> userData = new HashMap<>();
            for (int j = 0; j < this.probeSet.get(i).size() - 1; j++) {
                userData.put(id3.bankFeatureName.get(j), this.probeSet.get(i).get(j));
            }

            while (true) {
                if (((DecisionTree) treePointer).children
                        .get((userData).get(((DecisionTree) treePointer).getAttributeName())) != null) {

                    if (!((((DecisionTree) treePointer).children
                            .get((userData.get((((DecisionTree) treePointer).getAttributeName()))))).getClass()
                                    .getSimpleName().equals("String"))) {
                        treePointer = (DecisionTree) ((DecisionTree) treePointer).children
                                .get(userData.get(((DecisionTree) treePointer).getAttributeName()));
                    } else {// 类型是string,说明到了决策节点,或者说是叶子节点。
                        predictLabel = (String) ((DecisionTree) treePointer).children
                                .get(userData.get((((DecisionTree) treePointer).getAttributeName())));
                        break;
                    }
                } else {
                    predictLabel = "-1";
                    break;
                }
            }
            if (this.probeSet.get(i).get(labelIndex).equals(predictLabel)) {
                // System.out.println("YES true label is: " + this.probeSet.get(i).get(labelIndex) + " predict label is " + predictLabel);
                success += 1;
            } else {
                // System.out.println("NO true label is: " + this.probeSet.get(i).get(labelIndex) + " predict label is " + predictLabel);
            }
        }
        System.out.println("\r\tprecise is " + success / (double) this.probeSet.size());
    }

    private void getProbeSet() {
        try {
            BufferedReader bufferedReader = new BufferedReader(new FileReader(this.probeSetFileName));
            String line = null;

            while ((line = bufferedReader.readLine()) != null) {
                ArrayList<String> temp = new ArrayList<>();
                String[] arr = line.split(",");
                for (String string : arr) {
                    temp.add(string);
                }
                this.probeSet.add(temp);
            }
            bufferedReader.close();
        } catch (FileNotFoundException e) {
            // TODO Auto-generated catch block
            e.printStackTrace();
        } catch (IOException e) {
            // TODO Auto-generated catch block
            e.printStackTrace();
        }
        System.out.println("Probe Set Size :" + this.probeSet.size());
    }

    public static void main(String args[]) {
        Mnist mnist = new Mnist();
        mnist.calculatePrecise();
    }
}
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107
  • 108
  • 109
  • 110
  • 111
  • 112
  • 113

完整代码下载地址

github下载地址

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/小蓝xlanll/article/detail/398685
推荐阅读
相关标签
  

闽ICP备14008679号