当前位置:   article > 正文

ID3算法 决策树学习 Python实现_id3决策树





设样本需分为 K K K 类,当前节点待分类样本中每类样本的个数分别为 n 1 , n 2 , … , n K n_1, n_2, …, n_K n1,n2,,nK,则该节点信息熵为
I ( n 1 , n 2 , … , n K ) = − ∑ i = 1 K n i ∑ j = 1 K n j log ⁡ 2 n i ∑ j = 1 K n j I(n_1, n_2, …, n_K) = -\sum_{i=1}^K \frac{n_i}{\sum_{j=1}^K n_j} \log_2 \frac{n_i}{\sum_{j=1}^K n_j} I(n1,n2,,nK)=i=1Kj=1Knjnilog2j=1Knjni
设属性 A A A v v v 种取值,当前节点样本按属性 A A A 决策分类为 v v v 个子节点,第 i i i 个子节点待分类样本中每类样本的个数分别为 n i 1 , n i 2 , … , n i K n_{i1}, n_{i2}, …, n_{iK} ni1,ni2,,niK,则父节点按属性 A A A 决策分类的类信息熵为
E ( A ) = ∑ i = 1 v ∑ j = 1 K n i j ∑ j = 1 K n j I ( n i 1 , n i 2 , … , n i K ) E(A) = \sum_{i=1}^v \frac{\sum_{j=1}^K n_{ij}}{\sum_{j=1}^K n_j} I(n_{i1}, n_{i2}, …, n_{iK}) E(A)=i=1vj=1Knjj=1KnijI(ni1,ni2,,niK)
由此计算当前节点在属性 A 上的信息增益为
G a i n ( A ) = I ( n 1 , n 2 , … , n K ) − E ( A ) Gain(A) = I(n_1, n_2, …, n_K) - E(A) Gain(A)=I(n1,n2,,nK)E(A)


不相关属性(irrelevant attribute),属性与类分布相独立。此情况下信息增益过小,可以终止决策,将当前节点标签设为最高频类。
不充足属性(inadequate attribute),不同类的样本有完全相同特征。此情况下信息增益为 0 0 0,可以终止决策,将当前节点标签设为最高频类。
未知属性值(unknown value),数据集中某些属性值不确定。可以通过预处理剔除含有未知属性值的样本或属性。
空分支(empty branch),学习过程中某节点样本数为 0 0 0 v ≥ 3 v≥3 v3 才会发生。可以将当前节点标签设为父节点的最高频类。


import numpy as np
class ID3:
    def __init__(self, max_depth = 0, min_samples_split = 0):
        self.max_depth, self.min_samples_split = max_depth, min_samples_split
    def __EI(self, *n):
        n = np.array([i for i in n if i > 0])
        if n.shape[0] <= 1:
            return 0
        p = n / np.sum(n)
        return -np.dot(p, np.log2(p))
    def __Gain(self, A: np.ndarray):
        return self.__EI(*np.sum(A, axis = 0)) - np.average(np.frompyfunc(self.__EI, A.shape[1], 1)(*A.T), weights = np.sum(A, axis = 1))
    def fit(self, X: np.ndarray, y):
        self.DX, (self.Dy, yn) = [np.unique(X[:, i]) for i in range(X.shape[1])], np.unique(y, return_inverse = True)
        self.Dy: np.ndarray
        self.value = []
        def fitcur(n, h, p = 0):
            self.value.append(np.bincount(yn[n], minlength = self.Dy.shape[0]))
            r: np.ndarray = np.unique(y[n])
            if r.shape[0] == 0: # Empty Branch
                return p
            elif r.shape[0] == 1:
                return yn[n[0]]
            elif self.max_depth > 0 and h >= self.max_depth or n.shape[0] <= self.min_samples_split: # Overfitting
                return np.argmax(np.bincount(yn[n]))
                P = [[n[np.where(X[n, i] == j)[0]] for j in self.DX[i]] for i in range(X.shape[1])]
                G = [self.__Gain(A) for A in [np.array([[np.where(y[i] == j)[0].shape[0] for j in self.Dy] for i in p]) for p in P]]
                m = np.argmax(G)
                if(G[m] < 1e-9): # Inadequate attribute
                    return np.argmax(np.bincount(yn[n]))
                return (m,) + tuple(fitcur(i, h + 1, np.argmax(np.bincount(yn[n]))) for i in P[m])
        self.tree = fitcur(np.arange(X.shape[0]), 0)
    def predict(self, X):
        def precur(n, x):
            return precur(n[1 + np.where(self.DX[n[0]] == x[n[0]])[0][0]], x) if isinstance(n, tuple) else self.Dy[n]
        return np.array([precur(self.tree, x) for x in X])
    def visualize(self, header):
        i = iter(self.value)
        def visval():
            v = next(i)
            print(' (entropy = {}, samples = {}, value = {})'.format(self.__EI(*v), np.sum(v), v), end = '')
        def viscur(n, h, c):
            for i in h[:-1]:
                print('%c   ' % ('│' if i else ' '), end = '')
            if len(h) > 0:
                print('%c── ' % ('├' if h[-1] else '└'), end = '')
                print('[%s] ' % c, end = '')
            if isinstance(n, tuple):
                print(header[n[0]], end = '')
                for i in range(len(n) - 1):
                    viscur(n[i + 1], h + [i < len(n) - 2], str(self.DX[n[0]][i]))
                print(self.Dy[n], end = '')
        viscur(self.tree, [], '')
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59


