当前位置:   article > 正文

sklearn决策树与随机森林 参数及规则提取 模型可视化(初体验)_sklearn randomforestclassifier 规则提炼

sklearn randomforestclassifier 规则提炼

决策树

import os
import pandas as pd
import numpy as np
from sklearn import tree
from sklearn.tree import _tree
from sklearn.model_selection import train_test_split
from sklearn.feature_extraction import DictVectorizer
from sklearn.tree import DecisionTreeClassifier
from sklearn.metrics import classification_report
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import confusion_matrix
import matplotlib.pyplot as plt
import pydotplus


def tree_to_code(tree, feature_names):    # 决策树规则提取
    tree_ = tree.tree_
    feature_name = [
        feature_names[i] if i != _tree.TREE_UNDEFINED else "undefined!"
        for i in tree_.feature
    ]
    print('feature_name:', feature_name)
    with open('code.txt', 'a+') as f:
        f.write("def tree({}):".format(", ".join(feature_names)))
        f.write('\n')
        f.close()

    def recurse(node, depth):
        indent = "  " * depth
        # print('tree_.feature:',tree_.feature)
        if tree_.feature[node] != _tree.TREE_UNDEFINED:
            # print('tree_.feature[node]:',tree_.feature[node])
            name = feature_name[node]
            threshold = tree_.threshold[node]
            with open('code.txt', 'a+') as f:
                f.write("{}if {} <= {}:".format(indent, name, threshold))
                f.write('\n')
                f.close()
            recurse(tree_.children_left[node], depth + 1)
            with open('code.txt', 'a+') as f:
                f.write("{}else:  # if {} > {}".format(indent, name, threshold))
                f.write('\n')
                f.close()
            recurse(tree_.children_right[node], depth + 1)
        else:
            with open('code.txt', 'a+') as f:
                f.write("{}return {} -- {}".format(indent, tree_.value[node],
                                                target_name[np.argmax(tree_.value[node])]))
                f.write('\n')
                f.close()

    recurse(0, 1)

pwd = os.getcwd()
titanic = pd.read_csv(pwd + '/ta.txt')
titanic['age'].fillna(titanic['age'].mean(), inplace=True) #  补充缺失值
# 选取一些特征作为我们划分的依据
x = titanic[['pclass', 'age', 'sex']]
y = titanic['survived']
labels = [0, 1]
target_name = ["deid", "survived"]

fea_name = ["sex", "age", "pclass"]
fea_name.sort()

x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.3)  # 测试数据和训练数据的比例 数值为测数据/总体数据

dt = DictVectorizer(sparse=False)   #  sparse=False意思是不产生稀疏矩阵

x_train = dt.fit_transform(x_train.to_dict(orient="record"))

x_test = dt.fit_transform(x_test.to_dict(orient="record"))

# 使用决策树
dtc = DecisionTreeClassifier(                         # 使用默认的就行
                             # class_weight='balanced',   #  平衡数据集
                             # criterion='entropy',     # 划分标准使用gini还是信息熵  默认gini
                             # max_features='sqrt',   
                             )

dtc.fit(x_train, y_train)

dt_predict = dtc.predict(x_test)


tree_to_code(dtc, fea_name)  # 实现决策树的规则提取

print(dtc.score(x_test, y_test))  

print(classification_report(y_test, dt_predict, labels=labels, target_names=target_name))

# # 混淆矩阵并可视化
confmat = confusion_matrix(y_true=y_test, y_pred=rfc_y_predict, labels=labels)  # 输出混淆矩阵
print(confmat)

fig, ax = plt.subplots(figsize=(3, 3))
ax.matshow(confmat, cmap=plt.cm.Blues, alpha=0.3)
for i in range(confmat.shape[0]):
    for j in range(confmat.shape[1]):
        ax.text(x=j, y=i, s=confmat[i, j], va='center', ha='center')

plt.xticks(range(len(confmat)), labels)
plt.yticks(range(len(confmat)), labels)
plt.xlabel('predicted label')
plt.ylabel('true label')
plt.savefig('confusion_matrix.png')
plt.show()

