赞
踩
我在另一篇博文机器学习(3) K近邻算法(KNN)介绍及C++实现中介绍了K近邻算法及KD树的实现方法,博文编写过程中需要显式绘制二叉树将其表示出来。初始方法是使用C++生成KD树,并根据graphviz的dot语言逐行编写KD树。通过查阅相关文献,发现使用Python绘制KD树的过程并不繁琐,于是本文介绍使用Python绘制graphviz二叉查找树的图形。另外,介绍博文中是绘制二维平面上KD树、收敛过程如何表示的图形化绘制方法
graphviz是一种便于绘制流程图、树形结构等的图形可视化软件。掌握基础的脚本语言就可以轻松绘制属于自己的流程图、二叉树图等内容。
安装graphviz流程:打开graphviz下载链接,依据网页提示选择属于自己的平台安装包。我安装的是windows10下的stable_windows_10_msbuild_Release_Win32_graphviz-2.46.0-win32.zip。下载完成后解压到C:/Software/graphviz等自己习惯的路径下,将C:/software/graphviz/bin加入到系统环境变量中,重启电脑以配置graphviz环境变量。使用Python绘制graphviz流程图,需要在安装python3环境后,在命令行pip install graphviz即可。
安装python3环境流程:以windows10为例,打开Python3-Windows-下载地址,下载最新版本安装包或适合自己的版本安装包,安装到本地过程中记得配置环境变量Path,不再赘述。
我将着重讲述使用graphviz绘制二叉树涉及到的语法知识,更详细的语法知识参见graphviz官网说明文档。我将讲解两方面知识,第一是使用命令行编译dot文件,第二是使用python直接生成.gv文件。
首先介绍命令行下如何使用graphviz语法编写一个二叉树。在自己的路径下生成一个demo.dot文档,文档中内容如下:
// demo.dot
digraph {
node [shape=circle]
1 [label="(7,2)"]
2 [label="(5,4)"]
3 [label="(2,3)"]
4 [label="(6,6)", style="invis"]
1 -> 2
1 -> 4
1 -> 3 [style="invis"]
}
我将文档保存在了C:\File\demo.dot路径下。在命令行中执行:
cd C:\File
dot -Tpng demo.png -o demo.dot
就会在路径C:\File下生成demo.png,图像如图所示。
digraph G{
...
}
表示这是一个有向图,图中的边都带箭头。
...
node [shape='circle']
...
表示图中的节点都是圆形。
1 [label='(7,2)']
声明一个节点,节点记为1,其内容为字符串"(7,2)"
4 [label="(6,6)", style="invis"]
声明一个节点,节点记为4,其内容为字符串"(6,6)",并且这个节点在图中不显示。
1 -> 2
声明一条从1指向2的边。
1 -> 3 [style="invis"]
声明一条从1指向3的边,并且这条边在图中不显示。
据此,为了保证二叉树有序、对齐显示,我们在绘制二叉树的过程中,左右子树中间添加一个不可见的边和不可见节点,实现图形的对齐效果。如果使用C++强行绘制graphviz,就根据.dot文件的语法格式,向文件流中采用先根遍历的方法书写dot文本,使用文件流记得#include <fstream>
。C++实现方法如下:
#include <fstream> void drawKDTree(node* root, string path) { // 等价于先根序列。 // path = "tree.dot" ofstream fout(path); string tab = " "; fout << "digraph G{" << endl; fout << tab << "node[shape=circle]" << endl; int N = data.size()+1; preOrderDraw(root, fout, N); fout << "}" << endl; fout.close(); } void preOrderDraw(node* root, ofstream& fout, int& nullIndex) { string tab = " "; // 先根序列,绘制当前节点的内容。 fout << tab << root->index << "[group=" << root->index << ", label=\"(" << data[root->index][0]; for (int i = 1; i < n; i++) { fout << "," << data[root->index][i]; } fout << ")\"]" << endl; // 绘制左节点的内容 if (root->left) { // 当左节点非空的时候,需要绘制一条伸向左节点的有向边。 fout << tab << root->index << " -> " << root->left->index << endl; // 递归遍历左子树。 preOrderDraw(root->left, fout, nullIndex); } else { // 左节点为空的时候,为了保证图形的整洁有序,绘制左侧空节点占位。边与节点都为不可见[style=invis]。 fout << tab << root->index << " -> _" << nullIndex << "[style=invis]" << endl; fout << tab << "_" << nullIndex++ << " [style=invis]" << endl; } // 为了二叉树的图形可以相当漂亮美观且对齐,设置一个中间空节点保证左右两侧对齐。 fout << tab << root->index<<" -> "<< "_" << root->index << "[weight=10, group=" << root->index << ", style=invis]" << endl; fout << tab << "_" << root->index << "[style=invis]" << endl; // 同上,绘制右节点的内容。 if (root->right) { // 当右节点非空的时候,需要绘制一条伸向右节点的有向边。 fout << tab << root->index << " -> " << root->right->index << endl; // 递归遍历右子树。 preOrderDraw(root->right, fout, nullIndex); } else { // 右节点为空的时候,为了保证图形的整洁有序,绘制右侧空节点占位。边与节点都为不可见[style=invis]。 fout << tab << root->index << " -> _" << nullIndex << "[style=invis]" << endl; fout << tab << "_" << nullIndex++ << " [style=invis]" << endl; } }
下面介绍Python的graphviz语法。为了绘制同样一棵上面的树,我们只需要做这几行代码,即可生成一棵二叉树并展示出来。
// demo.py
from graphviz import Digraph
dot = Digraph(node_attr={'shape': 'circle'})
dot.node(1,"(7,2)")
dot.node(2,"(5,4)")
dot.node(3,"(2,3)")
dot.node(4,"(6,6)",style="invis")
dot.edge(1,2)
dot.edge(1,4)
dot.edge(1,3,style="invis")
dot.view()
据此,使用先根遍历的方式,同样根据二叉树的节点,绘制边和点即可。
import numpy as np from graphviz import Digraph from matplotlib import pyplot as plt from matplotlib.pyplot import MultipleLocator #data = [[2,3],[6, 4],[9, 6],[4, 7],[8, 1],[7, 2], [8,2], [10,4], [6,6]] data = [[7,2], [5,4], [9,6], [2,3], [4,7], [8,1]] data = np.array(data) # 节点 class node: def __init__(self, _data=None, _left=None, _right=None, _father=None, _dim=None, _index=None, _visiable=True): self.data = _data self.left = _left self.right = _right self.father = _father self.dim = _dim self.index = _index self.visiable = _visiable def getData(self): s = "(" for i in range(self.data.size): if i!=0: s += ',' s+=str(self.data[i]) s += ")" return s def __str__(self): if(self.visiable): return str(self.index) else: return "_invis"+str(self.index) dataIndex = 1 def drawKDTree(data, depth, k, dot): # 根据数据生成KD树 dim = depth % k length = data.shape[0] if(length==0): return None, dot index = [] for i in range(length): index.append([data[i][dim], i]) index.sort() root = data[index[length//2][1]] left = [data[index[i][1]] for i in range(length//2)] left = np.array(left) right = [data[index[i][1]] for i in range(length//2+1, length)] right = np.array(right) global dataIndex root_node = node(_data=root, _dim=dim, _index=dataIndex) dataIndex+=1 dot.node(str(root_node.index), root_node.getData()) root_node.left, dot=drawKDTree(left, depth+1, k, dot) if(root_node.left==None): pass dot.node("_left"+str(root_node.index), root_node.getData(), style="invis") dot.edge(str(root_node.index), "_left"+str(root_node.index), style="invis") else: dot.edge(str(root_node.index), str(root_node.left.index)) dot.node("_middle"+str(root_node.index), root_node.getData(), style="invis") dot.edge(str(root_node.index), "_middle"+str(root_node.index), style="invis", weight="10") root_node.right, dot=drawKDTree(right, depth+1, k, dot) if(root_node.right==None): pass dot.node("_right"+str(root_node.index), root_node.getData(), style="invis") dot.edge(str(root_node.index), "_right"+str(root_node.index), style="invis") else: dot.edge(str(root_node.index), str(root_node.right.index)) if(root_node.left): root_node.left.father=root_node if(root_node.right): root_node.right.father=root_node return root_node, dot dot = Digraph(node_attr={'shape': 'circle'}) _, dot = drawKDTree(data, 0, 2, dot) dot.view() print(dot.source)
只需要通过pyplot在生成KD树的过程中,控制节点的维度以及左右边界,即可绘制分类的直线段;通过绘制scatter散点图,将点标记在图中;通过计算半径,绘制以待查询节点为圆心的圆形。特别注意的是,由于pyplot不支持深拷贝、也无法撤销某一步操作,因此想要在同一个背景下绘制不同的图形,只有自己设置一个函数以保证每次都可以同样调用生成同一块背景,并在该背景上绘制新的图形。这里的Python函数不包括数据的预处理、标签、投票等内容,仅仅是用于绘制图形而用的脚本内容。
import numpy as np from graphviz import Digraph from matplotlib import pyplot as plt from matplotlib.pyplot import MultipleLocator #data = [[2,3],[6, 4],[9, 6],[4, 7],[8, 1],[7, 2], [8,2], [10,4], [6,6]] data = [[7,2], [5,4], [9,6], [2,3], [4,7], [8,1]] data = np.array(data) # 节点 class node: def __init__(self, _data=None, _left=None, _right=None, _father=None, _dim=None, _index=None, _visiable=True): self.data = _data self.left = _left self.right = _right self.father = _father self.dim = _dim self.index = _index self.visiable = _visiable def getData(self): s = "(" for i in range(self.data.size): if i!=0: s += ',' s+=str(self.data[i]) s += ")" return s def __str__(self): if(self.visiable): return str(self.index) else: return "_invis"+str(self.index) # 生成KD树,并绘制一个完整的平面图形。 def createTree(data, depth, k, l, r, d, u): dim = depth % k length = data.shape[0] if(length==0): return None index = [] for i in range(length): index.append([data[i][dim], i]) index.sort() root = data[index[length//2][1]] left = [data[index[i][1]] for i in range(length//2)] left = np.array(left) right = [data[index[i][1]] for i in range(length//2+1, length)] right = np.array(right) root_node = node(_data=root, _dim=dim) if(dim == 0): plt.plot([root[0]]*(u-d+1), range(d, u+1)) root_node.left=createTree(left, depth+1, k, l, root[0], d, u) root_node.right=createTree(right, depth+1, k, root[0], r, d, u) if(root_node.left): root_node.left.father=root_node if(root_node.right): root_node.right.father=root_node else: plt.plot(range(l, r+1), [root[1]]*(r-l+1)) root_node.left=createTree(left, depth+1, k, l, r, d, root[1]) root_node.right=createTree(right, depth+1, k, l, r, root[1], u) if(root_node.left): root_node.left.father=root_node if(root_node.right): root_node.right.father=root_node return root_node # 绘制分类超平面 def drawOri(data): fig, ax = plt.subplots() fig.set_size_inches(5, 5) data = np.array(data) mmax = np.max(data)+1 mmin = np.min(data)-1 major_locator=MultipleLocator(1) plt.scatter(data[:,0], data[:,1]) plt.xlim(mmin, mmax) plt.ylim(mmin, mmax) ax = plt.gca() ax.xaxis.set_major_locator(major_locator) ax.yaxis.set_major_locator(major_locator) return createTree(data, 0, 2, mmin, mmax, mmin, mmax) # 绘制标记点及分类超平面 def drawPic(x, data): fig, ax = plt.subplots() fig.set_size_inches(5, 5) data = np.array(data) mmax = np.max(data)+1 mmin = np.min(data)-1 major_locator=MultipleLocator(1) plt.scatter(data[:,0], data[:,1]) plt.scatter([x[0]], [x[1]], marker='x') plt.xlim(mmin, mmax) plt.ylim(mmin, mmax) ax = plt.gca() ax.xaxis.set_major_locator(major_locator) ax.yaxis.set_major_locator(major_locator) return createTree(data, 0, 2, mmin, mmax, mmin, mmax) # 计算两点间的欧式距离 def distance(a, b): return ((a[0]-b[0])**2+(a[1]-b[1])**2)**0.5 # 寻找叶节点 def findLeaf(root, x, stack): if(root==None): return stack stack.append(root) if(x[root.dim]<=root.data[root.dim]): return findLeaf(root.left, x, stack) else: return findLeaf(root.right, x, stack) # 寻找最近邻节点,并绘制图形 def searchNearest(root, x, differt_pic=True, show=False): plt.scatter([x[0]], [x[1]], marker='x') stack = [] stack = findLeaf(root, x, stack) nearN = stack[-1] minD = distance(stack[-1].data, x) visted = set() path = 1 while(stack): top = stack[-1] visted.add(top) stack.pop() dis = distance(top.data, x) if(dis < minD): minD = dis nearN = top if show: plt.show() if differt_pic: # 重新绘制一张底图。 drawPic(x, data) ax = plt.gca() ax.scatter(top.data[0], top.data[1], marker='x', s=200) ax.plot([x[0], nearN.data[0]], [x[1], nearN.data[1]]) theta = np.arange(0, 2*np.pi, 0.01) xx = x[0] + minD * np.cos(theta) yy = x[1] + minD * np.sin(theta) plt.plot(xx, yy) plt.savefig("{0}.png".format(path)) path += 1 left = x[top.dim] - minD right = x[top.dim] + minD if(left <= top.data[top.dim] and top.left != None and top.left not in visted): stack.append(top.left) if(right >= top.data[top.dim] and top.right != None and top.right not in visted): stack.append(top.right) return nearN x = [4, 3] drawOri(data) plt.savefig("0.png") root = drawPic(x, data, differt_pic=True, show=False) searchNearest(root, x)
绘制图形展示如下:
至此,《统计学习方法》第三章的全部内容都更新完毕,在我的Gtihub中有详细代码,欢迎查阅。
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。