对于某个连续属性,取训练集中所有属性值的相邻两点中点生成界点集,按每个界点将当前节点样本分为 2 2 2 类,算出界点集中最大信息增益的界点。


import numpy as np
class ID3:
    def __init__(self, max_depth = 0, min_samples_split = 0):
        self.max_depth, self.min_samples_split = max_depth, min_samples_split
    def __EI(self, *n):
        n = np.array([i for i in n if i > 0])
        if n.shape[0] <= 1:
            return 0
        p = n / np.sum(n)
        return -np.dot(p, np.log2(p))
    def __Gain(self, A: np.ndarray):
        return self.__EI(*np.sum(A, axis = 0)) - np.average([self.__EI(*a) for a in A], weights = np.sum(A, axis = 1))
    def fit(self, X: np.ndarray, y):
        self.c = np.array([(np.all([isinstance(j, (int, float)) for j in i])) for i in X.T])
        self.DX, (self.Dy, yn) = [np.unique(X[:, i]) if not self.c[i] else None for i in range(X.shape[1])], np.unique(y, return_inverse = True)
        self.Dy: np.ndarray
        self.value = []
        def Part(n, a):
            if self.c[a]:
                u = np.sort(np.unique(X[n, a]))
                if(u.shape[0] < 2):
                    return None
                v = np.array([(u[i - 1] + u[i]) / 2 for i in range(1, u.shape[0])])
                P = [[n[np.where(X[n, a] < i)[0]], n[np.where(X[n, a] >= i)[0]]] for i in v]
                m = np.argmax([self.__Gain([[np.where(y[i] == j)[0].shape[0] for j in self.Dy] for i in p]) for p in P])
                return v[m], P[m]
                return None, [n[np.where(X[n, a] == i)[0]] for i in self.DX[a]]
        def fitcur(n: np.ndarray, h, p = 0):
            self.value.append(np.bincount(yn[n], minlength = self.Dy.shape[0]))
            r: np.ndarray = np.unique(y[n])
            if r.shape[0] == 0: # Empty Branch
                return p
            elif r.shape[0] == 1:
                return yn[n[0]]
            elif self.max_depth > 0 and h >= self.max_depth or n.shape[0] <= self.min_samples_split: # Overfitting
                return np.argmax(np.bincount(yn[n]))
                P = [Part(n, i) for i in range(X.shape[1])]
                G = [self.__Gain([[np.where(y[i] == j)[0].shape[0] for j in self.Dy] for i in p[1]]) if p != None else 0 for p in P]
                m = np.argmax(G)
                if(G[m] < 1e-9): # Inadequate attribute
                    return np.argmax(np.bincount(yn[n]))
                return ((m, P[m][0]) if self.c[m] else (m,)) + tuple(fitcur(i, h + 1, np.argmax(np.bincount(yn[n]))) for i in P[m][1])
        self.tree = fitcur(np.arange(X.shape[0]), 0)
    def predict(self, X):
        def precur(n, x):
            return precur(n[(2 if x[n[0]] < n[1] else 3) if self.c[n[0]] else (1 + np.where(self.DX[n[0]] == x[n[0]])[0][0])], x) if isinstance(n, tuple) else self.Dy[n]
        return np.array([precur(self.tree, x) for x in X])
    def visualize(self, header):
        i = iter(self.value)
        def visval():
            v = next(i)
            print(' (entropy = {}, samples = {}, value = {})'.format(self.__EI(*v), np.sum(v), v), end = '')
        def viscur(n, h, c):
            for i in h[:-1]:
                print('%c   ' % ('│' if i else ' '), end = '')
            if len(h) > 0:
                print('%c── ' % ('├' if h[-1] else '└'), end = '')
                print('[%s] ' % c, end = '')
            if isinstance(n, tuple):
                print(header[n[0]], end = '')
                if self.c[n[0]]:
                    for i in range(2):
                        viscur(n[2 + i], h + [i < 1], ('< ', '>= ')[i] + str(n[1]))
                    for i in range(len(n) - 1):
                        viscur(n[1 + i], h + [i < len(n) - 2], str(self.DX[n[0]][i]))
                print(self.Dy[n], end = '')
        viscur(self.tree, [], '')
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75