# 可视化决策树

os.environ["PATH"] += os.pathsep + 'graphviz的bin路径'   #  在pycharm运行时 可能会出现找不到graphviz的情况,自己加环境
dot_data = tree.export_graphviz(dtc, out_file=None, feature_names=fea_name, class_names=target_name,
                                filled=True,
                                rounded=True,
                                )
graph = pydotplus.graph_from_dot_data(dot_data)
graph.write_pdf("descion_tree.pdf")
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107
  • 108
  • 109
  • 110
  • 111
  • 112
  • 113
  • 114
  • 115
  • 116
  • 117

随机森林

# 之前的数据导入处理和决策树一样
# 使用随机森林

rfc = RandomForestClassifier(n_estimators=100, max_depth=6)  #  如果不设置n_estimators的值 在2.0版本会有警告提示 建议将其设置为2.02的默认值100

rfc.fit(x_train, y_train)

rfc_y_predict = rfc.predict(x_test)

print(rfc.score(x_test, y_test))

print(classification_report(y_test, rfc_y_predict, labels=labels, target_names=target_name))

if os.path.exists(pwd + '/forest/'):
    os.chdir(pwd + '/forest/')
else:
    os.mkdir(pwd + '/forest/')
    os.chdir(pwd + '/forest/')

for idx, estimator in enumerate(rfc.estimators_):
    # 导出dot文件
    filename = 'forest_' + str(idx) + '.pdf'
    dot_data = tree.export_graphviz(estimator,
                                    out_file=None,
                                    feature_names=fea_name,
                                    class_names=target_name,
                                    rounded=True,
                                    proportion=False,
                                    precision=2,
                                    filled=True)
                                    
    graph = pydotplus.graph_from_dot_data(dot_data)

    graph.write_pdf(filename)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34

本地文件ta的原始文件
性别 Pclass 分别做了数值处理

提取的规则代码块

