赞
踩
信息论是关于信息的本质和传输规律的理论。
信道模型:信源(发送端)-> 信道 -> 信宿(接收端)
1. 通信过程是在随机干扰的环境汇中传递信息的过程
2. 信宿对于信源的先验不确定性:在通信前,信宿不能确切的了解信源的状态;
3. 信宿对于信源的后验不确定性:在通信后,由于存在干扰,信宿对于接收到的信息仍然具有不确定性
4. 后验不确定性总是要小于先验不确定性的。
信息:是消除不确定性的度量。
信息量的大小:由所消除的不确定性的大小来计量。
直观理解:
若消息发生的概率很大,受信者事先已经有所估计,则该消息的信息量就很小。
若消息发生的概率很小,受信者感觉到很突然,该消息所含有的信息量就很大。
所以信息量和概率联系在了一起,信息量可以表示为概率的函数。那么怎样的函数可以用来描述信息量呢?函数
1.
2. 当
3. 当
4. 两个独立事件的联合信息量应该等于它们信息量之和。
以下是
若一个消息
已知,垒球活动进行和取消的概率分别为
那么是否进行活动的熵的计算方法如下:(先验熵)
活动 | 活动进行 | 活动取消 |
---|---|---|
晴天 | 2/5 | 3/5 |
阴天 | 1 | 0 |
雨天 | 3/5 | 2/5 |
计算已知户外的天气情况下活动的条件熵
(总的步骤是计算先验熵,在计算后验熵,在计算条件熵。现在先验熵已知)
计算后验墒:分别计算晴天对于活动的后验熵,阴天对于活动的后验熵,雨天对于活动的后验熵如下。
引入了信息论中的互信息(信息增益)作为选择判别因素的度量,即:以信息增益的下降速度作为选取分类属性的标准,所选的测试属性是从根节点到当前节点的路径上从没有被考虑过的具有最高的信息增益的属性。这就需要计算各个属性的信息增益的值,找出最大的作为判别的属性:
1. 计算先验熵,没有接收到其他的属性值时的平均不确定性,
2. 计算后验墒,在接收到输出符号yi时关于信源的不确定性,
3. 条件熵,对后验熵在输出符号集Y中求期望,接收到全部的付好后对信源的不确定性,
4. 互信息,先验熵和条件熵的差,
是否适合打垒球的决策表如下
天气 | 温度 | 湿度 | 风速 | 活动 |
---|---|---|---|---|
晴 | 炎热 | 高 | 弱 | 取消 |
晴 | 炎热 | 高 | 强 | 取消 |
阴 | 炎热 | 高 | 弱 | 进行 |
雨 | 适中 | 高 | 弱 | 进行 |
雨 | 寒冷 | 正常 | 弱 | 进行 |
雨 | 寒冷 | 正常 | 强 | 取消 |
阴 | 寒冷 | 正常 | 强 | 进行 |
晴 | 适中 | 高 | 弱 | 取消 |
晴 | 寒冷 | 正常 | 弱 | 进行 |
雨 | 适中 | 正常 | 弱 | 进行 |
晴 | 适中 | 正常 | 强 | 进行 |
阴 | 适中 | 高 | 强 | 进行 |
阴 | 炎热 | 正常 | 弱 | 进行 |
雨 | 适中 | 高 | 强 | 取消 |
1.计算先验熵:在没有接收到其他的任何的属性值时候,活动进行与否的熵根据下表进行计算。
2.分别将各个属性作为决策属性时的条件熵(先计算后验墒,在计算条件熵)
(1) 计算已知天气情况下活动是否进行的条件熵(已知天气情况下对于活动的不确定性)
先计算后验墒:
ID3算法实现中,训练数据和测试数据都是用ArrayList<ArrayList<String>>
存放,每一个子ArrayList是一个sample(feature+label)。即,data中的一列是一个属性,一行是一个样本。
uniqueLabels用来统计不同的label出现的个数。
public double calculateShannonEntropy(ArrayList<ArrayList<String>> data) {
double shannon = 0.0;
int length = data.get(0).size(); // length-1就是label的index
HashMap<String, Integer> uniqueLabels = new HashMap<>();
for (int i = 0; i < data.size(); i++) {
if (uniqueLabels.containsKey(data.get(i).get(length - 1))) {
uniqueLabels.replace(data.get(i).get(length - 1), uniqueLabels.get(data.get(i).get(length - 1)) + 1);
} else {
uniqueLabels.put(data.get(i).get(length - 1), 1);
}
}
for (String one : uniqueLabels.keySet()) {
shannon += -(((double) (uniqueLabels.get(one)) / (data.size()))
* Math.log((double) (uniqueLabels.get(one)) / (data.size())) / Math.log(2));
}
return shannon;
}
三个参数(data, index, value)的含义: 将data中第index列上值为value的样本返回,并且在返回的结果中样本不包括index列的特征
public ArrayList<ArrayList<String>> splitDataSetByFeature(ArrayList<ArrayList<String>> data, int index,
String value) {
ArrayList<ArrayList<String>> subData = new ArrayList<>();
for (int i = 0; i < data.size(); i++) {
ArrayList<String> newSample = new ArrayList<>();
if (data.get(i).get(index).equals(value)) {
for (int j = 0; j < data.get(i).size(); j++) {
if (j != index) {
newSample.add(data.get(i).get(j));
}
}
subData.add(newSample);
}
}
return subData;
}
对于一个数据集data,要选择其中的最好的feature来划分数据, 所以需要一列一列(data中的一列是一个属性,一行是一个样本)的比较(比较使用哪个特征来划分得到的信息增益最大)。对于每一列来说,计算该列中的属性值有多少种,然后计算每种属性值的熵的大小,然后按照比例求和。最后比较每一列的熵值的总和,信息增益最大的属性就是我们想要找的最好的属性。
featureStatistic用来统计某一个特征可能的取值以及这些取值的个数
public int chooseBestFeature(ArrayList<ArrayList<String>> data, ArrayList<String> featureName) {
int featureSize = data.get(0).size();
int dataSize = data.size();
int bestFuatrue = -1;
double bestInfoGain = 0.0;
double infoGain = 0.0;
double baseShannon = this.calculateShannonEntropy(data);
double shannon = 0.0;
HashMap<String, Integer> featureStatistic = new HashMap<>();
for (int i = 0; i < featureSize - 1; i++) {
for (int j = 0; j < data.size(); j++) {
if (featureStatistic.containsKey(data.get(j).get(i))) {
featureStatistic.replace(data.get(j).get(i), featureStatistic.get(data.get(j).get(i)) + 1);
} else {
featureStatistic.put(data.get(j).get(i), 1);
}
}
ArrayList<ArrayList<String>> subdata;
for (String featureValue : featureStatistic.keySet()) {
subdata = this.splitDataSetByFeature(data, i, featureValue);
shannon += this.calculateShannonEntropy(subdata)
* ((double) featureStatistic.get(featureValue) / dataSize);
}
infoGain = baseShannon - shannon;
if (infoGain > bestInfoGain) {
bestInfoGain = infoGain;
bestFuatrue = i;
}
shannon = 0.0;
featureStatistic.clear();
}
return bestFuatrue;
}
递归的构造决策树,注意函数的返回类型是object,而不是DecisionTree(该类的定义下面给出),这是因为当我们构造到叶子结点的时候,我们可能返回的是String(正例还是反例,yes or no,而不再是棵子树),所以使用Object
public Object createDecisionTree(ArrayList<ArrayList<String>> data, ArrayList<String> featureName) {
int dataSize = data.size();
int featureSize = data.get(0).size();
// 如果没有特征了,data.get(0).size = 1 说明只剩下标签了, 开始投票。
if (data.get(0).size() == 1) {
return vote(data);
}
// 判断是不是所有的sample的label都一致了, 如果是,返回这个统一的类别标签。
HashSet<String> labels = new HashSet<>();
for (int i = 0; i < dataSize; i++) {
if (!labels.contains(data.get(i).get(featureSize - 1))) {
labels.add(data.get(i).get(featureSize - 1));
}
}
if (labels.size() == 1) {
return data.get(0).get(featureSize - 1);
}
// 选择最好的feature来进行决策树(子决策树)的构建
int bestFeatureIndex = this.chooseBestFeature(data, featureName);
String bestFeature = featureName.get(bestFeatureIndex);
featureName.remove(bestFeatureIndex);
// 统计上一步选出的最好的属性,都有那些取值。
HashSet<String> bestFeatureValuesSet = new HashSet<>();
for (int i = 0; i < data.size(); i++) {
if (!bestFeatureValuesSet.contains(data.get(i).get(bestFeatureIndex))) {
bestFeatureValuesSet.add(data.get(i).get(bestFeatureIndex));
}
}
DecisionTree tree = new DecisionTree();
tree.setAttributeName(bestFeature);
// 最好的属性的每一个取值,都形成一个子树的root, 开始递归。
Iterator<String> iterator = bestFeatureValuesSet.iterator();
while (iterator.hasNext()) {
ArrayList<String> subFeatureName = new ArrayList<>();
for (int i = 0; i < featureName.size(); i++) {
subFeatureName.add(featureName.get(i));
} // 递归的一个关键问题。
String featureValue = iterator.next();
tree.children.put(featureValue,
createDecisionTree(splitDataSetByFeature(data, bestFeatureIndex, featureValue), subFeatureName));
}
return tree;
}
当已经没有属性可以作为划分的依据了, 但是这些样本的类的标签依然不同, 那么这个时候就要投票决定了。这个时候data的形式应该是只有一列标签了。那么我们就找这一列标签中最多的,作为类别返回。
public String vote(ArrayList<ArrayList<String>> data) {
String voteResult = null;
int dataSize = data.size();
int length = data.get(0).size();
HashMap<String, Integer> sta = new HashMap<>();
for (int i = 0; i < dataSize; i++) {
if (!sta.keySet().contains(data.get(i).get(length - 1))) {
sta.put(data.get(i).get(length - 1), 1);
} else {
sta.replace(data.get(i).get(length - 1), sta.get(data.get(i).get(length - 1)) + 1);
}
}
int maxValue = Collections.max(sta.values());
for (String key : sta.keySet()) {
if (maxValue == sta.get(key)) {
voteResult = key;
}
}
return voteResult;
}
不像python中有一个功能比较强大的字典,所以这里自定义了一个决策树的数据结构(类DecisionTree),两个域:
1. String
:用来表示该树(子树)的属性(feature)。
2. HashMap<String, Object>
: key的值表示feature的取值,Object是子树(DecisionTree)或者是最终的label。
典型的一个递归的定义。并且在该类中提供了:
1. 遍历树的方法。
2. 将构造的树输出到指定的文件中。
public class DecisionTree implements Serializable{
private static final long serialVersionUID = 1L;
private String attributeName;
public HashMap<String, Object> children;
private String decisionTree = "./outputTree/decisionTree.data";
public void printTree(Object tree, ArrayList<String> record, BufferedWriter bufferedWriter) {
if (tree instanceof String) {
record.add((String) tree);
System.out.println(record);
try {
bufferedWriter.write(record.toString());
bufferedWriter.newLine();
} catch (IOException e) {
// TODO Auto-generated catch block
e.printStackTrace();
}
record.remove(record.size() - 1);
record.remove(record.size() - 1);
return;
}
record.add(((DecisionTree) tree).getAttributeName());
for (String key : ((DecisionTree) tree).children.keySet()) {
record.add(key);
printTree(((DecisionTree) tree).children.get(key), record, bufferedWriter);
}
int count = 1;
while( record.size() > 0 && count <= 2){
record.remove(record.size() - 1);
count++;
}
}
public void saveDecisionTree(Object tree)
{
try {
BufferedWriter bufferedWriter = new BufferedWriter(new FileWriter(this.decisionTree));
this.printTree(tree, new ArrayList<>(), bufferedWriter);
bufferedWriter.close();
System.out.println("\r\nthe decision tree has saved in the file: './outputTree/decisionTree.data'");
} catch (IOException e) {
// TODO Auto-generated catch block
e.printStackTrace();
}
}
}
" "
)传递给构造函数answerYou()
import java.util.HashMap;
import java.util.Map;
import java.util.Scanner;
public class Football {
ID3 id3 = null;
DecisionTree decisionTree = null;
public Football() {
this.id3 = new ID3("./football.txt", " ");
// 创建一个ID3对象,将trainingSet的名称和文件中用于分隔数据的符号(这里是空格" ")传递给构造函数
this.decisionTree = id3.trainingDecisionTree();
// 构造决策树
}
public String answerYou() {
Scanner scanner = new Scanner(System.in);
Map<String, String> userData = new HashMap<>();
String answer = null;
System.out.println("\r\n回答几个问题,我可以帮你决定是否应该出去踢球。");
System.out.println("天气怎么样呢? 晴? 阴? 还是 雨?");
userData.put("天气", scanner.next());
System.out.println("温度怎么样呢? 炎热? 适中? 还是 寒冷?");
userData.put("温度", scanner.next());
System.out.println("湿度呢? 高? 还是 正常?");
userData.put("湿度", scanner.next());
System.out.println("风速呢? 强? 还是弱? ");
userData.put("风速", scanner.next());
System.out.println("show the information about your input");
System.out.println(userData);
DecisionTree treePointer = this.decisionTree;
while (true) {
System.out.println("用户在" + ((DecisionTree) treePointer).getAttributeName() + "的取值是"
+ userData.get((((DecisionTree) treePointer).getAttributeName())));
if (!((((DecisionTree) treePointer).children
.get((userData.get((((DecisionTree) treePointer).getAttributeName()))))).getClass().getSimpleName()
.equals("String"))) {
treePointer = (DecisionTree) ((DecisionTree) treePointer).children
.get(userData.get(((DecisionTree) treePointer).getAttributeName()));
} else {
System.out.println(((DecisionTree) treePointer).children
.get(userData.get((((DecisionTree) treePointer).getAttributeName()))));
answer = (String) ((DecisionTree) treePointer).children
.get(userData.get((((DecisionTree) treePointer).getAttributeName())));
return answer;
}
}
}
public static void main(String args[]) {
Football football = new Football();
System.out.println("\r\n足球活动应该" + football.answerYou());
}
}
import java.util.HashMap;
import java.util.Map;
import java.util.Scanner;
public class Lenses {
ID3 id3 = null;
DecisionTree decisionTree = null;
public Lenses() {
this.id3 = new ID3("./lenses.txt", "\t");
this.decisionTree = id3.trainingDecisionTree();
}
public String answerYou() {
Scanner scanner = new Scanner(System.in);
Map<String, String> userData = new HashMap<>();
String answer = null;
System.out.println(
"\r\nI can help you choose the type of contact lenses,as long as you answer me a few quentions");
System.out.println("ok, let's start.");
System.out.println("tearRate? please type 'reduced' or 'normal'");
userData.put("tearRate", scanner.next());
System.out.println("astigmatic? please type 'yes' or 'no'");
userData.put("astigmatic", scanner.next());
System.out.println("prescript? please type 'hyper' or 'myope'");
userData.put("prescript", scanner.next());
System.out.println("age? please type 'pre', 'presbyopic' or 'young'");
userData.put("age", scanner.next());
System.out.println("show the information about your input");
System.out.println(userData);
DecisionTree treePointer = this.decisionTree;
while (true) {
if (!((((DecisionTree) treePointer).children
.get((userData.get((((DecisionTree) treePointer).getAttributeName()))))).getClass().getSimpleName()
.equals("String"))) {
treePointer = (DecisionTree) ((DecisionTree) treePointer).children
.get(userData.get(((DecisionTree) treePointer).getAttributeName()));
} else {
answer = (String) ((DecisionTree) treePointer).children
.get(userData.get((((DecisionTree) treePointer).getAttributeName())));
return answer;
}
}
}
public static void main(String args[]) {
Lenses lenses = new Lenses();
System.out.println("\r\nThe type of contants lenses that fits you is " + lenses.answerYou());
}
}
数据集的大小为42000个处理后的手写数字图片,需要进行二值处理。选取其中的90%作为训练集,10%作为预测集。
计算得到的准确度为 88.05%
构造的决策树规则共有2318条
package com.alibaba;
import java.io.BufferedReader;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.FileNotFoundException;
import java.io.FileReader;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.util.ArrayList;
import java.util.HashMap;
public class Mnist {
ID3 id3 = null;
DecisionTree decisionTree = null;
private ArrayList<ArrayList<String>> probeSet = new ArrayList<>();
private String trainingSetFileName = "./MNIST/training.data";
private String probeSetFileName = "./MNIST/probe.data";
public Mnist() {
this.id3 = new ID3(this.trainingSetFileName, ",");
this.decisionTree = id3.trainingDecisionTree();
}
public void calculatePrecise() {
this.getProbeSet();
DecisionTree treePointer = null;
String predictLabel = null;
int labelIndex = this.probeSet.get(0).size() - 1;
double success = 0.0;
for (int i = 0; i < this.probeSet.size(); i++) {
// treePointer = this.decisionTree;
try {
ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream();
ObjectOutputStream objectOutputStream = new ObjectOutputStream(byteArrayOutputStream);
objectOutputStream.writeObject(this.decisionTree);
ByteArrayInputStream byteArrayInputStream = new ByteArrayInputStream(
byteArrayOutputStream.toByteArray());
ObjectInputStream objectInputStream = new ObjectInputStream(byteArrayInputStream);
treePointer = (DecisionTree) objectInputStream.readObject();
} catch (Exception e) {
e.printStackTrace();
}
HashMap<String, String> userData = new HashMap<>();
for (int j = 0; j < this.probeSet.get(i).size() - 1; j++) {
userData.put(id3.bankFeatureName.get(j), this.probeSet.get(i).get(j));
}
while (true) {
if (((DecisionTree) treePointer).children
.get((userData).get(((DecisionTree) treePointer).getAttributeName())) != null) {
if (!((((DecisionTree) treePointer).children
.get((userData.get((((DecisionTree) treePointer).getAttributeName()))))).getClass()
.getSimpleName().equals("String"))) {
treePointer = (DecisionTree) ((DecisionTree) treePointer).children
.get(userData.get(((DecisionTree) treePointer).getAttributeName()));
} else {// 类型是string,说明到了决策节点,或者说是叶子节点。
predictLabel = (String) ((DecisionTree) treePointer).children
.get(userData.get((((DecisionTree) treePointer).getAttributeName())));
break;
}
} else {
predictLabel = "-1";
break;
}
}
if (this.probeSet.get(i).get(labelIndex).equals(predictLabel)) {
// System.out.println("YES true label is: " + this.probeSet.get(i).get(labelIndex) + " predict label is " + predictLabel);
success += 1;
} else {
// System.out.println("NO true label is: " + this.probeSet.get(i).get(labelIndex) + " predict label is " + predictLabel);
}
}
System.out.println("\r\tprecise is " + success / (double) this.probeSet.size());
}
private void getProbeSet() {
try {
BufferedReader bufferedReader = new BufferedReader(new FileReader(this.probeSetFileName));
String line = null;
while ((line = bufferedReader.readLine()) != null) {
ArrayList<String> temp = new ArrayList<>();
String[] arr = line.split(",");
for (String string : arr) {
temp.add(string);
}
this.probeSet.add(temp);
}
bufferedReader.close();
} catch (FileNotFoundException e) {
// TODO Auto-generated catch block
e.printStackTrace();
} catch (IOException e) {
// TODO Auto-generated catch block
e.printStackTrace();
}
System.out.println("Probe Set Size :" + this.probeSet.size());
}
public static void main(String args[]) {
Mnist mnist = new Mnist();
mnist.calculatePrecise();
}
}
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。