Play tennis 数据集(来源:kaggle):离散属性
Mushroom classification 数据集(来源:kaggle):离散属性
Carsdata 数据集(来源:kaggle):连续属性
Iris 数据集(来源:sklearn.datasets):连续属性

其中 play_tennis.csv 内容如下:


Play tennis 数据集上的测试


import pandas as pd
class Datasets:
    def __init__(self, fn):
        self.df = pd.read_csv('Datasets\\%s' % fn).map(lambda x: x.strip() if isinstance(x, str) else x)
        self.df.rename(columns = lambda x: x.strip(), inplace = True)
    def getData(self, DX, Dy, drop = False):
        dfn = self.df.loc[~self.df.eq('').any(axis = 1)].apply(pd.to_numeric, errors = 'ignore') if drop else self.df
        return dfn[DX].to_numpy(dtype = np.object_), dfn[Dy].to_numpy(dtype = np.object_)

# play_tennis.csv
a = ['outlook', 'temp', 'humidity', 'wind']
X, y = Datasets('play_tennis.csv').getData(a, 'play')
dt11 = ID3()
dt11.fit(X, y)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16


outlook (entropy = 0.9402859586706311, samples = 14, value = [5 9])
├── [Overcast] Yes (entropy = 0, samples = 4, value = [0 4])
├── [Rain] wind (entropy = 0.9709505944546686, samples = 5, value = [2 3])
│   ├── [Strong] No (entropy = 0, samples = 2, value = [2 0])
│   └── [Weak] Yes (entropy = 0, samples = 3, value = [0 3])
└── [Sunny] humidity (entropy = 0.9709505944546686, samples = 5, value = [3 2])
    ├── [High] No (entropy = 0, samples = 3, value = [3 0])
    └── [Normal] Yes (entropy = 0, samples = 2, value = [0 2])
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8



# play_tennis.csv for inadequate attribute test and class > 2
a = ['temp', 'humidity', 'wind', 'play']
X, y = Datasets('play_tennis.csv').getData(a, 'outlook')
dt12 = ID3(10)
dt12.fit(X, y)
print(dt12.predict([['Cool', 'Normal', 'Weak', 'Yes']]))
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8


