赞
踩
环境:Windows 10,Python 3.7
首先需要安装Graphviz,这里我们使用的是graphviz-2.38.msi,安装在D:\Program Files (x86)\Graphviz2.38。
代码:
import os
from sklearn.datasets import load_iris
from sklearn.ensemble import RandomForestClassifier
from sklearn.tree import export_graphviz
# 系统环境变量添加Graphviz安装路径,以便下面代码可以用dot命令
os.environ["PATH"] += os.pathsep + 'D:/Program Files (x86)/Graphviz2.38/bin'
iris = load_iris()
X, y = iris.data, iris.target
model = RandomForestClassifier(n_estimators=3, max_features=1)
model.fit(X, y)
# 循环打印每棵树
for idx, estimator in enumerate(model.estimators_):
# 导出dot文件
export_graphviz(estimator,
out_file='tree{}.dot'.format(idx),
feature_names=iris.feature_names,
class_names=iris.target_names,
rounded=True,
proportion=False,
precision=2,
filled=True)
# 转换为png文件
os.system('dot -Tpng tree{}.dot -o tree{}.png'.format(idx, idx))
第一棵树:
第二棵树:
第三棵树:
树结果解释:
1、有多少种类别,整棵树就有多少种颜色,比如我们这里有setosa、versicolor、virginica三个类别,颜色对应是黄、绿、紫,Gini指数越小,该节点颜色越深。
2、value表示当前节点三种类别的样本有多少,比如下面第一棵树的根节点,value = [59,45,48],表示setosa有59个样本,versicolor有45个样本,virginica有48个样本。
3、class表示当前那个类别的样本最多,比如下面第一棵树的根节点,class = setosa,可以看到当前节点它的样本数是最多的。
4、为什么三棵树根节点每个类别的样本数不一样的?Iris数据集不是每个类别都是50个样本吗?-> 务必记得随机森林每棵树所用的数据集都是有放回抽样得到的!
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。