赞
踩
import os import pandas as pd import numpy as np from sklearn import tree from sklearn.tree import _tree from sklearn.model_selection import train_test_split from sklearn.feature_extraction import DictVectorizer from sklearn.tree import DecisionTreeClassifier from sklearn.metrics import classification_report from sklearn.ensemble import RandomForestClassifier from sklearn.metrics import confusion_matrix import matplotlib.pyplot as plt import pydotplus def tree_to_code(tree, feature_names): # 决策树规则提取 tree_ = tree.tree_ feature_name = [ feature_names[i] if i != _tree.TREE_UNDEFINED else "undefined!" for i in tree_.feature ] print('feature_name:', feature_name) with open('code.txt', 'a+') as f: f.write("def tree({}):".format(", ".join(feature_names))) f.write('\n') f.close() def recurse(node, depth): indent = " " * depth # print('tree_.feature:',tree_.feature) if tree_.feature[node] != _tree.TREE_UNDEFINED: # print('tree_.feature[node]:',tree_.feature[node]) name = feature_name[node] threshold = tree_.threshold[node] with open('code.txt', 'a+') as f: f.write("{}if {} <= {}:".format(indent, name, threshold)) f.write('\n') f.close() recurse(tree_.children_left[node], depth + 1) with open('code.txt', 'a+') as f: f.write("{}else: # if {} > {}".format(indent, name, threshold)) f.write('\n') f.close() recurse(tree_.children_right[node], depth + 1) else: with open('code.txt', 'a+') as f: f.write("{}return {} -- {}".format(indent, tree_.value[node], target_name[np.argmax(tree_.value[node])])) f.write('\n') f.close() recurse(0, 1) pwd = os.getcwd() titanic = pd.read_csv(pwd + '/ta.txt') titanic['age'].fillna(titanic['age'].mean(), inplace=True) # 补充缺失值 # 选取一些特征作为我们划分的依据 x = titanic[['pclass', 'age', 'sex']] y = titanic['survived'] labels = [0, 1] target_name = ["deid", "survived"] fea_name = ["sex", "age", "pclass"] fea_name.sort() x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.3) # 测试数据和训练数据的比例 数值为测数据/总体数据 dt = DictVectorizer(sparse=False) # sparse=False意思是不产生稀疏矩阵 x_train = dt.fit_transform(x_train.to_dict(orient="record")) x_test = dt.fit_transform(x_test.to_dict(orient="record")) # 使用决策树 dtc = DecisionTreeClassifier( # 使用默认的就行 # class_weight='balanced', # 平衡数据集 # criterion='entropy', # 划分标准使用gini还是信息熵 默认gini # max_features='sqrt', ) dtc.fit(x_train, y_train) dt_predict = dtc.predict(x_test) tree_to_code(dtc, fea_name) # 实现决策树的规则提取 print(dtc.score(x_test, y_test)) print(classification_report(y_test, dt_predict, labels=labels, target_names=target_name)) # # 混淆矩阵并可视化 confmat = confusion_matrix(y_true=y_test, y_pred=rfc_y_predict, labels=labels) # 输出混淆矩阵 print(confmat) fig, ax = plt.subplots(figsize=(3, 3)) ax.matshow(confmat, cmap=plt.cm.Blues, alpha=0.3) for i in range(confmat.shape[0]): for j in range(confmat.shape[1]): ax.text(x=j, y=i, s=confmat[i, j], va='center', ha='center') plt.xticks(range(len(confmat)), labels) plt.yticks(range(len(confmat)), labels) plt.xlabel('predicted label') plt.ylabel('true label') plt.savefig('confusion_matrix.png') plt.show() # 可视化决策树 os.environ["PATH"] += os.pathsep + 'graphviz的bin路径' # 在pycharm运行时 可能会出现找不到graphviz的情况,自己加环境 dot_data = tree.export_graphviz(dtc, out_file=None, feature_names=fea_name, class_names=target_name, filled=True, rounded=True, ) graph = pydotplus.graph_from_dot_data(dot_data) graph.write_pdf("descion_tree.pdf")
# 之前的数据导入处理和决策树一样 # 使用随机森林 rfc = RandomForestClassifier(n_estimators=100, max_depth=6) # 如果不设置n_estimators的值 在2.0版本会有警告提示 建议将其设置为2.02的默认值100 rfc.fit(x_train, y_train) rfc_y_predict = rfc.predict(x_test) print(rfc.score(x_test, y_test)) print(classification_report(y_test, rfc_y_predict, labels=labels, target_names=target_name)) if os.path.exists(pwd + '/forest/'): os.chdir(pwd + '/forest/') else: os.mkdir(pwd + '/forest/') os.chdir(pwd + '/forest/') for idx, estimator in enumerate(rfc.estimators_): # 导出dot文件 filename = 'forest_' + str(idx) + '.pdf' dot_data = tree.export_graphviz(estimator, out_file=None, feature_names=fea_name, class_names=target_name, rounded=True, proportion=False, precision=2, filled=True) graph = pydotplus.graph_from_dot_data(dot_data) graph.write_pdf(filename)
本地文件ta的原始文件
性别 Pclass 分别做了数值处理
提取的规则代码块
def tree(age, pclass, sex): if sex <= 1.5: if age <= 10.0: if pclass <= 2.5: return [[ 0. 12.]] deid else: # if pclass > 2.5 if age <= 0.583299994468689: return [[1. 0.]] survived else: # if age > 0.583299994468689 if age <= 4.0: return [[0. 3.]] deid else: # if age > 4.0 if age <= 7.5: return [[2. 0.]] survived else: # if age > 7.5 return [[1. 2.]] deid else: # if age > 10.0 if pclass <= 1.5: if age <= 54.5: if age <= 29.0: if age <= 17.5: return [[0. 2.]] deid else: # if age > 17.5 if age <= 24.5: if age <= 20.0: return [[2. 0.]] survived else: # if age > 20.0 if age <= 23.5: if age <= 21.5: return [[0. 1.]] deid else: # if age > 21.5 if age <= 22.5: return [[1. 0.]] survived else: # if age > 22.5 return [[0. 1.]] deid else: # if age > 23.5 return [[2. 0.]] survived else: # if age > 24.5 if age <= 26.0: return [[1. 2.]] deid else: # if age > 26.0 if age <= 27.5: return [[0. 1.]] deid else: # if age > 27.5 return [[1. 2.]] deid else: # if age > 29.0 if age <= 33.5: if age <= 31.09709072113037: return [[4. 0.]] survived else: # if age > 31.09709072113037 if age <= 32.09709072113037: return [[29. 10.]] survived else: # if age > 32.09709072113037 return [[2. 0.]] survived else: # if age > 33.5 if age <= 36.5: if age <= 35.5: return [[0. 2.]] deid else: # if age > 35.5 return [[1. 4.]] deid else: # if age > 36.5 if age <= 47.5: if age <= 38.5: if age <= 37.5: return [[1. 1.]] survived else: # if age > 37.5 return [[1. 1.]] survived else: # if age > 38.5 if age <= 45.5: if age <= 41.5: if age <= 39.5: return [[3. 1.]] survived else: # if age > 39.5 return [[2. 0.]] survived else: # if age > 41.5 if age <= 43.0: return [[2. 1.]] survived else: # if age > 43.0 if age <= 44.5: return [[1. 0.]] survived else: # if age > 44.5 return [[3. 1.]] survived else: # if age > 45.5 if age <= 46.5: return [[5. 0.]] survived else: # if age > 46.5 return [[3. 1.]] survived else: # if age > 47.5 if age <= 48.5: return [[1. 2.]] deid else: # if age > 48.5 if age <= 51.5: if age <= 49.5: return [[2. 1.]] survived else: # if age > 49.5 return [[3. 0.]] survived else: # if age > 51.5 if age <= 53.0: return [[1. 1.]] survived else: # if age > 53.0 return [[1. 1.]] survived else: # if age > 54.5 return [[14. 0.]] survived else: # if pclass > 1.5 if age <= 29.5: if age <= 25.5: if age <= 23.5: if age <= 18.5: return [[17. 0.]] survived else: # if age > 18.5 if age <= 19.5: if pclass <= 2.5: return [[1. 0.]] survived else: # if pclass > 2.5 return [[4. 1.]] survived else: # if age > 19.5 if age <= 20.5: return [[8. 0.]] survived else: # if age > 20.5 if age <= 22.5: if age <= 21.5: if pclass <= 2.5: return [[5. 0.]] survived else: # if pclass > 2.5 return [[4. 1.]] survived else: # if age > 21.5 if pclass <= 2.5: return [[3. 1.]] survived else: # if pclass > 2.5 return [[3. 0.]] survived else: # if age > 22.5 return [[7. 0.]] survived else: # if age > 23.5 if age <= 24.5: if pclass <= 2.5: return [[1. 1.]] survived else: # if pclass > 2.5 return [[6. 1.]] survived else: # if age > 24.5 if pclass <= 2.5: return [[4. 0.]] survived else: # if pclass > 2.5 return [[4. 1.]] survived else: # if age > 25.5 return [[23. 0.]] survived else: # if age > 29.5 if age <= 45.5: if age <= 44.5: if age <= 32.5: if age <= 31.59709072113037: if pclass <= 2.5: if age <= 30.59709072113037: return [[8. 0.]] survived else: # if age > 30.59709072113037 return [[32. 4.]] survived else: # if pclass > 2.5 if age <= 30.59709072113037: return [[1. 1.]] survived else: # if age > 30.59709072113037 return [[220. 32.]] survived else: # if age > 31.59709072113037 if pclass <= 2.5: return [[3. 2.]] survived else: # if pclass > 2.5 return [[5. 0.]] survived else: # if age > 32.5 if age <= 35.5: return [[11. 0.]] survived else: # if age > 35.5 if age <= 36.5: if pclass <= 2.5: return [[1. 0.]] survived else: # if pclass > 2.5 return [[0. 1.]] deid else: # if age > 36.5 if pclass <= 2.5: if age <= 40.5: return [[3. 0.]] survived else: # if age > 40.5 if age <= 41.5: return [[1. 1.]] survived else: # if age > 41.5 return [[3. 0.]] survived else: # if pclass > 2.5 return [[11. 0.]] survived else: # if age > 44.5 if pclass <= 2.5: return [[2. 0.]] survived else: # if pclass > 2.5 return [[1. 1.]] survived else: # if age > 45.5 return [[13. 0.]] survived else: # if sex > 1.5 if pclass <= 2.5: if pclass <= 1.5: if age <= 62.5: if age <= 36.5: if age <= 35.5: if age <= 24.5: return [[ 0. 19.]] deid else: # if age > 24.5 if age <= 26.0: return [[1. 0.]] survived else: # if age > 26.0 if age <= 31.09709072113037: return [[0. 6.]] deid else: # if age > 31.09709072113037 if age <= 32.09709072113037: return [[ 1. 23.]] deid else: # if age > 32.09709072113037 return [[0. 5.]] deid else: # if age > 35.5 return [[1. 3.]] deid else: # if age > 36.5 return [[ 0. 31.]] deid else: # if age > 62.5 if age <= 63.5: return [[1. 1.]] survived else: # if age > 63.5 return [[0. 1.]] deid else: # if pclass > 1.5 if age <= 17.5: return [[0. 9.]] deid else: # if age > 17.5 if age <= 22.5: if age <= 21.5: if age <= 18.5: return [[1. 3.]] deid else: # if age > 18.5 return [[0. 5.]] deid else: # if age > 21.5 return [[2. 0.]] survived else: # if age > 22.5 if age <= 26.5: return [[0. 5.]] deid else: # if age > 26.5 if age <= 27.5: return [[1. 1.]] survived else: # if age > 27.5 if age <= 29.5: return [[0. 5.]] deid else: # if age > 29.5 if age <= 30.5: return [[1. 2.]] deid else: # if age > 30.5 if age <= 46.0: if age <= 43.0: if age <= 39.0: if age <= 37.0: if age <= 31.59709072113037: if age <= 31.09709072113037: return [[0. 2.]] deid else: # if age > 31.09709072113037 return [[ 3. 17.]] deid else: # if age > 31.59709072113037 return [[0. 9.]] deid else: # if age > 37.0 return [[1. 0.]] survived else: # if age > 39.0 return [[0. 4.]] deid else: # if age > 43.0 return [[1. 0.]] survived else: # if age > 46.0 return [[0. 5.]] deid else: # if pclass > 2.5 if age <= 19.5: if age <= 12.0: if age <= 5.5: if age <= 1.0833500027656555: return [[0. 1.]] deid else: # if age > 1.0833500027656555 if age <= 3.5: return [[1. 0.]] survived else: # if age > 3.5 return [[0. 1.]] deid else: # if age > 5.5 return [[2. 0.]] survived else: # if age > 12.0 if age <= 17.5: if age <= 15.5: return [[0. 1.]] deid else: # if age > 15.5 if age <= 16.5: return [[1. 3.]] deid else: # if age > 16.5 return [[0. 1.]] deid else: # if age > 17.5 if age <= 18.5: return [[2. 3.]] deid else: # if age > 18.5 return [[0. 1.]] deid else: # if age > 19.5 if age <= 21.5: return [[3. 0.]] survived else: # if age > 21.5 if age <= 23.5: if age <= 22.5: return [[1. 2.]] deid else: # if age > 22.5 return [[0. 1.]] deid else: # if age > 23.5 if age <= 32.5: if age <= 31.59709072113037: if age <= 25.5: return [[1. 1.]] survived else: # if age > 25.5 if age <= 29.0: return [[2. 0.]] survived else: # if age > 29.0 if age <= 30.59709072113037: return [[1. 1.]] survived else: # if age > 30.59709072113037 return [[75. 40.]] survived else: # if age > 31.59709072113037 return [[1. 0.]] survived else: # if age > 32.5 if age <= 37.0: return [[0. 3.]] deid else: # if age > 37.0 if age <= 42.5: return [[2. 0.]] survived else: # if age > 42.5 return [[1. 1.]] survived
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。