play (entropy = 1.5774062828523454, samples = 14, value = [4 5 5])
├── [No] temp (entropy = 0.9709505944546686, samples = 5, value = [0 2 3])
│   ├── [Cool] Rain (entropy = 0, samples = 1, value = [0 1 0])
│   ├── [Hot] Sunny (entropy = 0, samples = 2, value = [0 0 2])
│   └── [Mild] wind (entropy = 1.0, samples = 2, value = [0 1 1])
│       ├── [Strong] Rain (entropy = 0, samples = 1, value = [0 1 0])
│       └── [Weak] Sunny (entropy = 0, samples = 1, value = [0 0 1])
└── [Yes] temp (entropy = 1.5304930567574826, samples = 9, value = [4 3 2])
    ├── [Cool] wind (entropy = 1.584962500721156, samples = 3, value = [1 1 1])
    │   ├── [Strong] Overcast (entropy = 0, samples = 1, value = [1 0 0])
    │   └── [Weak] Rain (entropy = 1.0, samples = 2, value = [0 1 1])
    ├── [Hot] Overcast (entropy = 0, samples = 2, value = [2 0 0])
    └── [Mild] wind (entropy = 1.5, samples = 4, value = [1 2 1])
        ├── [Strong] humidity (entropy = 1.0, samples = 2, value = [1 0 1])
        │   ├── [High] Overcast (entropy = 0, samples = 1, value = [1 0 0])
        │   └── [Normal] Sunny (entropy = 0, samples = 1, value = [0 0 1])
        └── [Weak] Rain (entropy = 0, samples = 2, value = [0 2 0])
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18



# play_tennis.csv modified to generate empty branch
a = ['outlook', 'temp', 'humidity', 'wind']
X, y = Datasets('play_tennis.csv').getData(a, 'play')
X[2, 2], X[13, 2] = 'Low', 'Low'
dt13 = ID3()
dt13.fit(X, y)
print(dt13.predict([['Sunny', 'Hot', 'Low', 'Weak']]))
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9


outlook (entropy = 0.9402859586706311, samples = 14, value = [5 9])
├── [Overcast] Yes (entropy = 0, samples = 4, value = [0 4])
├── [Rain] wind (entropy = 0.9709505944546686, samples = 5, value = [2 3])
│   ├── [Strong] No (entropy = 0, samples = 2, value = [2 0])
│   └── [Weak] Yes (entropy = 0, samples = 3, value = [0 3])
└── [Sunny] humidity (entropy = 0.9709505944546686, samples = 5, value = [3 2])
    ├── [High] No (entropy = 0, samples = 3, value = [3 0])
    ├── [Low] No (entropy = 0, samples = 0, value = [0 0])
    └── [Normal] Yes (entropy = 0, samples = 2, value = [0 2])
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10

Mushroom classification 数据集上的测试


# mushrooms.csv ignoring attribute 'stalk-root' with unknown value
a = ['cap-shape', 'cap-surface', 'cap-color', 'bruises', 'odor', 'gill-attachment', 'gill-spacing', 'gill-size',
        'gill-color', 'stalk-shape', 'stalk-surface-above-ring', 'stalk-surface-below-ring', 'stalk-color-above-ring', 'stalk-color-below-ring', 'veil-type',
        'veil-color', 'ring-number', 'ring-type', 'spore-print-color', 'population', 'habitat']
X, y = Datasets('mushrooms.csv').getData(a, 'class')
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state = 20231218)
dt21 = ID3()
dt21.fit(X_train, y_train)
y_pred = dt21.predict(X_test)
print(classification_report(y_test, y_pred))
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14