def tree(age, pclass, sex):
  if sex <= 1.5:
    if age <= 10.0:
      if pclass <= 2.5:
        return [[ 0. 12.]] deid
      else:  # if pclass > 2.5
        if age <= 0.583299994468689:
          return [[1. 0.]] survived
        else:  # if age > 0.583299994468689
          if age <= 4.0:
            return [[0. 3.]] deid
          else:  # if age > 4.0
            if age <= 7.5:
              return [[2. 0.]] survived
            else:  # if age > 7.5
              return [[1. 2.]] deid
    else:  # if age > 10.0
      if pclass <= 1.5:
        if age <= 54.5:
          if age <= 29.0:
            if age <= 17.5:
              return [[0. 2.]] deid
            else:  # if age > 17.5
              if age <= 24.5:
                if age <= 20.0:
                  return [[2. 0.]] survived
                else:  # if age > 20.0
                  if age <= 23.5:
                    if age <= 21.5:
                      return [[0. 1.]] deid
                    else:  # if age > 21.5
                      if age <= 22.5:
                        return [[1. 0.]] survived
                      else:  # if age > 22.5
                        return [[0. 1.]] deid
                  else:  # if age > 23.5
                    return [[2. 0.]] survived
              else:  # if age > 24.5
                if age <= 26.0:
                  return [[1. 2.]] deid
                else:  # if age > 26.0
                  if age <= 27.5:
                    return [[0. 1.]] deid
                  else:  # if age > 27.5
                    return [[1. 2.]] deid
          else:  # if age > 29.0
            if age <= 33.5:
              if age <= 31.09709072113037:
                return [[4. 0.]] survived
              else:  # if age > 31.09709072113037
                if age <= 32.09709072113037:
                  return [[29. 10.]] survived
                else:  # if age > 32.09709072113037
                  return [[2. 0.]] survived
            else:  # if age > 33.5
              if age <= 36.5:
                if age <= 35.5:
                  return [[0. 2.]] deid
                else:  # if age > 35.5
                  return [[1. 4.]] deid
              else:  # if age > 36.5
                if age <= 47.5:
                  if age <= 38.5:
                    if age <= 37.5:
                      return [[1. 1.]] survived
                    else:  # if age > 37.5
                      return [[1. 1.]] survived
                  else:  # if age > 38.5
                    if age <= 45.5:
                      if age <= 41.5:
                        if age <= 39.5:
                          return [[3. 1.]] survived
                        else:  # if age > 39.5
                          return [[2. 0.]] survived
                      else:  # if age > 41.5
                        if age <= 43.0:
                          return [[2. 1.]] survived
                        else:  # if age > 43.0
                          if age <= 44.5:
                            return [[1. 0.]] survived
                          else:  # if age > 44.5
                            return [[3. 1.]] survived
                    else:  # if age > 45.5
                      if age <= 46.5:
                        return [[5. 0.]] survived
                      else:  # if age > 46.5
                        return [[3. 1.]] survived
                else:  # if age > 47.5
                  if age <= 48.5:
                    return [[1. 2.]] deid
                  else:  # if age > 48.5
                    if age <= 51.5:
                      if age <= 49.5:
                        return [[2. 1.]] survived
                      else:  # if age > 49.5
                        return [[3. 0.]] survived
                    else:  # if age > 51.5
                      if age <= 53.0:
                        return [[1. 1.]] survived
                      else:  # if age > 53.0
                        return [[1. 1.]] survived
        else:  # if age > 54.5
          return [[14.  0.]] survived
      else:  # if pclass > 1.5
        if age <= 29.5:
          if age <= 25.5:
            if age <= 23.5:
              if age <= 18.5:
                return [[17.  0.]] survived
              else:  # if age > 18.5
                if age <= 19.5:
                  if pclass <= 2.5:
                    return [[1. 0.]] survived
                  else:  # if pclass > 2.5
                    return [[4. 1.]] survived
                else:  # if age > 19.5
                  if age <= 20.5:
                    return [[8. 0.]] survived
                  else:  # if age > 20.5
                    if age <= 22.5:
                      if age <= 21.5:
                        if pclass <= 2.5:
                          return [[5. 0.]] survived
                        else:  # if pclass > 2.5
                          return [[4. 1.]] survived
                      else:  # if age > 21.5
                        if pclass <= 2.5:
                          return [[3. 1.]] survived
                        else:  # if pclass > 2.5
                          return [[3. 0.]] survived
                    else:  # if age > 22.5
                      return [[7. 0.]] survived
            else:  # if age > 23.5
              if age <= 24.5:
                if pclass <= 2.5:
                  return [[1. 1.]] survived
                else:  # if pclass > 2.5
                  return [[6. 1.]] survived
              else:  # if age > 24.5
                if pclass <= 2.5:
                  return [[4. 0.]] survived
                else:  # if pclass > 2.5
                  return [[4. 1.]] survived
          else:  # if age > 25.5
            return [[23.  0.]] survived
        else:  # if age > 29.5
          if age <= 45.5:
            if age <= 44.5:
              if age <= 32.5:
                if age <= 31.59709072113037:
                  if pclass <= 2.5:
                    if age <= 30.59709072113037:
                      return [[8. 0.]] survived
                    else:  # if age > 30.59709072113037
                      return [[32.  4.]] survived
                  else:  # if pclass > 2.5
                    if age <= 30.59709072113037:
                      return [[1. 1.]] survived
                    else:  # if age > 30.59709072113037
                      return [[220.  32.]] survived
                else:  # if age > 31.59709072113037
                  if pclass <= 2.5:
                    return [[3. 2.]] survived
                  else:  # if pclass > 2.5
                    return [[5. 0.]] survived
              else:  # if age > 32.5
                if age <= 35.5:
                  return [[11.  0.]] survived
                else:  # if age > 35.5
                  if age <= 36.5:
                    if pclass <= 2.5:
                      return [[1. 0.]] survived
                    else:  # if pclass > 2.5
                      return [[0. 1.]] deid
                  else:  # if age > 36.5
                    if pclass <= 2.5:
                      if age <= 40.5:
                        return [[3. 0.]] survived
                      else:  # if age > 40.5
                        if age <= 41.5:
                          return [[1. 1.]] survived
                        else:  # if age > 41.5
                          return [[3. 0.]] survived
                    else:  # if pclass > 2.5
                      return [[11.  0.]] survived
            else:  # if age > 44.5
              if pclass <= 2.5:
                return [[2. 0.]] survived
              else:  # if pclass > 2.5
                return [[1. 1.]] survived
          else:  # if age > 45.5
            return [[13.  0.]] survived
  else:  # if sex > 1.5
    if pclass <= 2.5:
      if pclass <= 1.5:
        if age <= 62.5:
          if age <= 36.5:
            if age <= 35.5:
              if age <= 24.5:
                return [[ 0. 19.]] deid
              else:  # if age > 24.5
                if age <= 26.0:
                  return [[1. 0.]] survived
                else:  # if age > 26.0
                  if age <= 31.09709072113037:
                    return [[0. 6.]] deid
                  else:  # if age > 31.09709072113037
                    if age <= 32.09709072113037:
                      return [[ 1. 23.]] deid
                    else:  # if age > 32.09709072113037
                      return [[0. 5.]] deid
            else:  # if age > 35.5
              return [[1. 3.]] deid
          else:  # if age > 36.5
            return [[ 0. 31.]] deid
        else:  # if age > 62.5
          if age <= 63.5:
            return [[1. 1.]] survived
          else:  # if age > 63.5
            return [[0. 1.]] deid
      else:  # if pclass > 1.5
        if age <= 17.5:
          return [[0. 9.]] deid
        else:  # if age > 17.5
          if age <= 22.5:
            if age <= 21.5:
              if age <= 18.5:
                return [[1. 3.]] deid
              else:  # if age > 18.5
                return [[0. 5.]] deid
            else:  # if age > 21.5
              return [[2. 0.]] survived
          else:  # if age > 22.5
            if age <= 26.5:
              return [[0. 5.]] deid
            else:  # if age > 26.5
              if age <= 27.5:
                return [[1. 1.]] survived
              else:  # if age > 27.5
                if age <= 29.5:
                  return [[0. 5.]] deid
                else:  # if age > 29.5
                  if age <= 30.5:
                    return [[1. 2.]] deid
                  else:  # if age > 30.5
                    if age <= 46.0:
                      if age <= 43.0:
                        if age <= 39.0:
                          if age <= 37.0:
                            if age <= 31.59709072113037:
                              if age <= 31.09709072113037:
                                return [[0. 2.]] deid
                              else:  # if age > 31.09709072113037
                                return [[ 3. 17.]] deid
                            else:  # if age > 31.59709072113037
                              return [[0. 9.]] deid
                          else:  # if age > 37.0
                            return [[1. 0.]] survived
                        else:  # if age > 39.0
                          return [[0. 4.]] deid
                      else:  # if age > 43.0
                        return [[1. 0.]] survived
                    else:  # if age > 46.0
                      return [[0. 5.]] deid
    else:  # if pclass > 2.5
      if age <= 19.5:
        if age <= 12.0:
          if age <= 5.5:
            if age <= 1.0833500027656555:
              return [[0. 1.]] deid
            else:  # if age > 1.0833500027656555
              if age <= 3.5:
                return [[1. 0.]] survived
              else:  # if age > 3.5
                return [[0. 1.]] deid
          else:  # if age > 5.5
            return [[2. 0.]] survived
        else:  # if age > 12.0
          if age <= 17.5:
            if age <= 15.5:
              return [[0. 1.]] deid
            else:  # if age > 15.5
              if age <= 16.5:
                return [[1. 3.]] deid
              else:  # if age > 16.5
                return [[0. 1.]] deid
          else:  # if age > 17.5
            if age <= 18.5:
              return [[2. 3.]] deid
            else:  # if age > 18.5
              return [[0. 1.]] deid
      else:  # if age > 19.5
        if age <= 21.5:
          return [[3. 0.]] survived
        else:  # if age > 21.5
          if age <= 23.5:
            if age <= 22.5:
              return [[1. 2.]] deid
            else:  # if age > 22.5
              return [[0. 1.]] deid
          else:  # if age > 23.5
            if age <= 32.5:
              if age <= 31.59709072113037:
                if age <= 25.5:
                  return [[1. 1.]] survived
                else:  # if age > 25.5
                  if age <= 29.0:
                    return [[2. 0.]] survived
                  else:  # if age > 29.0
                    if age <= 30.59709072113037:
                      return [[1. 1.]] survived
                    else:  # if age > 30.59709072113037
                      return [[75. 40.]] survived
              else:  # if age > 31.59709072113037
                return [[1. 0.]] survived
            else:  # if age > 32.5
              if age <= 37.0:
                return [[0. 3.]] deid
              else:  # if age > 37.0
                if age <= 42.5:
                  return [[2. 0.]] survived
                else:  # if age > 42.5
                  return [[1. 1.]] survived

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107
  • 108
  • 109
  • 110
  • 111
  • 112
  • 113
  • 114
  • 115
  • 116
  • 117
  • 118
  • 119
  • 120
  • 121
  • 122
  • 123
  • 124
  • 125
  • 126
  • 127
  • 128
  • 129
  • 130
  • 131
  • 132
  • 133
  • 134
  • 135
  • 136
  • 137
  • 138
  • 139
  • 140
  • 141
  • 142
  • 143
  • 144
  • 145
  • 146
  • 147
  • 148
  • 149
  • 150
  • 151
  • 152
  • 153
  • 154
  • 155
  • 156
  • 157
  • 158
  • 159
  • 160
  • 161
  • 162
  • 163
  • 164
  • 165
  • 166
  • 167
  • 168
  • 169
  • 170
  • 171
  • 172
  • 173
  • 174
  • 175
  • 176
  • 177
  • 178
  • 179
  • 180
  • 181
  • 182
  • 183
  • 184
  • 185
  • 186
  • 187
  • 188
  • 189
  • 190
  • 191
  • 192
  • 193
  • 194
  • 195
  • 196
  • 197
  • 198
  • 199
  • 200
  • 201
  • 202
  • 203
  • 204
  • 205
  • 206
  • 207
  • 208
  • 209
  • 210
  • 211
  • 212
  • 213
  • 214
  • 215
  • 216
  • 217
  • 218
  • 219
  • 220
  • 221
  • 222
  • 223
  • 224
  • 225
  • 226
  • 227
  • 228
  • 229
  • 230
  • 231
  • 232
  • 233
  • 234
  • 235
  • 236
  • 237
  • 238
  • 239
  • 240
  • 241
  • 242
  • 243
  • 244
  • 245
  • 246
  • 247
  • 248
  • 249
  • 250
  • 251
  • 252
  • 253
  • 254
  • 255
  • 256
  • 257
  • 258
  • 259
  • 260
  • 261
  • 262
  • 263
  • 264
  • 265
  • 266
  • 267
  • 268
  • 269
  • 270
  • 271
  • 272
  • 273
  • 274
  • 275
  • 276
  • 277
  • 278
  • 279
  • 280
  • 281
  • 282
  • 283
  • 284
  • 285
  • 286
  • 287
  • 288
  • 289
  • 290
  • 291
  • 292
  • 293
  • 294
  • 295
  • 296
  • 297
  • 298
  • 299
  • 300
  • 301
  • 302
  • 303
  • 304
  • 305
  • 306
  • 307
  • 308
  • 309
  • 310
  • 311
  • 312
  • 313
  • 314
  • 315
  • 316
  • 317
  • 318
  • 319
  • 320
  • 321
  • 322
  • 323
  • 324

混淆矩阵

结果图
在这里插入图片描述

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/很楠不爱3/article/detail/285477
推荐阅读
  

闽ICP备14008679号