当前位置:   article > 正文

Graph U-Nets 笔记_graphunet

graphunet

Graph U-Nets 笔记

1. 引言

论文主要内容是提出了用于图神经网络的 U-Net 结构。论文的出发点很简单,鉴于 UNet 的 encoder-decoder 结构在图像领域取得的成功,我们能不能在 GNN 中也模仿 U-Net 设计一种 Graph U-Net 结构呢?因为 U-Net 是先做降采样(Encoder)得到从低层次到高层次的图像特征,然后再做上采样 (Decoder)将 Encoder 计算的特征融合起来得到新的特征,所以如果要在 GNN 中构造 UNet 结构,就需要解决如何进行降采样和上采样的问题。论文给出的解决方法是 graph pooling (gPool) 和 graph unpooling (gUnpool). gPool 通过学习的方式,自动地选择需要保留图中哪些节点以构建较小地图. gUnpool 是 gPool 的逆操作,它根据 gPool 记录的节点编号,直接将高层次小图的节点特征赋予到低层次大图的节点上。

论文的主要贡献为:

  • 提出的 graph pooling (gPool) 和 graph unpooling (gUnpool) 两种算子
  • 基于 gPool 和 gUnpool,将 U-Net 结构应用到图上得到 Graph U-Nets

2. 方法

论文的卷积计算是基于 GCN的,GCN 计算公式如下:
X l + 1 = σ ( D ^ − 1 2 A ^ D − 1 2 X l W l ) X_{l+1} = \sigma (\hat D^{-\frac 12} \hat A D^{-\frac 12} X_{l} W_{l}) Xl+1=σ(D^21A^D21XlWl)
其中 A ^ = A + I \hat A = A + I A^=A+I, A A A 为邻接矩阵, D ^ \hat D D^ 为节点度矩阵, X l X_l Xl 为节点特征, W l W_l Wl 为需要学习的权重参数。

2.1 Graph Pooling Layer

为了对图数据做降采样,论文提出利用投影向量 p \mathbf p p 来计算图中各个节点的重要性,并且 p \mathbf p p 是可训练的参数,不需要认为指定。假设节点的嵌入向量为 x i \mathbf x_i xi , 它在向量 p \mathbf p p 上的投影为 y i = x i p / ∥ p ∥ y_i = \mathbf x_i \mathbf p / \|\mathbf p\| yi=xip/p y i y_i yi 可以作为衡量节点重要性的度量。在 gPool 中,计算过程为

  • 根据节点特征   X ( l ) \ X^{(l)}  X(l) 计算每个节点在投影向量上的投影:$\mathbf y = X^{(l)} \mathbf p^{(l)}/|\mathbf p | $
  • 将投影 y \mathbf y y 进行排序,取最大的 k k k 个节点: idx = rank ( y , k ) \text{idx} = \text{rank} (\mathbf y, k) idx=rank(y,k)
  • 计算特征投影系数: y ~ = sigmoid ( y ( idx ) ) \tilde {\mathbf y} = \text{sigmoid} (y(\text{idx})) y~=sigmoid(y(idx))
  • 根据 idx 抽取 k k k 个节点: X ~ ( l ) = X ( l ) ( idx , : ) \tilde { X}^{(l)} = X^{(l)}(\text{idx}, :) X~(l)=X(l)(idx,:)
  • 根据 idx 抽取邻接矩阵: A ( l + 1 ) = A ( l ) ( idx , idx ) A^{(l+1)} = A^{(l)}(\text{idx}, \text{idx}) A(l+1)=A(l)(idx,idx)
  • X ~ ( l ) \tilde { X}^{(l)} X~(l) 中每个节点嵌入向量分别乘以对应的投影系数 y ~ \tilde {\mathbf y} y~ ,得到 l + 1 l+1 l+1 层的节点嵌入: X ( l + 1 ) = X ~ ( l ) ⊙ ( y ~ 1 C T ) X^{(l+1)} = \tilde X^{(l)} \odot (\tilde{\mathbf y} \mathbf 1^T_C) X(l+1)=X~(l)(y~1CT)

gPool 示意图如下

2.2 Graph Unpooling Layer

gUnPool 就是 gPool 的逆操作。假设 gPool 前节点嵌入和邻接矩阵分别为 X ( l ′ ) X^{(l')} X(l),邻接矩阵为 A ( l ′ ) A^{(l')} A(l)经过池化后选取的节点编号为 idx, 节点嵌入为 X ( l ) X^{(l)} X(l),gUnpool 得到的节点嵌入和邻接矩阵分别是为 X ( l + 1 ) , A ( l + 1 ) X^{(l+1)}, A^{(l+1)} X(l+1),A(l+1),那么
X ( l + 1 ) = distribute ( 0 N × C , X ( l ) , idx ) , A ( l + 1 ) = A ( l ′ ) X^{(l+1)} = \text{distribute}(0_{N \times C}, X^{(l)}, \text{idx}),A^{(l+1)} = A^{(l')} X(l+1)=distribute(0N×C,X(l),idx)A(l+1)=A(l)
其中 X ( l + 1 ) ( idx ) = X ( l ) X^{(l+1)}(\text{idx}) = X^{(l)} X(l+1)(idx)=X(l). gUnpool 示意图如下

2.3 Graph U-Nets 的结构

Graph U-Net 和常规的 U-Net 在整体上结构基本相同,只是把 Conv2d, maxpool, TransposedConv2d 分别换成 GCN, gPool, gUnPool。通常 GNN 的深度不宜太大,所以论文中实验的 g-UNet 最深不过5层。g-UNet 示意图如下

2.4 图连接扩展 (Graph Connectivity Augmentation via Graph Power )

在 gPool 中,移除部分节点后,其边也会被移除,下采样后图的连接变得稀疏,也有可能会产生部分孤立节点。论文中采用了连接扩展的方法,在池化的之前会将节点 k k k 跳范围内的节点都连接起来,用邻接矩阵表示为
A ~ ( l ) = A k , A ( l + 1 ) = A ~ ( l ) ( idx , idx ) \tilde A^{(l)} = A^k, \quad A^{(l+1)} = \tilde A^{(l)}(\text{idx}, \text{idx}) A~(l)=Ak,A(l+1)=A~(l)(idx,idx)
论文中推荐 k = 2 k = 2 k=2,也就是将边扩展到节点的 2 跳范围内。

2.5 改进 GCN

在 GCN 中邻接矩阵 A ^ = A + I \hat A = A + I A^=A+I , 论文中改为 A ^ = A + 2 I \hat A = A + 2I A^=A+2I ,给予中心节点更大的权重。在实验中,文章发现这样做效果更好。不过这得具体问题具体分析,效果好不好,就见仁见智。

3. 实验

实验做了几个常用数据集的节点分类分析,包括传导学习 (transductive learning) 和归纳学习 (inductive learning experiments). 结果如下

与其他方法比起来当然是要好一些,使用了 gPool 和 gUnpool 的结果比没有用的要稍微好一点,但是不明显,说明还有改进的余地。

4. 代码实现

代码实现可参考 pytorch_geometric 中的 GraphUNet代码. 在 PyG 中 gPool 是通过 TopKPooling 实现的。

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/Cpp五条/article/detail/85392
推荐阅读
相关标签
  

闽ICP备14008679号