odor (entropy = 0.9990161113058208, samples = 6093, value = [3159 2934])
├── [a] e (entropy = 0, samples = 298, value = [298   0])
├── [c] p (entropy = 0, samples = 138, value = [  0 138])
├── [f] p (entropy = 0, samples = 1636, value = [   0 1636])
├── [l] e (entropy = 0, samples = 297, value = [297   0])
├── [m] p (entropy = 0, samples = 26, value = [ 0 26])
├── [n] spore-print-color (entropy = 0.19751069442516636, samples = 2645, value = [2564   81])
│   ├── [b] e (entropy = 0, samples = 32, value = [32  0])
│   ├── [h] e (entropy = 0, samples = 35, value = [35  0])
│   ├── [k] e (entropy = 0, samples = 974, value = [974   0])
│   ├── [n] e (entropy = 0, samples = 1013, value = [1013    0])
│   ├── [o] e (entropy = 0, samples = 33, value = [33  0])
│   ├── [r] p (entropy = 0, samples = 50, value = [ 0 50])
│   ├── [u] e (entropy = 0, samples = 0, value = [0 0])
│   ├── [w] habitat (entropy = 0.34905151737109524, samples = 473, value = [442  31])
│   │   ├── [d] gill-size (entropy = 0.7062740891876007, samples = 26, value = [ 5 21])
│   │   │   ├── [b] e (entropy = 0, samples = 5, value = [5 0])
│   │   │   └── [n] p (entropy = 0, samples = 21, value = [ 0 21])
│   │   ├── [g] e (entropy = 0, samples = 222, value = [222   0])
│   │   ├── [l] cap-color (entropy = 0.7553754125614287, samples = 46, value = [36 10])
│   │   │   ├── [b] e (entropy = 0, samples = 0, value = [0 0])
│   │   │   ├── [c] e (entropy = 0, samples = 20, value = [20  0])
│   │   │   ├── [e] e (entropy = 0, samples = 0, value = [0 0])
│   │   │   ├── [g] e (entropy = 0, samples = 0, value = [0 0])
│   │   │   ├── [n] e (entropy = 0, samples = 16, value = [16  0])
│   │   │   ├── [p] e (entropy = 0, samples = 0, value = [0 0])
│   │   │   ├── [r] e (entropy = 0, samples = 0, value = [0 0])
│   │   │   ├── [u] e (entropy = 0, samples = 0, value = [0 0])
│   │   │   ├── [w] p (entropy = 0, samples = 7, value = [0 7])
│   │   │   └── [y] p (entropy = 0, samples = 3, value = [0 3])
│   │   ├── [m] e (entropy = 0, samples = 0, value = [0 0])
│   │   ├── [p] e (entropy = 0, samples = 35, value = [35  0])
│   │   ├── [u] e (entropy = 0, samples = 0, value = [0 0])
│   │   └── [w] e (entropy = 0, samples = 144, value = [144   0])
│   └── [y] e (entropy = 0, samples = 35, value = [35  0])
├── [p] p (entropy = 0, samples = 187, value = [  0 187])
├── [s] p (entropy = 0, samples = 433, value = [  0 433])
└── [y] p (entropy = 0, samples = 433, value = [  0 433])

              precision    recall  f1-score   support

           e       1.00      1.00      1.00      1049
           p       1.00      1.00      1.00       982

    accuracy                           1.00      2031
   macro avg       1.00      1.00      1.00      2031
weighted avg       1.00      1.00      1.00      2031
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47

Carsdata 数据集上的测试

默认属性三分类,忽略有未知值的样本,划分训练集和测试集,约束决策树生长最大深度为 5,节点最小样本数为 3,代码如下:

# cars.csv ignoring samples with unknown value
a = ['mpg', 'cylinders', 'cubicinches', 'hp', 'weightlbs', 'time-to-60', 'year']
X, y = Datasets('cars.csv').getData(a, 'brand', True)
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state = 20231218)
dt31 = ID3(5, 3)
dt31.fit(X_train, y_train)
y_pred = dt31.predict(X_test)
print(classification_report(y_test, y_pred))
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10


