赞
踩
论文主要内容是提出了用于图神经网络的 U-Net 结构。论文的出发点很简单,鉴于 UNet 的 encoder-decoder 结构在图像领域取得的成功,我们能不能在 GNN 中也模仿 U-Net 设计一种 Graph U-Net 结构呢?因为 U-Net 是先做降采样(Encoder)得到从低层次到高层次的图像特征,然后再做上采样 (Decoder)将 Encoder 计算的特征融合起来得到新的特征,所以如果要在 GNN 中构造 UNet 结构,就需要解决如何进行降采样和上采样的问题。论文给出的解决方法是 graph pooling (gPool) 和 graph unpooling (gUnpool). gPool 通过学习的方式,自动地选择需要保留图中哪些节点以构建较小地图. gUnpool 是 gPool 的逆操作,它根据 gPool 记录的节点编号,直接将高层次小图的节点特征赋予到低层次大图的节点上。
论文的主要贡献为:
论文的卷积计算是基于 GCN的,GCN 计算公式如下:
X
l
+
1
=
σ
(
D
^
−
1
2
A
^
D
−
1
2
X
l
W
l
)
X_{l+1} = \sigma (\hat D^{-\frac 12} \hat A D^{-\frac 12} X_{l} W_{l})
Xl+1=σ(D^−21A^D−21XlWl)
其中
A
^
=
A
+
I
\hat A = A + I
A^=A+I,
A
A
A 为邻接矩阵,
D
^
\hat D
D^ 为节点度矩阵,
X
l
X_l
Xl 为节点特征,
W
l
W_l
Wl 为需要学习的权重参数。
为了对图数据做降采样,论文提出利用投影向量 p \mathbf p p 来计算图中各个节点的重要性,并且 p \mathbf p p 是可训练的参数,不需要认为指定。假设节点的嵌入向量为 x i \mathbf x_i xi , 它在向量 p \mathbf p p 上的投影为 y i = x i p / ∥ p ∥ y_i = \mathbf x_i \mathbf p / \|\mathbf p\| yi=xip/∥p∥ , y i y_i yi 可以作为衡量节点重要性的度量。在 gPool 中,计算过程为
gPool 示意图如下
gUnPool 就是 gPool 的逆操作。假设 gPool 前节点嵌入和邻接矩阵分别为
X
(
l
′
)
X^{(l')}
X(l′),邻接矩阵为
A
(
l
′
)
A^{(l')}
A(l′)经过池化后选取的节点编号为 idx, 节点嵌入为
X
(
l
)
X^{(l)}
X(l),gUnpool 得到的节点嵌入和邻接矩阵分别是为
X
(
l
+
1
)
,
A
(
l
+
1
)
X^{(l+1)}, A^{(l+1)}
X(l+1),A(l+1),那么
X
(
l
+
1
)
=
distribute
(
0
N
×
C
,
X
(
l
)
,
idx
)
,
A
(
l
+
1
)
=
A
(
l
′
)
X^{(l+1)} = \text{distribute}(0_{N \times C}, X^{(l)}, \text{idx}),A^{(l+1)} = A^{(l')}
X(l+1)=distribute(0N×C,X(l),idx),A(l+1)=A(l′)
其中
X
(
l
+
1
)
(
idx
)
=
X
(
l
)
X^{(l+1)}(\text{idx}) = X^{(l)}
X(l+1)(idx)=X(l). gUnpool 示意图如下
Graph U-Net 和常规的 U-Net 在整体上结构基本相同,只是把 Conv2d, maxpool, TransposedConv2d 分别换成 GCN, gPool, gUnPool。通常 GNN 的深度不宜太大,所以论文中实验的 g-UNet 最深不过5层。g-UNet 示意图如下
在 gPool 中,移除部分节点后,其边也会被移除,下采样后图的连接变得稀疏,也有可能会产生部分孤立节点。论文中采用了连接扩展的方法,在池化的之前会将节点
k
k
k 跳范围内的节点都连接起来,用邻接矩阵表示为
A
~
(
l
)
=
A
k
,
A
(
l
+
1
)
=
A
~
(
l
)
(
idx
,
idx
)
\tilde A^{(l)} = A^k, \quad A^{(l+1)} = \tilde A^{(l)}(\text{idx}, \text{idx})
A~(l)=Ak,A(l+1)=A~(l)(idx,idx)
论文中推荐
k
=
2
k = 2
k=2,也就是将边扩展到节点的 2 跳范围内。
在 GCN 中邻接矩阵 A ^ = A + I \hat A = A + I A^=A+I , 论文中改为 A ^ = A + 2 I \hat A = A + 2I A^=A+2I ,给予中心节点更大的权重。在实验中,文章发现这样做效果更好。不过这得具体问题具体分析,效果好不好,就见仁见智。
实验做了几个常用数据集的节点分类分析,包括传导学习 (transductive learning) 和归纳学习 (inductive learning experiments). 结果如下
与其他方法比起来当然是要好一些,使用了 gPool 和 gUnpool 的结果比没有用的要稍微好一点,但是不明显,说明还有改进的余地。
代码实现可参考 pytorch_geometric 中的 GraphUNet 的 代码. 在 PyG 中 gPool 是通过 TopKPooling 实现的。
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。