cubicinches (entropy = 1.3101461692119258, samples = 192, value = [ 37  33 122])
├── [< 191.0] year (entropy = 1.5833913647120852, samples = 105, value = [37 33 35])
│   ├── [< 1981.5] cubicinches (entropy = 1.5558899087683136, samples = 87, value = [37 27 23])
│   │   ├── [< 121.5] cubicinches (entropy = 1.4119058166561587, samples = 62, value = [31 23  8])
│   │   │   ├── [< 114.0] cubicinches (entropy = 1.4844331941390079, samples = 47, value = [18 21  8])
│   │   │   │   ├── [< 87.0] Japan. (entropy = 0.5435644431995964, samples = 8, value = [1 7 0])
│   │   │   │   └── [>= 87.0] Europe. (entropy = 1.521560239117063, samples = 39, value = [17 14  8])
│   │   │   └── [>= 114.0] weightlbs (entropy = 0.5665095065529053, samples = 15, value = [13  2  0])
│   │   │       ├── [< 2571.0] Europe. (entropy = 0.9709505944546686, samples = 5, value = [3 2 0])
│   │   │       └── [>= 2571.0] Europe. (entropy = 0, samples = 10, value = [10  0  0])
│   │   └── [>= 121.5] weightlbs (entropy = 1.3593308322365363, samples = 25, value = [ 6  4 15])
│   │       ├── [< 3076.5] hp (entropy = 0.9917601481809735, samples = 20, value = [ 1  4 15])
│   │       │   ├── [< 92.5] US. (entropy = 0, samples = 11, value = [ 0  0 11])
│   │       │   └── [>= 92.5] Japan. (entropy = 1.3921472236645345, samples = 9, value = [1 4 4])
│   │       └── [>= 3076.5] Europe. (entropy = 0, samples = 5, value = [5 0 0])
│   └── [>= 1981.5] mpg (entropy = 0.9182958340544896, samples = 18, value = [ 0  6 12])
│       ├── [< 31.3] US. (entropy = 0, samples = 9, value = [0 0 9])
│       └── [>= 31.3] mpg (entropy = 0.9182958340544896, samples = 9, value = [0 6 3])
│           ├── [< 33.2] Japan. (entropy = 0, samples = 4, value = [0 4 0])
│           └── [>= 33.2] time-to-60 (entropy = 0.9709505944546686, samples = 5, value = [0 2 3])
│               ├── [< 16.5] US. (entropy = 0, samples = 2, value = [0 0 2])
│               └── [>= 16.5] Japan. (entropy = 0.9182958340544896, samples = 3, value = [0 2 1])
└── [>= 191.0] US. (entropy = 0, samples = 87, value = [ 0  0 87])

              precision    recall  f1-score   support

     Europe.       0.50      0.80      0.62        10
      Japan.       0.83      0.56      0.67        18
         US.       0.94      0.94      0.94        36

    accuracy                           0.81        64
   macro avg       0.76      0.77      0.74        64
weighted avg       0.84      0.81      0.81        64
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33



# cars.csv with attribute 'cubicinches' discretized 
def find(a, n):
    def findcur(r):
        return findcur(r + 1) if r < len(a) and a[r] < n else r
    return findcur(0)
X[:, 2] = np.array([('a', 'b', 'c')[find([121.5, 191.0], i)] for i in X[:, 2]])
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state = 20231218)
dt32 = ID3(5, 3)
dt32.fit(X_train, y_train)
y_pred = dt32.predict(X_test)
print(classification_report(y_test, y_pred))
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12


cubicinches (entropy = 1.3101461692119258, samples = 192, value = [ 37  33 122])
├── [a] year (entropy = 1.5231103605784926, samples = 74, value = [31 28 15])
│   ├── [< 1981.5] weightlbs (entropy = 1.4119058166561587, samples = 62, value = [31 23  8])
│   │   ├── [< 2571.0] weightlbs (entropy = 1.4729350396193688, samples = 50, value = [20 22  8])
│   │   │   ├── [< 2271.5] mpg (entropy = 1.5038892873131435, samples = 42, value = [19 15  8])
│   │   │   │   ├── [< 30.25] Europe. (entropy = 1.2640886121123147, samples = 23, value = [15  5  3])
│   │   │   │   └── [>= 30.25] Japan. (entropy = 1.4674579648482995, samples = 19, value = [ 4 10  5])
│   │   │   └── [>= 2271.5] mpg (entropy = 0.5435644431995964, samples = 8, value = [1 7 0])
│   │   │       ├── [< 37.9] Japan. (entropy = 0, samples = 7, value = [0 7 0])
│   │   │       └── [>= 37.9] Europe. (entropy = 0, samples = 1, value = [1 0 0])
│   │   └── [>= 2571.0] cylinders (entropy = 0.41381685030363374, samples = 12, value = [11  1  0])
│   │       ├── [< 3.5] Japan. (entropy = 0, samples = 1, value = [0 1 0])
│   │       └── [>= 3.5] Europe. (entropy = 0, samples = 11, value = [11  0  0])
│   └── [>= 1981.5] mpg (entropy = 0.9798687566511528, samples = 12, value = [0 5 7])
│       ├── [< 31.3] US. (entropy = 0, samples = 4, value = [0 0 4])
│       └── [>= 31.3] mpg (entropy = 0.954434002924965, samples = 8, value = [0 5 3])
│           ├── [< 33.2] Japan. (entropy = 0, samples = 3, value = [0 3 0])
│           └── [>= 33.2] time-to-60 (entropy = 0.9709505944546686, samples = 5, value = [0 2 3])
│               ├── [< 16.5] US. (entropy = 0, samples = 2, value = [0 0 2])
│               └── [>= 16.5] Japan. (entropy = 0.9182958340544896, samples = 3, value = [0 2 1])
├── [b] weightlbs (entropy = 1.2910357498542626, samples = 31, value = [ 6  5 20])
│   ├── [< 3076.5] hp (entropy = 0.9293550115186283, samples = 26, value = [ 1  5 20])
│   │   ├── [< 93.5] US. (entropy = 0, samples = 16, value = [ 0  0 16])
│   │   └── [>= 93.5] time-to-60 (entropy = 1.360964047443681, samples = 10, value = [1 5 4])
│   │       ├── [< 15.5] cylinders (entropy = 0.954434002924965, samples = 8, value = [0 5 3])
│   │       │   ├── [< 5.0] Japan. (entropy = 0, samples = 3, value = [0 3 0])
│   │       │   └── [>= 5.0] US. (entropy = 0.9709505944546686, samples = 5, value = [0 2 3])
│   │       └── [>= 15.5] Europe. (entropy = 1.0, samples = 2, value = [1 0 1])
│   └── [>= 3076.5] Europe. (entropy = 0, samples = 5, value = [5 0 0])
└── [c] US. (entropy = 0, samples = 87, value = [ 0  0 87])
              precision    recall  f1-score   support

     Europe.       0.57      0.40      0.47        10
      Japan.       0.65      0.72      0.68        18
         US.       0.92      0.94      0.93        36

    accuracy                           0.80        64
   macro avg       0.71      0.69      0.70        64
weighted avg       0.79      0.80      0.79        64
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39

Iris 数据集上的测试

默认属性三分类,划分训练集和测试集,限制决策树生长最大深度为 3,代码如下:

# iris dataset
from sklearn.datasets import load_iris
iris = load_iris()
X, y = iris.data, iris.target
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state = 20231218)
dt41 = ID3(3)
dt41.fit(X_train, y_train)
y_pred = dt41.predict(X_test)
print(classification_report(y_test, y_pred))
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11


petal length (cm) (entropy = 1.5807197138422102, samples = 112, value = [34 41 37])
├── [< 2.45] 0 (entropy = 0, samples = 34, value = [34  0  0])
└── [>= 2.45] petal width (cm) (entropy = 0.9981021327390103, samples = 78, value = [ 0 41 37])
    ├── [< 1.75] petal length (cm) (entropy = 0.4394969869215134, samples = 44, value = [ 0 40  4])
    │   ├── [< 4.95] 1 (entropy = 0.17203694935311378, samples = 39, value = [ 0 38  1])
    │   └── [>= 4.95] 2 (entropy = 0.9709505944546686, samples = 5, value = [0 2 3])
    └── [>= 1.75] petal length (cm) (entropy = 0.19143325481419343, samples = 34, value = [ 0  1 33])
        ├── [< 4.85] 2 (entropy = 0.9182958340544896, samples = 3, value = [0 1 2])
        └── [>= 4.85] 2 (entropy = 0, samples = 31, value = [ 0  0 31])

              precision    recall  f1-score   support

           0       1.00      1.00      1.00        16
           1       1.00      1.00      1.00         9
           2       1.00      1.00      1.00        13

    accuracy                           1.00        38
   macro avg       1.00      1.00      1.00        38
weighted avg       1.00      1.00      1.00        38
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
